.. code:: ipython3 %matplotlib inline %load_ext autoreload %autoreload 2 Handle motion/drift with spikeinterface ======================================= Spikeinterface offers a very flexible framework to handle drift as a preprocessing step. If you want to know more, please read the ``motion_correction`` section of the documentation. Here is a short demo on how to handle drift using the high-level function ``spikeinterface.preprocessing.correct_motion()``. This function takes a preprocessed recording as input and then internally runs several steps (it can be slow!) and returns a lazy recording that interpolates the traces on-the-fly to compensate for the motion. Internally this function runs the following steps: :: 1. localize_peaks() 2. select_peaks() (optional) 3. estimate_motion() 4. interpolate_motion() All these sub-steps can be run with different methods and have many parameters. The high-level function suggests 3 predifined “presets” and we will explore them using a very well known public dataset recorded by Nick Steinmetz: `Imposed motion datasets `__ This dataset contains 3 recordings and each recording contains a Neuropixels 1 and a Neuropixels 2 probe. Here we will use *dataset1* with *neuropixel1*. This dataset is the *“hello world”* for drift correction in the spike sorting community! .. code:: ipython3 from pathlib import Path import matplotlib.pyplot as plt import numpy as np import shutil import spikeinterface.full as si from spikeinterface.preprocessing import get_motion_parameters_preset, get_motion_presets .. code:: ipython3 base_folder = Path("/mnt/data/sam/DataSpikeSorting/imposed_motion_nick") dataset_folder = base_folder / "dataset1/NP1" .. code:: ipython3 # read the file raw_rec = si.read_spikeglx(dataset_folder) raw_rec .. raw:: html
SpikeGLXRecordingExtractor: 384 channels - 30.0kHz - 1 segments - 58,715,724 samples - 1,957.19s (32.62 minutes) - int16 dtype - 42.00 GiB
Channel IDs
Annotations
Channel Properties
We preprocess the recording with bandpass filter and a common median reference. Note, that it is better to not whiten the recording before motion estimation to get a better estimate of peak locations! .. code:: ipython3 def preprocess_chain(rec): rec = rec.astype('float32') rec = si.bandpass_filter(rec, freq_min=300.0, freq_max=6000.0) rec = si.common_reference(rec, reference="global", operator="median") return rec .. code:: ipython3 rec = preprocess_chain(raw_rec) .. code:: ipython3 job_kwargs = dict(n_jobs=40, chunk_duration="1s", progress_bar=True) Correcting for drift is easy! You just need to run a single function. We will try this function with some presets. Internally a preset is a dictionary of dictionaries containing all parameters for every steps. Here we also save the motion correction results into a folder to be able to load them later. preset and parameters ~~~~~~~~~~~~~~~~~~~~~ Motion correction has some steps and eevry step can be controlled by a method and related parameters. A preset is a nested dict that contains theses methods/parameters. .. code:: ipython3 preset_keys = get_motion_presets() preset_keys .. parsed-literal:: ['dredge', 'dredge_fast', 'nonrigid_accurate', 'nonrigid_fast_and_accurate', 'rigid_fast', 'kilosort_like'] .. code:: ipython3 one_preset_params = get_motion_parameters_preset("kilosort_like") one_preset_params .. parsed-literal:: {'doc': 'Mimic the drift correction of kilosort (grid_convolution + iterative_template)', 'detect_kwargs': {'peak_sign': 'neg', 'detect_threshold': 8.0, 'exclude_sweep_ms': 0.1, 'radius_um': 50, 'noise_levels': None, 'random_chunk_kwargs': {}, 'method': 'locally_exclusive'}, 'select_kwargs': {}, 'localize_peaks_kwargs': {'radius_um': 40.0, 'upsampling_um': 5.0, 'sigma_ms': 0.25, 'margin_um': 50.0, 'prototype': None, 'percentile': 5.0, 'peak_sign': 'neg', 'weight_method': {'mode': 'gaussian_2d', 'sigma_list_um': array([ 5., 10., 15., 20., 25.])}, 'method': 'grid_convolution'}, 'estimate_motion_kwargs': {'direction': 'y', 'rigid': False, 'win_shape': 'rect', 'win_step_um': 200.0, 'win_scale_um': 400.0, 'win_margin_um': None, 'bin_um': 10.0, 'hist_margin_um': 0, 'bin_s': 2.0, 'num_amp_bins': 20, 'num_shifts_global': 15, 'num_iterations': 10, 'num_shifts_block': 5, 'smoothing_sigma': 0.5, 'kriging_sigma': 1, 'kriging_p': 2, 'kriging_d': 2, 'method': 'iterative_template'}, 'interpolate_motion_kwargs': {'border_mode': 'force_extrapolate', 'spatial_interpolation_method': 'kriging', 'sigma_um': 20.0, 'p': 2}} Run motion correction with one function! ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Correcting for drift is easy! You just need to run a single function. We will try this function with some presets. Here we also save the motion correction results into a folder to be able to load them later. .. code:: ipython3 # lets try theses presets some_presets = ("rigid_fast", "kilosort_like", "nonrigid_accurate", "nonrigid_fast_and_accurate", "dredge", "dredge_fast") .. code:: ipython3 # compute motion with theses presets for preset in some_presets: print("Computing with", preset) folder = base_folder / "motion_folder_dataset1" / preset if folder.exists(): shutil.rmtree(folder) recording_corrected, motion, motion_info = si.correct_motion( rec, preset=preset, folder=folder, output_motion=True, output_motion_info=True, **job_kwargs ) .. parsed-literal:: Computing with rigid_fast .. parsed-literal:: detect and localize: 0%| | 0/1958 [00:00 int(sr * time_lim0)) & (peaks["sample_index"] < int(sr * time_lim1)) sl = slice(None, None, 5) amps = np.abs(peaks["amplitude"][mask][sl]) amps /= np.quantile(amps, 0.95) c = plt.get_cmap("inferno")(amps) color_kargs = dict(alpha=0.2, s=2, c=c) peak_locations = motion_info["peak_locations"] # color='black', ax.scatter(peak_locations["x"][mask][sl], peak_locations["y"][mask][sl], **color_kargs) peak_locations2 = correct_motion_on_peaks(peaks, peak_locations, motion,rec) ax = axs[1] si.plot_probe_map(rec, ax=ax) # color='black', ax.scatter(peak_locations2["x"][mask][sl], peak_locations2["y"][mask][sl], **color_kargs) ax.set_ylim(400, 600) fig.suptitle(f"{preset=}") .. image:: handle_drift_files/handle_drift_19_0.png .. image:: handle_drift_files/handle_drift_19_1.png .. image:: handle_drift_files/handle_drift_19_2.png .. image:: handle_drift_files/handle_drift_19_3.png .. image:: handle_drift_files/handle_drift_19_4.png .. image:: handle_drift_files/handle_drift_19_5.png run times --------- Presets and related methods have differents accuracies but also computation speeds. It is good to have this in mind! .. code:: ipython3 run_times = [] for preset in some_presets: folder = base_folder / "motion_folder_dataset1" / preset motion_info = si.load_motion_info(folder) run_times.append(motion_info["run_times"]) keys = run_times[0].keys() bottom = np.zeros(len(run_times)) fig, ax = plt.subplots(figsize=(14, 6)) for k in keys: rtimes = np.array([rt[k] for rt in run_times]) if np.any(rtimes > 0.0): ax.bar(some_presets, rtimes, bottom=bottom, label=k) bottom += rtimes ax.legend() .. parsed-literal:: .. image:: handle_drift_files/handle_drift_21_1.png