Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 28 additions & 29 deletions doc/how_to/auto_label_units.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,18 +79,18 @@ curation:


1. Quality-metrics based curation
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
----------------------------------

A simple solution is to use a filter based on quality metrics. To do so,
we can use the ``spikeinterface.curation.qualitymetrics_label_units``
we can use the ``spikeinterface.curation.threshold_metrics_label_units``
function and provide a set of thresholds.

.. code:: ipython3

qm_thresholds = {
"snr": {"min": 5},
"firing_rate": {"min": 0.1, "max": 200},
"rp_contamination": {"max": 0.5}
"snr": {"greater": 5},
"firing_rate": {"greater": 0.1, "less": 200},
"rp_contamination": {"less": 0.5}
}

.. code:: ipython3
Expand Down Expand Up @@ -143,7 +143,7 @@ across all units:
.. image:: auto_label_units_files/auto_label_units_14_0.png


1. Bombcell
2. Bombcell
-----------

**Bombcell** ([Fabre]_) is another threshold-based method that also uses
Expand All @@ -161,24 +161,24 @@ file.

.. parsed-literal::

{'mua': {'amplitude_cutoff': {'max': 0.2, 'min': None},
'amplitude_median': {'max': None, 'min': 40},
'drift_ptp': {'max': 100, 'min': None},
'num_spikes': {'max': None, 'min': 300},
'presence_ratio': {'max': None, 'min': 0.7},
'rp_contamination': {'max': 0.1, 'min': None},
'snr': {'max': None, 'min': 5}},
'noise': {'exp_decay': {'max': 0.1, 'min': 0.01},
'num_negative_peaks': {'max': 1, 'min': None},
'num_positive_peaks': {'max': 2, 'min': None},
'peak_after_to_trough_ratio': {'max': 0.8, 'min': None},
'peak_to_trough_duration': {'max': 0.00115, 'min': 0.0001},
'waveform_baseline_flatness': {'max': 0.5, 'min': None}},
'non-somatic': {'main_peak_to_trough_ratio': {'max': 0.8, 'min': None},
'peak_before_to_peak_after_ratio': {'max': 3, 'min': None},
'peak_before_to_trough_ratio': {'max': 3, 'min': None},
'peak_before_width': {'max': None, 'min': 0.00015},
'trough_width': {'max': None, 'min': 0.0002}}}
{'mua': {'amplitude_cutoff': {'greater': None, 'less': 0.2},
'amplitude_median': {'abs': True, 'greater': 30, 'less': None},
'drift_ptp': {'greater': None, 'less': 100},
'num_spikes': {'greater': 300, 'less': None},
'presence_ratio': {'greater': 0.7, 'less': None},
'rp_contamination': {'greater': None, 'less': 0.1},
'snr': {'greater': 5, 'less': None}},
'noise': {'exp_decay': {'greater': 0.01, 'less': 0.1},
'num_negative_peaks': {'greater': None, 'less': 1},
'num_positive_peaks': {'greater': None, 'less': 2},
'peak_after_to_trough_ratio': {'greater': None, 'less': 0.8},
'peak_to_trough_duration': {'greater': 0.0001, 'less': 0.00115},
'waveform_baseline_flatness': {'greater': None, 'less': 0.5}},
'non-somatic': {'main_peak_to_trough_ratio': {'greater': None, 'less': 0.8},
'peak_before_to_peak_after_ratio': {'greater': None, 'less': 3},
'peak_before_to_trough_ratio': {'greater': None, 'less': 3},
'peak_before_width': {'greater': 0.00015, 'less': None},
'trough_width': {'greater': 0.0002, 'less': None}}}


.. code:: ipython3
Expand Down Expand Up @@ -248,8 +248,8 @@ contamination (``rp_contamination``).
.. image:: auto_label_units_files/auto_label_units_23_1.png


UnitRefine
----------
3. UnitRefine
-------------

**UnitRefine** ([Jain]_) also uses quality and template metrics, but in
a different way. It uses pre-trained classifiers to trained on
Expand Down Expand Up @@ -305,12 +305,11 @@ sorting with different strategies. We recommend running **Bombcell** and
**UnitRefine** as part of your pipeline. These methods will facilitate
further curation and make downstream analysis cleaner.

To remove units from your ``SortingAnalyzer``, you can simply use the
``select_units`` function:

Remove units from ``SortingAnalyzer``
-------------------------------------

To remove units from your ``SortingAnalyzer``, you can use the ``select_units`` function.

After auto-labeling, we can easily remove the “noise” units for
downstream analysis:

Expand Down
4 changes: 2 additions & 2 deletions doc/modules/curation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ which applies a set of thresholds based on the available metrics (template/quali
labels = threshold_metrics_label_units(
sorting_analyzer=sorting_analyzer,
thresholds={
"snr": {"min": 5},
"rp_contamination": {"max": 0.2},
"snr": {"greater": 5},
"rp_contamination": {"less": 0.2},
},
pass_label="good",
fail_label="bad",
Expand Down
6 changes: 3 additions & 3 deletions examples/how_to/auto_label_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@

# %%
qm_thresholds = {
"snr": {"min": 5},
"firing_rate": {"min": 0.1, "max": 200},
"rp_contamination": {"max": 0.5}
"snr": {"greater": 5},
"firing_rate": {"greater": 0.1, "less": 200},
"rp_contamination": {"less": 0.5}
}

# %%
Expand Down
46 changes: 23 additions & 23 deletions src/spikeinterface/curation/bombcell_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,34 +50,34 @@ def bombcell_get_default_thresholds() -> dict:
"""
bombcell - Returns default thresholds for unit labeling.

Each metric has 'min' and 'max' values. Use None to disable a threshold (e.g. to ignore a metric completely
or to only have a min or a max threshold)
Each metric has 'greater' and 'less' values. Use None to disable a threshold (e.g. to ignore a metric completely
or to only have a greater or a less threshold)
"""
# bombcell
return {
"noise": { # failures -> NOISE
"num_positive_peaks": {"min": None, "max": 2},
"num_negative_peaks": {"min": None, "max": 1},
"peak_to_trough_duration": {"min": 0.0001, "max": 0.00115}, # seconds
"waveform_baseline_flatness": {"min": None, "max": 0.5},
"peak_after_to_trough_ratio": {"min": None, "max": 0.8},
"exp_decay": {"min": 0.01, "max": 0.1},
"num_positive_peaks": {"greater": None, "less": 2},
"num_negative_peaks": {"greater": None, "less": 1},
"peak_to_trough_duration": {"greater": 0.0001, "less": 0.00115}, # seconds
"waveform_baseline_flatness": {"greater": None, "less": 0.5},
"peak_after_to_trough_ratio": {"greater": None, "less": 0.8},
"exp_decay": {"greater": 0.01, "less": 0.1},
},
"mua": { # failures -> MUA, only applied to units that passed noise thresholds
"amplitude_median": {"min": 30, "max": None, "abs": True}, # uV
"snr": {"min": 5, "max": None},
"amplitude_cutoff": {"min": None, "max": 0.2},
"num_spikes": {"min": 300, "max": None},
"rp_contamination": {"min": None, "max": 0.1},
"presence_ratio": {"min": 0.7, "max": None},
"drift_ptp": {"min": None, "max": 100}, # um
"amplitude_median": {"greater": 30, "less": None, "abs": True}, # uV
"snr": {"greater": 5, "less": None},
"amplitude_cutoff": {"greater": None, "less": 0.2},
"num_spikes": {"greater": 300, "less": None},
"rp_contamination": {"greater": None, "less": 0.1},
"presence_ratio": {"greater": 0.7, "less": None},
"drift_ptp": {"greater": None, "less": 100}, # um
},
"non-somatic": {
"peak_before_to_trough_ratio": {"min": None, "max": 3},
"peak_before_width": {"min": 0.00015, "max": None}, # seconds
"trough_width": {"min": 0.0002, "max": None}, # seconds
"peak_before_to_peak_after_ratio": {"min": None, "max": 3},
"main_peak_to_trough_ratio": {"min": None, "max": 0.8},
"peak_before_to_trough_ratio": {"greater": None, "less": 3},
"peak_before_width": {"greater": 0.00015, "less": None}, # seconds
"trough_width": {"greater": 0.0002, "less": None}, # seconds
"peak_before_to_peak_after_ratio": {"greater": None, "less": 3},
"main_peak_to_trough_ratio": {"greater": None, "less": 0.8},
},
}

Expand Down Expand Up @@ -123,7 +123,7 @@ def bombcell_label_units(
If provided, metrics are extracted automatically using get_metrics_extension_data().
thresholds : dict | str | Path | None
Threshold dict or JSON file, including a three sections ("noise", "mua", "non-somatic") of
{"metric": {"min": val, "max": val}}.
{"metric": {"greater": val, "less": val}}.
If None, default Bombcell thresholds are used.
label_non_somatic : bool, default: True
If True, detect non-somatic (dendritic, axonal) units.
Expand Down Expand Up @@ -336,8 +336,8 @@ def save_bombcell_results(
continue
value = metrics.loc[unit_id, metric_name]
thresh = flat_thresholds[metric_name]
thresh_min = thresh.get("min", None)
thresh_max = thresh.get("max", None)
thresh_min = thresh.get("greater", None)
thresh_max = thresh.get("less", None)

# Determine pass/fail
passed = True
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/curation/tests/test_bombcell_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def test_bombcell_label_units_with_threshold_file(sorting_analyzer_with_metrics,

# Define custom thresholds
custom_thresholds = {
"snr": {"min": 5, "max": 100},
"isi_violations": {"min": None, "max": 0.2},
"snr": {"greater": 5, "less": 100},
"isi_violations": {"greater": None, "less": 0.2},
}

# Save thresholds to a temporary JSON file
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def test_threshold_metrics_label_units_with_dataframe():
index=[0, 1, 2],
)
thresholds = {
"snr": {"min": 5.0},
"firing_rate": {"min": 0.1, "max": 20.0},
"snr": {"greater": 5.0},
"firing_rate": {"greater": 0.1, "less": 20.0},
}

labels = threshold_metrics_label_units(metrics, thresholds)
Expand All @@ -39,8 +39,8 @@ def test_threshold_metrics_label_units_with_file(tmp_path):
index=[0, 1],
)
thresholds = {
"snr": {"min": 5.0},
"firing_rate": {"min": 0.1},
"snr": {"greater": 5.0},
"firing_rate": {"greater": 0.1},
}

thresholds_file = tmp_path / "thresholds.json"
Expand All @@ -63,8 +63,8 @@ def test_threshold_metrics_label_external_labels():
index=[0, 1],
)
thresholds = {
"snr": {"min": 5.0},
"firing_rate": {"min": 0.1},
"snr": {"greater": 5.0},
"firing_rate": {"greater": 0.1},
}

labels = threshold_metrics_label_units(
Expand All @@ -86,7 +86,7 @@ def test_threshold_metrics_label_units_operator_or_with_dataframe():
},
index=[0, 1, 2, 3],
)
thresholds = {"m1": {"min": 0.0}, "m2": {"min": 0.0}}
thresholds = {"m1": {"greater": 0.0}, "m2": {"greater": 0.0}}

labels_and = threshold_metrics_label_units(
metrics,
Expand Down Expand Up @@ -115,7 +115,7 @@ def test_threshold_metrics_label_units_nan_policy_fail_vs_ignore_and():
},
index=[10, 11, 12],
)
thresholds = {"m1": {"min": 0.0}, "m2": {"min": 0.0}}
thresholds = {"m1": {"greater": 0.0}, "m2": {"greater": 0.0}}

labels_fail = threshold_metrics_label_units(
metrics,
Expand Down Expand Up @@ -147,7 +147,7 @@ def test_threshold_metrics_label_units_nan_policy_ignore_with_or():
},
index=[20, 21],
)
thresholds = {"m1": {"min": 0.0}, "m2": {"min": 0.0}}
thresholds = {"m1": {"greater": 0.0}, "m2": {"greater": 0.0}}

labels_ignore_or = threshold_metrics_label_units(
metrics,
Expand All @@ -170,7 +170,7 @@ def test_threshold_metrics_label_units_nan_policy_pass_and_or():
},
index=[30, 31, 32, 33],
)
thresholds = {"m1": {"min": 0.0}, "m2": {"min": 0.0}}
thresholds = {"m1": {"greater": 0.0}, "m2": {"greater": 0.0}}

labels_and = threshold_metrics_label_units(
metrics,
Expand Down Expand Up @@ -198,7 +198,7 @@ def test_threshold_metrics_label_units_invalid_operator_raises():
import pandas as pd

metrics = pd.DataFrame({"m1": [1.0]}, index=[0])
thresholds = {"m1": {"min": 0.0}}
thresholds = {"m1": {"greater": 0.0}}
with pytest.raises(ValueError, match="operator must be 'and' or 'or'"):
threshold_metrics_label_units(metrics, thresholds, operator="xor")

Expand All @@ -207,7 +207,7 @@ def test_threshold_metrics_label_units_invalid_nan_policy_raises():
import pandas as pd

metrics = pd.DataFrame({"m1": [1.0]}, index=[0])
thresholds = {"m1": {"min": 0.0}}
thresholds = {"m1": {"greater": 0.0}}
with pytest.raises(ValueError, match="nan_policy must be"):
threshold_metrics_label_units(metrics, thresholds, nan_policy="omit")

Expand All @@ -216,6 +216,15 @@ def test_threshold_metrics_label_units_missing_metric_raises():
import pandas as pd

metrics = pd.DataFrame({"m1": [1.0]}, index=[0])
thresholds = {"does_not_exist": {"min": 0.0}}
thresholds = {"does_not_exist": {"greater": 0.0}}
with pytest.raises(ValueError, match="specified in thresholds are not present"):
threshold_metrics_label_units(metrics, thresholds)


def test_threshold_metrics_label_units_invalid_threshold_keys_raises():
import pandas as pd

metrics = pd.DataFrame({"m1": [1.0]}, index=[0])
thresholds = {"m1": {"greater": 0.0, "invalid_key": 1.0}}
with pytest.raises(ValueError, match="contains invalid keys"):
threshold_metrics_label_units(metrics, thresholds)
17 changes: 13 additions & 4 deletions src/spikeinterface/curation/threshold_metrics_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ def threshold_metrics_label_units(
thresholds : dict | str | Path
A dictionary or JSON file path where keys are metric names and values are threshold values for labeling units.
Each key should correspond to a quality metric present in the analyzer's quality metrics DataFrame. Values
should contain at least "min" and/or "max" keys to specify threshold ranges. Optionally, an "abs": True entry
can be included to indicate that the metric should be treated as an absolute value when applying thresholds.
should contain at least "greater" and/or "less" keys to specify threshold ranges. Thresholds are inclusive, i.e.
"greater" is >= and "less" is <=. Optionally, an "abs": True entry can be included to indicate that the metric
should be treated as an absolute value when applying thresholds.
pass_label : str, default: "good"
The label to assign to units that pass all thresholds.
fail_label : str, default: "noise"
Expand Down Expand Up @@ -74,6 +75,14 @@ def threshold_metrics_label_units(
f"Available metrics are: {metrics.columns.tolist()}"
)

# Check that threshold dictionaries contain only valid keys
valid_keys = {"greater", "less", "abs"}
for metric_name, threshold in thresholds_dict.items():
if not set(threshold).issubset(valid_keys):
raise ValueError(
f"Threshold for metric '{metric_name}' contains invalid keys {set(threshold) - valid_keys}."
)

if operator not in ("and", "or"):
raise ValueError("operator must be 'and' or 'or'")

Expand All @@ -88,8 +97,8 @@ def threshold_metrics_label_units(
any_threshold_applied = False

for metric_name, threshold in thresholds_dict.items():
min_value = threshold.get("min", None)
max_value = threshold.get("max", None)
min_value = threshold.get("greater", None)
max_value = threshold.get("less", None)
abs_value = threshold.get("abs", False)

# If both disabled, ignore this metric
Expand Down
10 changes: 5 additions & 5 deletions src/spikeinterface/widgets/bombcell_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class BombcellUpsetPlotWidget(BaseWidget):
"non_soma", "non_soma_good", "non_soma_mua".
thresholds : dict, optional
Threshold dictionary with structure "noise", "mua", "non-somatic" as sections. Each section contains
metric names keys with "min" and "max" thresholds.
metric names keys with "greater" and "less" thresholds.
If None, uses default thresholds.
unit_labels_to_plot : list of str, optional
List of unit labels to include in the plot. If None, defaults to all labels in thresholds.
Expand Down Expand Up @@ -197,10 +197,10 @@ def _build_failure_table(self, metrics, thresholds):
values = np.abs(values)

failed = np.isnan(values)
if not is_threshold_disabled(thresh.get("min", None)):
failed |= values < thresh["min"]
if not is_threshold_disabled(thresh.get("max", None)):
failed |= values > thresh["max"]
if not is_threshold_disabled(thresh.get("greater", None)):
failed |= values < thresh["greater"]
if not is_threshold_disabled(thresh.get("less", None)):
failed |= values > thresh["less"]
failure_data[metric_name] = failed

return pd.DataFrame(failure_data, index=metrics.index)
Expand Down
Loading