Curation Tutorial

After spike sorting and computing quality metrics, you can automatically curate the spike sorting output using the quality metrics that you have calculated.

Import the modules and/or functions necessary from spikeinterface

import spikeinterface.core as si
import spikeinterface.extractors as se

from spikeinterface.postprocessing import compute_principal_components
from spikeinterface.qualitymetrics import compute_quality_metrics

Let’s download a simulated dataset from the repo ‘https://gin.g-node.org/NeuralEnsemble/ephy_testing_data

Let’s imagine that the ground-truth sorting is in fact the output of a sorter.

local_path = si.download_dataset(remote_path="mearec/mearec_test_10s.h5")
recording, sorting = se.read_mearec(file_path=local_path)
print(recording)
print(sorting)
MEArecRecordingExtractor: 32 channels - 32.0kHz - 1 segments - 320,000 samples - 10.00s
                          float32 dtype - 39.06 MiB
  file_path: /home/docs/spikeinterface_datasets/ephy_testing_data/mearec/mearec_test_10s.h5
MEArecSortingExtractor: 10 units - 1 segments - 32.0kHz
  file_path: /home/docs/spikeinterface_datasets/ephy_testing_data/mearec/mearec_test_10s.h5

Create SortingAnalyzer

For this example, we will need a SortingAnalyzer and some extensions to be computed first

analyzer = si.create_sorting_analyzer(sorting=sorting, recording=recording, format="memory")
analyzer.compute(["random_spikes", "waveforms", "templates", "noise_levels"])

analyzer.compute("principal_components", n_components=3, mode="by_channel_local")
print(analyzer)
/home/docs/checkouts/readthedocs.org/user_builds/spikeinterface/checkouts/3072/src/spikeinterface/core/job_tools.py:103: UserWarning: `n_jobs` is not set so parallel processing is disabled! To speed up computations, it is recommended to set n_jobs either globally (with the `spikeinterface.set_global_job_kwargs()` function) or locally (with the `n_jobs` argument). Use `spikeinterface.set_global_job_kwargs?` for more information about job_kwargs.
  warnings.warn(

estimate_sparsity:   0%|          | 0/10 [00:00<?, ?it/s]
estimate_sparsity: 100%|##########| 10/10 [00:00<00:00, 705.22it/s]

compute_waveforms:   0%|          | 0/10 [00:00<?, ?it/s]
compute_waveforms: 100%|##########| 10/10 [00:00<00:00, 216.11it/s]

noise_level:   0%|          | 0/20 [00:00<?, ?it/s]
noise_level:  20%|##        | 4/20 [00:00<00:00, 39.08it/s]
noise_level:  40%|####      | 8/20 [00:00<00:00, 39.02it/s]
noise_level:  60%|######    | 12/20 [00:00<00:00, 38.99it/s]
noise_level:  80%|########  | 16/20 [00:00<00:00, 38.95it/s]
noise_level: 100%|##########| 20/20 [00:00<00:00, 38.90it/s]
noise_level: 100%|##########| 20/20 [00:00<00:00, 38.91it/s]

Fitting PCA:   0%|          | 0/10 [00:00<?, ?it/s]
Fitting PCA:  10%|█         | 1/10 [00:00<00:04,  2.04it/s]
Fitting PCA:  30%|███       | 3/10 [00:02<00:05,  1.35it/s]
Fitting PCA:  50%|█████     | 5/10 [00:02<00:02,  2.15it/s]
Fitting PCA:  70%|███████   | 7/10 [00:02<00:01,  2.78it/s]
Fitting PCA:  80%|████████  | 8/10 [00:03<00:00,  2.96it/s]
Fitting PCA:  90%|█████████ | 9/10 [00:04<00:00,  2.04it/s]
Fitting PCA: 100%|██████████| 10/10 [00:06<00:00,  1.13it/s]
Fitting PCA: 100%|██████████| 10/10 [00:06<00:00,  1.61it/s]

Projecting waveforms:   0%|          | 0/10 [00:00<?, ?it/s]
Projecting waveforms: 100%|██████████| 10/10 [00:00<00:00, 167.26it/s]
SortingAnalyzer: 32 channels - 10 units - 1 segments - memory - sparse - has recording
Loaded 5 extensions: random_spikes, waveforms, templates, noise_levels, principal_components

Then we compute some quality metrics:

metrics = compute_quality_metrics(analyzer, metric_names=["snr", "isi_violation", "nearest_neighbor"])
print(metrics)
/home/docs/checkouts/readthedocs.org/user_builds/spikeinterface/checkouts/3072/src/spikeinterface/core/job_tools.py:103: UserWarning: `n_jobs` is not set so parallel processing is disabled! To speed up computations, it is recommended to set n_jobs either globally (with the `spikeinterface.set_global_job_kwargs()` function) or locally (with the `n_jobs` argument). Use `spikeinterface.set_global_job_kwargs?` for more information about job_kwargs.
  warnings.warn(

calculate pc_metrics:   0%|          | 0/10 [00:00<?, ?it/s]
calculate pc_metrics:  40%|████      | 4/10 [00:00<00:00, 36.79it/s]
calculate pc_metrics:  90%|█████████ | 9/10 [00:00<00:00, 44.24it/s]
calculate pc_metrics: 100%|██████████| 10/10 [00:00<00:00, 43.81it/s]
          snr  isi_violations_ratio  ...  nn_hit_rate  nn_miss_rate
#0  23.892211                     0  ...          1.0      0.001289
#1  25.605818                     0  ...         0.99      0.000744
#2  13.791474                     0  ...     0.976744      0.005831
#3  21.804222                     0  ...          1.0           0.0
#4   7.468614                     0  ...     0.989583       0.00101
#5   7.458332                     0  ...     0.993243      0.002653
#6  20.871042                     0  ...     0.995098           0.0
#7   7.396984                     0  ...     0.986486      0.010753
#8   8.027916                     0  ...     0.989744      0.001506
#9   9.033094                     0  ...     0.996124      0.003348

[10 rows x 5 columns]

We can now threshold each quality metric and select units based on some rules.

The easiest and most intuitive way is to use boolean masking with a dataframe.

Then create a list of unit ids that we want to keep

keep_mask = (metrics["snr"] > 7.5) & (metrics["isi_violations_ratio"] < 0.2) & (metrics["nn_hit_rate"] > 0.90)
print(keep_mask)

keep_unit_ids = keep_mask[keep_mask].index.values
keep_unit_ids = [unit_id for unit_id in keep_unit_ids]
print(keep_unit_ids)
#0     True
#1     True
#2     True
#3     True
#4    False
#5    False
#6     True
#7    False
#8     True
#9     True
dtype: boolean
['#0', '#1', '#2', '#3', '#6', '#8', '#9']

And now let’s create a sorting that contains only curated units and save it.

curated_sorting = sorting.select_units(keep_unit_ids)
print(curated_sorting)


curated_sorting.save(folder="curated_sorting")
UnitsSelectionSorting: 7 units - 1 segments - 32.0kHz
NumpyFolder: 7 units - 1 segments - 32.0kHz
Unit IDs
    ['#0' '#1' '#2' '#3' '#6' '#8' '#9']
Annotations
    Unit Properties


      We can also save the analyzer with only theses units

      clean_analyzer = analyzer.select_units(unit_ids=keep_unit_ids, format="zarr", folder="clean_analyzer")
      
      print(clean_analyzer)
      
      SortingAnalyzer: 32 channels - 7 units - 1 segments - zarr - sparse - has recording
      Loaded 6 extensions: random_spikes, waveforms, templates, noise_levels, principal_components, quality_metrics
      

      Total running time of the script: (0 minutes 7.930 seconds)

      Gallery generated by Sphinx-Gallery