From 9487c57c04f20aecb640330d4b087b0e0ba931f0 Mon Sep 17 00:00:00 2001 From: ZZUOHAN Date: Thu, 26 Feb 2026 16:25:34 -0500 Subject: [PATCH 1/4] introduce backup reference and knn reference backup reference for local reference when local channels too few; knn can be more robust than local as it ensures the # of reference channels --- .../preprocessing/common_reference.py | 83 +++++++++++++++---- 1 file changed, 67 insertions(+), 16 deletions(-) diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index 3b5faa1381..07e7ff7794 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -8,6 +8,7 @@ from spikeinterface.core.baserecording import BaseRecording from .filter import fix_dtype +from functools import cache class CommonReferenceRecording(BasePreprocessor): @@ -78,26 +79,30 @@ class CommonReferenceRecording(BasePreprocessor): def __init__( self, recording: BaseRecording, - reference: Literal["global", "single", "local"] = "global", + reference: Literal["global", "single", "local", 'knn'] = "global", operator: Literal["median", "average"] = "median", groups: list | None = None, ref_channel_ids: list | str | int | None = None, local_radius: tuple[float, float] = (30.0, 55.0), + nneighbors: int | None = None, + backup_reference: Literal["global", "single", "knn"] = "global", + backup_thr: int = 1, dtype: str | np.dtype | None = None, ): num_chans = recording.get_num_channels() neighbors = None + knearest_neighbors = None # some checks - if reference not in ("global", "single", "local"): - raise ValueError("'reference' must be either 'global', 'single' or 'local'") + if reference not in ("global", "single", "local", "knn"): + raise ValueError("'reference' must be either 'global', 'single', 'local' or 'knn'") if operator not in ("median", "average"): raise ValueError("'operator' must be either 'median', 'average'") - if reference == "global": + if reference == "global" or backup_reference == "global": if ref_channel_ids is not None: if not isinstance(ref_channel_ids, list): raise ValueError("With 'global' reference, provide 'ref_channel_ids' as a list") - elif reference == "single": + if reference == "single" or reference == 'single': assert ref_channel_ids is not None, "With 'single' reference, provide 'ref_channel_ids'" if groups is not None: assert len(ref_channel_ids) == len(groups), "'ref_channel_ids' and 'groups' must have the same length" @@ -112,15 +117,19 @@ def __init__( assert np.all( [ch in recording.channel_ids for ch in ref_channel_ids] ), "Some 'ref_channel_ids' are wrong!" - elif reference == "local": + if reference == "local": assert groups is None, "With 'local' CAR, the group option should not be used." closest_inds, dist = get_closest_channels(recording) neighbors = {} for i in range(num_chans): mask = (dist[i, :] > local_radius[0]) & (dist[i, :] <= local_radius[1]) neighbors[i] = closest_inds[i, mask] - assert len(neighbors[i]) > 0, "No reference channels available in the local annulus for selection." - + # assert len(neighbors[i]) > 0, "No reference channels available in the local annulus for selection." + if reference == "knn" or backup_reference == 'knn': + assert groups is None, "With 'knn' CAR, the group option should not be used." + assert nneighbors is not None, "With 'knn' reference, provide 'nneighbors'" + assert nneighbors > 0, "'nneighbors' must be positive" + knearest_neighbors, _ = get_closest_channels(recording, num_channels=min(nneighbors, num_chans)) dtype_ = fix_dtype(recording, dtype) BasePreprocessor.__init__(self, recording, dtype=dtype_) @@ -136,7 +145,7 @@ def __init__( for parent_segment in recording._recording_segments: rec_segment = CommonReferenceRecordingSegment( - parent_segment, reference, operator, group_indices, ref_channel_indices, local_radius, neighbors, dtype_ + parent_segment, reference, operator, group_indices, ref_channel_indices, local_radius, neighbors, knearest_neighbors, backup_reference, backup_thr, dtype_ ) self.add_recording_segment(rec_segment) @@ -147,6 +156,9 @@ def __init__( operator=operator, ref_channel_ids=ref_channel_ids, local_radius=local_radius, + nneighbors=nneighbors, + backup_reference=backup_reference, + backup_thr=backup_thr, dtype=dtype_.str, ) @@ -161,11 +173,17 @@ def __init__( ref_channel_indices, local_radius, neighbors, + knearest_neighbors, + backup_reference, + backup_thr, dtype, ): BasePreprocessorSegment.__init__(self, parent_recording_segment) self.reference = reference + self.knearest_neighbors = knearest_neighbors + self.backup_reference = backup_reference + self.backup_thr = backup_thr self.operator = operator self.group_indices = group_indices self.ref_channel_indices = ref_channel_indices @@ -181,23 +199,56 @@ def get_traces(self, start_frame, end_frame, channel_indices): # We need all the channels to calculate the reference traces = self.parent_recording_segment.get_traces(start_frame, end_frame, slice(None)) - if self.reference == "global": + @cache + def _global(keepdims=True): if self.ref_channel_indices is None: - shift = self.operator_func(traces, axis=1, keepdims=True) + shift = self.operator_func(traces, axis=1, keepdims=keepdims) else: - shift = self.operator_func(traces[:, self.ref_channel_indices], axis=1, keepdims=True) - re_referenced_traces = traces[:, channel_indices] - shift - elif self.reference == "single": + shift = self.operator_func(traces[:, self.ref_channel_indices], axis=1, keepdims=keepdims) + return shift + + @cache + def _single(): # single channel -> no need of operator shift = traces[:, self.ref_channel_indices] - re_referenced_traces = traces[:, channel_indices] - shift - else: # then it must be local + return shift + + def _local(): channel_indices_array = np.arange(traces.shape[1])[channel_indices] re_referenced_traces = np.zeros((traces.shape[0], len(channel_indices_array)), dtype="float32") for i, channel_index in enumerate(channel_indices_array): channel_neighborhood = self.neighbors[channel_index] + if len(channel_neighborhood) < self.backup_thr: + if self.backup_reference == 'global': + channel_shift = _global(False) + elif self.backup_reference == 'single': + channel_shift = _single() + else: + channel_neighborhood = self.knearest_neighbors[channel_index] + channel_shift = self.operator_func(traces[:, channel_neighborhood], axis=1) + else: + channel_shift = self.operator_func(traces[:, channel_neighborhood], axis=1) + re_referenced_traces[:, i] = traces[:, channel_index] - channel_shift + return re_referenced_traces + + + def _knn(): + channel_indices_array = np.arange(traces.shape[1])[channel_indices] + re_referenced_traces = np.zeros((traces.shape[0], len(channel_indices_array)), dtype="float32") + for i, channel_index in enumerate(channel_indices_array): + channel_neighborhood = self.knearest_neighbors[channel_index] channel_shift = self.operator_func(traces[:, channel_neighborhood], axis=1) re_referenced_traces[:, i] = traces[:, channel_index] - channel_shift + return re_referenced_traces + + if self.reference == 'global': + re_referenced_traces = traces[:, channel_indices] - _global() + elif self.reference == 'single': + re_referenced_traces = traces[:, channel_indices] - _single() + elif self.reference == 'knn': + re_referenced_traces = _knn() + else: + re_referenced_traces = _local() return re_referenced_traces.astype(self.dtype, copy=False) From 013d810be15099888c1644ff8d8ecd7fe2d80b31 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 26 Feb 2026 21:44:17 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../preprocessing/common_reference.py | 29 ++++++++++++------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index 07e7ff7794..d3f59ff799 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -79,7 +79,7 @@ class CommonReferenceRecording(BasePreprocessor): def __init__( self, recording: BaseRecording, - reference: Literal["global", "single", "local", 'knn'] = "global", + reference: Literal["global", "single", "local", "knn"] = "global", operator: Literal["median", "average"] = "median", groups: list | None = None, ref_channel_ids: list | str | int | None = None, @@ -102,7 +102,7 @@ def __init__( if ref_channel_ids is not None: if not isinstance(ref_channel_ids, list): raise ValueError("With 'global' reference, provide 'ref_channel_ids' as a list") - if reference == "single" or reference == 'single': + if reference == "single" or reference == "single": assert ref_channel_ids is not None, "With 'single' reference, provide 'ref_channel_ids'" if groups is not None: assert len(ref_channel_ids) == len(groups), "'ref_channel_ids' and 'groups' must have the same length" @@ -125,7 +125,7 @@ def __init__( mask = (dist[i, :] > local_radius[0]) & (dist[i, :] <= local_radius[1]) neighbors[i] = closest_inds[i, mask] # assert len(neighbors[i]) > 0, "No reference channels available in the local annulus for selection." - if reference == "knn" or backup_reference == 'knn': + if reference == "knn" or backup_reference == "knn": assert groups is None, "With 'knn' CAR, the group option should not be used." assert nneighbors is not None, "With 'knn' reference, provide 'nneighbors'" assert nneighbors > 0, "'nneighbors' must be positive" @@ -145,7 +145,17 @@ def __init__( for parent_segment in recording._recording_segments: rec_segment = CommonReferenceRecordingSegment( - parent_segment, reference, operator, group_indices, ref_channel_indices, local_radius, neighbors, knearest_neighbors, backup_reference, backup_thr, dtype_ + parent_segment, + reference, + operator, + group_indices, + ref_channel_indices, + local_radius, + neighbors, + knearest_neighbors, + backup_reference, + backup_thr, + dtype_, ) self.add_recording_segment(rec_segment) @@ -219,9 +229,9 @@ def _local(): for i, channel_index in enumerate(channel_indices_array): channel_neighborhood = self.neighbors[channel_index] if len(channel_neighborhood) < self.backup_thr: - if self.backup_reference == 'global': + if self.backup_reference == "global": channel_shift = _global(False) - elif self.backup_reference == 'single': + elif self.backup_reference == "single": channel_shift = _single() else: channel_neighborhood = self.knearest_neighbors[channel_index] @@ -231,7 +241,6 @@ def _local(): re_referenced_traces[:, i] = traces[:, channel_index] - channel_shift return re_referenced_traces - def _knn(): channel_indices_array = np.arange(traces.shape[1])[channel_indices] re_referenced_traces = np.zeros((traces.shape[0], len(channel_indices_array)), dtype="float32") @@ -241,11 +250,11 @@ def _knn(): re_referenced_traces[:, i] = traces[:, channel_index] - channel_shift return re_referenced_traces - if self.reference == 'global': + if self.reference == "global": re_referenced_traces = traces[:, channel_indices] - _global() - elif self.reference == 'single': + elif self.reference == "single": re_referenced_traces = traces[:, channel_indices] - _single() - elif self.reference == 'knn': + elif self.reference == "knn": re_referenced_traces = _knn() else: re_referenced_traces = _local() From ff1811177bfc95be6de27426116107aaf095b4cd Mon Sep 17 00:00:00 2001 From: ZZUOHAN Date: Fri, 27 Feb 2026 11:49:28 -0500 Subject: [PATCH 3/4] Remove KNN as a reference method & limit backup behavior to KNN only --- .../preprocessing/common_reference.py | 86 ++++--------------- 1 file changed, 18 insertions(+), 68 deletions(-) diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index d3f59ff799..352afba263 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -79,30 +79,27 @@ class CommonReferenceRecording(BasePreprocessor): def __init__( self, recording: BaseRecording, - reference: Literal["global", "single", "local", "knn"] = "global", + reference: Literal["global", "single", "local"] = "global", operator: Literal["median", "average"] = "median", groups: list | None = None, ref_channel_ids: list | str | int | None = None, local_radius: tuple[float, float] = (30.0, 55.0), - nneighbors: int | None = None, - backup_reference: Literal["global", "single", "knn"] = "global", - backup_thr: int = 1, + min_local_neighbors: int = 5, dtype: str | np.dtype | None = None, ): num_chans = recording.get_num_channels() neighbors = None - knearest_neighbors = None # some checks - if reference not in ("global", "single", "local", "knn"): - raise ValueError("'reference' must be either 'global', 'single', 'local' or 'knn'") + if reference not in ("global", "single", "local"): + raise ValueError("'reference' must be either 'global', 'single', 'local'") if operator not in ("median", "average"): raise ValueError("'operator' must be either 'median', 'average'") - if reference == "global" or backup_reference == "global": + if reference == "global": if ref_channel_ids is not None: if not isinstance(ref_channel_ids, list): raise ValueError("With 'global' reference, provide 'ref_channel_ids' as a list") - if reference == "single" or reference == "single": + elif reference == "single": assert ref_channel_ids is not None, "With 'single' reference, provide 'ref_channel_ids'" if groups is not None: assert len(ref_channel_ids) == len(groups), "'ref_channel_ids' and 'groups' must have the same length" @@ -117,19 +114,15 @@ def __init__( assert np.all( [ch in recording.channel_ids for ch in ref_channel_ids] ), "Some 'ref_channel_ids' are wrong!" - if reference == "local": + elif reference == "local": assert groups is None, "With 'local' CAR, the group option should not be used." closest_inds, dist = get_closest_channels(recording) neighbors = {} for i in range(num_chans): - mask = (dist[i, :] > local_radius[0]) & (dist[i, :] <= local_radius[1]) + mask = (dist[i, :] > local_radius[0]) + nn = np.cumsum(mask) + mask &= (dist[i, :] <= local_radius[1]) | ((0 < nn) & (nn <= min_local_neighbors)) neighbors[i] = closest_inds[i, mask] - # assert len(neighbors[i]) > 0, "No reference channels available in the local annulus for selection." - if reference == "knn" or backup_reference == "knn": - assert groups is None, "With 'knn' CAR, the group option should not be used." - assert nneighbors is not None, "With 'knn' reference, provide 'nneighbors'" - assert nneighbors > 0, "'nneighbors' must be positive" - knearest_neighbors, _ = get_closest_channels(recording, num_channels=min(nneighbors, num_chans)) dtype_ = fix_dtype(recording, dtype) BasePreprocessor.__init__(self, recording, dtype=dtype_) @@ -152,9 +145,6 @@ def __init__( ref_channel_indices, local_radius, neighbors, - knearest_neighbors, - backup_reference, - backup_thr, dtype_, ) self.add_recording_segment(rec_segment) @@ -166,9 +156,7 @@ def __init__( operator=operator, ref_channel_ids=ref_channel_ids, local_radius=local_radius, - nneighbors=nneighbors, - backup_reference=backup_reference, - backup_thr=backup_thr, + min_local_neighbors=min_local_neighbors, dtype=dtype_.str, ) @@ -183,17 +171,11 @@ def __init__( ref_channel_indices, local_radius, neighbors, - knearest_neighbors, - backup_reference, - backup_thr, dtype, ): BasePreprocessorSegment.__init__(self, parent_recording_segment) self.reference = reference - self.knearest_neighbors = knearest_neighbors - self.backup_reference = backup_reference - self.backup_thr = backup_thr self.operator = operator self.group_indices = group_indices self.ref_channel_indices = ref_channel_indices @@ -209,55 +191,23 @@ def get_traces(self, start_frame, end_frame, channel_indices): # We need all the channels to calculate the reference traces = self.parent_recording_segment.get_traces(start_frame, end_frame, slice(None)) - @cache - def _global(keepdims=True): + if self.reference == "global": if self.ref_channel_indices is None: - shift = self.operator_func(traces, axis=1, keepdims=keepdims) + shift = self.operator_func(traces, axis=1, keepdims=True) else: - shift = self.operator_func(traces[:, self.ref_channel_indices], axis=1, keepdims=keepdims) - return shift - - @cache - def _single(): + shift = self.operator_func(traces[:, self.ref_channel_indices], axis=1, keepdims=True) + re_referenced_traces = traces[:, channel_indices] - shift + elif self.reference == "single": # single channel -> no need of operator shift = traces[:, self.ref_channel_indices] - return shift - - def _local(): + re_referenced_traces = traces[:, channel_indices] - shift + else: # then it must be local channel_indices_array = np.arange(traces.shape[1])[channel_indices] re_referenced_traces = np.zeros((traces.shape[0], len(channel_indices_array)), dtype="float32") for i, channel_index in enumerate(channel_indices_array): channel_neighborhood = self.neighbors[channel_index] - if len(channel_neighborhood) < self.backup_thr: - if self.backup_reference == "global": - channel_shift = _global(False) - elif self.backup_reference == "single": - channel_shift = _single() - else: - channel_neighborhood = self.knearest_neighbors[channel_index] - channel_shift = self.operator_func(traces[:, channel_neighborhood], axis=1) - else: - channel_shift = self.operator_func(traces[:, channel_neighborhood], axis=1) - re_referenced_traces[:, i] = traces[:, channel_index] - channel_shift - return re_referenced_traces - - def _knn(): - channel_indices_array = np.arange(traces.shape[1])[channel_indices] - re_referenced_traces = np.zeros((traces.shape[0], len(channel_indices_array)), dtype="float32") - for i, channel_index in enumerate(channel_indices_array): - channel_neighborhood = self.knearest_neighbors[channel_index] channel_shift = self.operator_func(traces[:, channel_neighborhood], axis=1) re_referenced_traces[:, i] = traces[:, channel_index] - channel_shift - return re_referenced_traces - - if self.reference == "global": - re_referenced_traces = traces[:, channel_indices] - _global() - elif self.reference == "single": - re_referenced_traces = traces[:, channel_indices] - _single() - elif self.reference == "knn": - re_referenced_traces = _knn() - else: - re_referenced_traces = _local() return re_referenced_traces.astype(self.dtype, copy=False) From 1972164caa4468db9fe142fbd0729c1c9f269f32 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 27 Feb 2026 16:50:02 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/preprocessing/common_reference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index 352afba263..5231f8f6c8 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -119,7 +119,7 @@ def __init__( closest_inds, dist = get_closest_channels(recording) neighbors = {} for i in range(num_chans): - mask = (dist[i, :] > local_radius[0]) + mask = dist[i, :] > local_radius[0] nn = np.cumsum(mask) mask &= (dist[i, :] <= local_radius[1]) | ((0 < nn) & (nn <= min_local_neighbors)) neighbors[i] = closest_inds[i, mask]