diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index 3b5faa1381..5231f8f6c8 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -8,6 +8,7 @@ from spikeinterface.core.baserecording import BaseRecording from .filter import fix_dtype +from functools import cache class CommonReferenceRecording(BasePreprocessor): @@ -83,13 +84,14 @@ def __init__( groups: list | None = None, ref_channel_ids: list | str | int | None = None, local_radius: tuple[float, float] = (30.0, 55.0), + min_local_neighbors: int = 5, dtype: str | np.dtype | None = None, ): num_chans = recording.get_num_channels() neighbors = None # some checks if reference not in ("global", "single", "local"): - raise ValueError("'reference' must be either 'global', 'single' or 'local'") + raise ValueError("'reference' must be either 'global', 'single', 'local'") if operator not in ("median", "average"): raise ValueError("'operator' must be either 'median', 'average'") @@ -117,10 +119,10 @@ def __init__( closest_inds, dist = get_closest_channels(recording) neighbors = {} for i in range(num_chans): - mask = (dist[i, :] > local_radius[0]) & (dist[i, :] <= local_radius[1]) + mask = dist[i, :] > local_radius[0] + nn = np.cumsum(mask) + mask &= (dist[i, :] <= local_radius[1]) | ((0 < nn) & (nn <= min_local_neighbors)) neighbors[i] = closest_inds[i, mask] - assert len(neighbors[i]) > 0, "No reference channels available in the local annulus for selection." - dtype_ = fix_dtype(recording, dtype) BasePreprocessor.__init__(self, recording, dtype=dtype_) @@ -136,7 +138,14 @@ def __init__( for parent_segment in recording._recording_segments: rec_segment = CommonReferenceRecordingSegment( - parent_segment, reference, operator, group_indices, ref_channel_indices, local_radius, neighbors, dtype_ + parent_segment, + reference, + operator, + group_indices, + ref_channel_indices, + local_radius, + neighbors, + dtype_, ) self.add_recording_segment(rec_segment) @@ -147,6 +156,7 @@ def __init__( operator=operator, ref_channel_ids=ref_channel_ids, local_radius=local_radius, + min_local_neighbors=min_local_neighbors, dtype=dtype_.str, )