From a6e226c74fca382a8d48d3dbdd970549891616b1 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Tue, 10 Feb 2026 18:49:28 +0100 Subject: [PATCH 1/7] wip new lazy to_multiscale() for labels --- src/spatialdata/_io/io_raster.py | 19 ++---- src/spatialdata/datasets.py | 63 ++++++++++++------ src/spatialdata/models/models.py | 33 +++++----- src/spatialdata/models/pyramids_utils.py | 84 ++++++++++++++++++++++++ tests/models/test_pyramids_utils.py | 52 +++++++++++++++ 5 files changed, 199 insertions(+), 52 deletions(-) create mode 100644 src/spatialdata/models/pyramids_utils.py create mode 100644 tests/models/test_pyramids_utils.py diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index bc8206db..86b537d0 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -13,7 +13,7 @@ from ome_zarr.writer import write_labels as write_labels_ngff from ome_zarr.writer import write_multiscale as write_multiscale_ngff from ome_zarr.writer import write_multiscale_labels as write_multiscale_labels_ngff -from xarray import DataArray, Dataset, DataTree +from xarray import DataArray, DataTree from spatialdata._io._utils import ( _get_transformations_from_ngff_dict, @@ -27,6 +27,7 @@ from spatialdata._utils import get_pyramid_levels from spatialdata.models._utils import get_channel_names from spatialdata.models.models import ATTRS_KEY +from spatialdata.models.pyramids_utils import dask_arrays_to_datatree from spatialdata.transformations._utils import ( _get_transformations, _get_transformations_xarray, @@ -91,20 +92,8 @@ def _read_multiscale( channels = [d["label"] for d in omero_metadata["channels"]] axes = [i["name"] for i in node.metadata["axes"]] if len(datasets) > 1: - multiscale_image = {} - for i, d in enumerate(datasets): - data = node.load(Multiscales).array(resolution=d) - multiscale_image[f"scale{i}"] = Dataset( - { - "image": DataArray( - data, - name="image", - dims=axes, - coords={"c": channels} if channels is not None else {}, - ) - } - ) - msi = DataTree.from_dict(multiscale_image) + arrays = [node.load(Multiscales).array(resolution=d) for d in datasets] + msi = dask_arrays_to_datatree(arrays, dims=axes, channels=channels) _set_transformations(msi, transformations) return compute_coordinates(msi) diff --git a/src/spatialdata/datasets.py b/src/spatialdata/datasets.py index ea38d739..4df422d8 100644 --- a/src/spatialdata/datasets.py +++ b/src/spatialdata/datasets.py @@ -20,7 +20,15 @@ from spatialdata._core.query.relational_query import get_element_instances from spatialdata._core.spatialdata import SpatialData from spatialdata._types import ArrayLike -from spatialdata.models import Image2DModel, Labels2DModel, PointsModel, ShapesModel, TableModel +from spatialdata.models import ( + Image2DModel, + Image3DModel, + Labels2DModel, + Labels3DModel, + PointsModel, + ShapesModel, + TableModel, +) from spatialdata.transformations import Identity __all__ = ["blobs", "raccoon"] @@ -172,37 +180,47 @@ def _image_blobs( n_channels: int = 3, c_coords: str | list[str] | None = None, multiscale: bool = False, + ndim: int = 2, ) -> DataArray | DataTree: masks = [] for i in range(n_channels): - mask = self._generate_blobs(length=length, seed=i) + mask = self._generate_blobs(length=length, seed=i, ndim=ndim) mask = (mask - mask.min()) / np.ptp(mask) masks.append(mask) x = np.stack(masks, axis=0) - dims = ["c", "y", "x"] + if ndim == 2: + dims = ["c", "y", "x"] + model = Image2DModel + else: + dims = ["c", "z", "y", "x"] + model = Image3DModel if not multiscale: - return Image2DModel.parse(x, transformations=transformations, dims=dims, c_coords=c_coords) - return Image2DModel.parse( - x, transformations=transformations, dims=dims, c_coords=c_coords, scale_factors=[2, 2] - ) + return model.parse(x, transformations=transformations, dims=dims, c_coords=c_coords) + return model.parse(x, transformations=transformations, dims=dims, c_coords=c_coords, scale_factors=[2, 2]) def _labels_blobs( - self, transformations: dict[str, Any] | None = None, length: int = 512, multiscale: bool = False + self, + transformations: dict[str, Any] | None = None, + length: int = 512, + multiscale: bool = False, + ndim: int = 2, ) -> DataArray | DataTree: - """Create a 2D labels.""" + """Create labels in 2D or 3D.""" from scipy.ndimage import watershed_ift # from skimage - mask = self._generate_blobs(length=length) + mask = self._generate_blobs(length=length, ndim=ndim) threshold = np.percentile(mask, 100 * (1 - 0.3)) inputs = np.logical_not(mask < threshold).astype(np.uint8) # use watershed from scipy - xm, ym = np.ogrid[0:length:10, 0:length:10] + grid = np.ogrid[tuple(slice(0, length, 10) for _ in range(ndim))] markers = np.zeros_like(inputs).astype(np.int16) - markers[xm, ym] = np.arange(xm.size * ym.size).reshape((xm.size, ym.size)) + grid_shape = tuple(g.size for g in grid) + markers[tuple(grid)] = np.arange(np.prod(grid_shape)).reshape(grid_shape) out = watershed_ift(inputs, markers) - out[xm, ym] = out[xm - 1, ym - 1] # remove the isolate seeds + shifted = tuple(g - 1 for g in grid) + out[tuple(grid)] = out[tuple(shifted)] # remove the isolated seeds # reindex by frequency val, counts = np.unique(out, return_counts=True) sorted_idx = np.argsort(counts) @@ -211,20 +229,25 @@ def _labels_blobs( out[out == val[idx]] = 0 else: out[out == val[idx]] = i - dims = ["y", "x"] + if ndim == 2: + dims = ["y", "x"] + model = Labels2DModel + else: + dims = ["z", "y", "x"] + model = Labels3DModel if not multiscale: - return Labels2DModel.parse(out, transformations=transformations, dims=dims) - return Labels2DModel.parse(out, transformations=transformations, dims=dims, scale_factors=[2, 2]) + return model.parse(out, transformations=transformations, dims=dims) + return model.parse(out, transformations=transformations, dims=dims, scale_factors=[2, 2]) - def _generate_blobs(self, length: int = 512, seed: int | None = None) -> ArrayLike: + def _generate_blobs(self, length: int = 512, seed: int | None = None, ndim: int = 2) -> ArrayLike: from scipy.ndimage import gaussian_filter rng = default_rng(42) if seed is None else default_rng(seed) # from skimage - shape = tuple([length] * 2) + shape = (length,) * ndim mask = np.zeros(shape) - n_pts = max(int(1.0 / 0.1) ** 2, 1) - points = (length * rng.random((2, n_pts))).astype(int) + n_pts = max(int(1.0 / 0.1) ** ndim, 1) + points = (length * rng.random((ndim, n_pts))).astype(int) mask[tuple(indices for indices in points)] = 1 mask = gaussian_filter(mask, sigma=0.25 * length * 0.1) assert isinstance(mask, np.ndarray) diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index 2d54c709..4de58136 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -14,7 +14,7 @@ from dask.array.core import from_array from dask.dataframe import DataFrame as DaskDataFrame from geopandas import GeoDataFrame, GeoSeries -from multiscale_spatial_image import to_multiscale +from multiscale_spatial_image import to_multiscale as to_multiscale_msi from multiscale_spatial_image.to_multiscale.to_multiscale import Methods from pandas import CategoricalDtype from shapely._geometry import GeometryType @@ -38,6 +38,8 @@ _validate_mapping_to_coordinate_system_type, convert_region_column_to_categorical, ) +from spatialdata.models.pyramids_utils import Chunks_t, ScaleFactors_t +from spatialdata.models.pyramids_utils import to_multiscale as to_multiscale_ozp # ozp -> ome-zarr-py from spatialdata.transformations._utils import ( _get_transformations, _set_transformations, @@ -45,10 +47,6 @@ ) from spatialdata.transformations.transformations import Identity -# Types -Chunks_t: TypeAlias = int | tuple[int, ...] | tuple[tuple[int, ...], ...] | Mapping[Any, None | int | tuple[int, ...]] -ScaleFactors_t = Sequence[dict[str, int] | int] - ATTRS_KEY = "spatialdata_attrs" @@ -225,12 +223,19 @@ def parse( chunks = {dim: chunks[index] for index, dim in enumerate(data.dims)} if isinstance(chunks, float): chunks = {dim: chunks for index, dim in data.dims} - data = to_multiscale( - data, - scale_factors=scale_factors, - method=method, - chunks=chunks, - ) + if method is not None: + data = to_multiscale_msi( + data, + scale_factors=scale_factors, + method=method, + chunks=chunks, + ) + else: + data = to_multiscale_ozp( + data, + scale_factors=scale_factors, + chunks=chunks, + ) _parse_transformations(data, parsed_transform) else: # Chunk single scale images @@ -375,9 +380,6 @@ def parse( # noqa: D102 ) -> DataArray | DataTree: if kwargs.get("c_coords") is not None: raise ValueError("`c_coords` is not supported for labels") - if kwargs.get("scale_factors") is not None and kwargs.get("method") is None: - # Override default scaling method to preserve labels - kwargs["method"] = Methods.DASK_IMAGE_NEAREST return super().parse(*args, **kwargs) @@ -388,9 +390,6 @@ class Labels3DModel(RasterSchema): def parse(self, *args: Any, **kwargs: Any) -> DataArray | DataTree: # noqa: D102 if kwargs.get("c_coords") is not None: raise ValueError("`c_coords` is not supported for labels") - if kwargs.get("scale_factors") is not None and kwargs.get("method") is None: - # Override default scaling method to preserve labels - kwargs["method"] = Methods.DASK_IMAGE_NEAREST return super().parse(*args, **kwargs) diff --git a/src/spatialdata/models/pyramids_utils.py b/src/spatialdata/models/pyramids_utils.py new file mode 100644 index 00000000..e610b827 --- /dev/null +++ b/src/spatialdata/models/pyramids_utils.py @@ -0,0 +1,84 @@ +from collections.abc import Mapping, Sequence +from typing import Any, TypeAlias + +import dask.array as da +from ome_zarr.dask_utils import resize +from xarray import DataArray, Dataset, DataTree + +Chunks_t: TypeAlias = int | tuple[int, ...] | tuple[tuple[int, ...], ...] | Mapping[Any, None | int | tuple[int, ...]] +ScaleFactors_t = Sequence[dict[str, int] | int] + + +def dask_arrays_to_datatree( + arrays: Sequence[da.Array], + dims: Sequence[str], + channels: list[Any] | None = None, +) -> DataTree: + """Build a multiscale DataTree from a sequence of dask arrays. + + Parameters + ---------- + arrays + Sequence of dask arrays, one per scale level (scale0, scale1, ...). + dims + Dimension names for the arrays (e.g. ``("c", "y", "x")``). + channels + Optional channel coordinate values. If provided, a ``"c"`` coordinate + is added to each scale level. + + Returns + ------- + DataTree with one child per scale level. + """ + coords = {"c": channels} if channels is not None else {} + d = {} + for i, arr in enumerate(arrays): + d[f"scale{i}"] = Dataset( + { + "image": DataArray( + arr, + name="image", + dims=list(dims), + coords=coords, + ) + } + ) + return DataTree.from_dict(d) + + +def to_multiscale( + image: DataArray, + scale_factors: ScaleFactors_t, + chunks: Chunks_t | None = None, +) -> DataTree: + dims = [str(dim) for dim in image.dims] + spatial_dims = [d for d in dims if d != "c"] + order = 1 if "c" in dims else 0 + pyramid = [image.data] + for sf in scale_factors: + prev = pyramid[-1] + # Compute per-axis scale factors: int applies to spatial axes only, dict to specific ones. + sf_by_axis = dict.fromkeys(dims, 1) + if isinstance(sf, int): + sf_by_axis.update(dict.fromkeys(spatial_dims, sf)) + else: + sf_by_axis.update(sf) + # Clamp: skip axes where the scale factor exceeds the axis size. + for ax, factor in sf_by_axis.items(): + ax_size = prev.shape[dims.index(ax)] + if factor > ax_size: + sf_by_axis[ax] = 1 + output_shape = tuple(prev.shape[dims.index(ax)] // f for ax, f in sf_by_axis.items()) + resized = resize( + image=prev.astype(float), + output_shape=output_shape, + order=order, + mode="reflect", + anti_aliasing=False, + ) + pyramid.append(resized.astype(prev.dtype)) + if chunks is not None: + if isinstance(chunks, Mapping): + chunks = {dims.index(k) if isinstance(k, str) else k: v for k, v in chunks.items()} + pyramid = [arr.rechunk(chunks) for arr in pyramid] + return dask_arrays_to_datatree(pyramid, dims=dims) diff --git a/tests/models/test_pyramids_utils.py b/tests/models/test_pyramids_utils.py new file mode 100644 index 00000000..6382d191 --- /dev/null +++ b/tests/models/test_pyramids_utils.py @@ -0,0 +1,52 @@ +import dask +import numpy as np +import pytest +from multiscale_spatial_image.to_multiscale.to_multiscale import Methods + +from spatialdata.datasets import BlobsDataset +from spatialdata.models import Image2DModel, Image3DModel, Labels2DModel, Labels3DModel + +CHUNK_SIZE = 32 + + +@pytest.mark.parametrize( + ("model", "length", "ndim", "n_channels", "scale_factors", "method"), + [ + (Image2DModel, 128, 2, 3, (2, 2), Methods.XARRAY_COARSEN), + (Image3DModel, 32, 3, 3, (2, 2), Methods.XARRAY_COARSEN), + (Labels2DModel, 128, 2, 0, (2, 2), Methods.DASK_IMAGE_NEAREST), + (Labels3DModel, 32, 3, 0, (2, 2), Methods.DASK_IMAGE_NEAREST), + ], +) +def test_to_multiscale_via_ome_zarr_scaler(model, length, ndim, n_channels, scale_factors, method): + blob_gen = BlobsDataset() + + if n_channels > 0: + # Image: stack multiple blob channels + masks = [] + for i in range(n_channels): + mask = blob_gen._generate_blobs(length=length, seed=i, ndim=ndim) + mask = (mask - mask.min()) / np.ptp(mask) + masks.append(mask) + array = np.stack(masks, axis=0) + else: + # Labels: threshold blob pattern to get integer labels + mask = blob_gen._generate_blobs(length=length, ndim=ndim) + threshold = np.percentile(mask, 70) + array = (mask >= threshold).astype(np.int64) + + dims = model.dims + dask_data = dask.array.from_array(array).rechunk(CHUNK_SIZE) + + # multiscale-spatial-image path (explicit method) + result_msi = model.parse(dask_data, dims=dims, scale_factors=scale_factors, chunks=CHUNK_SIZE, method=method) + + # ome-zarr-py scaler path (method=None triggers the ome-zarr-py scaler) + result_ozp = model.parse(dask_data, dims=dims, scale_factors=scale_factors, chunks=CHUNK_SIZE) + + # Compare data values at each scale level + for scale_name in result_msi.children: + msi_arr = result_msi[scale_name].ds["image"] + ozp_arr = result_ozp[scale_name].ds["image"] + assert msi_arr.sizes == ozp_arr.sizes + np.testing.assert_allclose(msi_arr.values, ozp_arr.values) From 19b4d5fa4d02789626ae0b1814e0c36d3a50e170 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Wed, 11 Feb 2026 13:36:44 +0100 Subject: [PATCH 2/7] wip some plots for debugging purposes --- src/spatialdata/models/pyramids_utils.py | 5 ++-- tests/models/test_pyramids_utils.py | 30 ++++++++++++++++++++---- 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/src/spatialdata/models/pyramids_utils.py b/src/spatialdata/models/pyramids_utils.py index e610b827..cd902b97 100644 --- a/src/spatialdata/models/pyramids_utils.py +++ b/src/spatialdata/models/pyramids_utils.py @@ -54,6 +54,7 @@ def to_multiscale( dims = [str(dim) for dim in image.dims] spatial_dims = [d for d in dims if d != "c"] order = 1 if "c" in dims else 0 + channels = None if "c" not in dims else image.coords["c"].values pyramid = [image.data] for sf in scale_factors: prev = pyramid[-1] @@ -63,7 +64,7 @@ def to_multiscale( sf_by_axis.update(dict.fromkeys(spatial_dims, sf)) else: sf_by_axis.update(sf) - # Clamp: skip axes where the scale factor exceeds the axis size. + # skip axes where the scale factor exceeds the axis size. for ax, factor in sf_by_axis.items(): ax_size = prev.shape[dims.index(ax)] if factor > ax_size: @@ -81,4 +82,4 @@ def to_multiscale( if isinstance(chunks, Mapping): chunks = {dims.index(k) if isinstance(k, str) else k: v for k, v in chunks.items()} pyramid = [arr.rechunk(chunks) for arr in pyramid] - return dask_arrays_to_datatree(pyramid, dims=dims) + return dask_arrays_to_datatree(pyramid, dims=dims, channels=channels) diff --git a/tests/models/test_pyramids_utils.py b/tests/models/test_pyramids_utils.py index 6382d191..9f5c548f 100644 --- a/tests/models/test_pyramids_utils.py +++ b/tests/models/test_pyramids_utils.py @@ -12,8 +12,8 @@ @pytest.mark.parametrize( ("model", "length", "ndim", "n_channels", "scale_factors", "method"), [ - (Image2DModel, 128, 2, 3, (2, 2), Methods.XARRAY_COARSEN), - (Image3DModel, 32, 3, 3, (2, 2), Methods.XARRAY_COARSEN), + # (Image2DModel, 128, 2, 3, (2, 2), Methods.XARRAY_COARSEN), + # (Image3DModel, 32, 3, 3, (2, 2), Methods.XARRAY_COARSEN), (Labels2DModel, 128, 2, 0, (2, 2), Methods.DASK_IMAGE_NEAREST), (Labels3DModel, 32, 3, 0, (2, 2), Methods.DASK_IMAGE_NEAREST), ], @@ -38,15 +38,35 @@ def test_to_multiscale_via_ome_zarr_scaler(model, length, ndim, n_channels, scal dims = model.dims dask_data = dask.array.from_array(array).rechunk(CHUNK_SIZE) - # multiscale-spatial-image path (explicit method) + # # multiscale-spatial-image path (explicit method) result_msi = model.parse(dask_data, dims=dims, scale_factors=scale_factors, chunks=CHUNK_SIZE, method=method) # ome-zarr-py scaler path (method=None triggers the ome-zarr-py scaler) result_ozp = model.parse(dask_data, dims=dims, scale_factors=scale_factors, chunks=CHUNK_SIZE) + # ## + # from napari_spatialdata import Interactive + # from spatialdata import SpatialData + # + # sdata = SpatialData.init_from_elements({'msi': result_msi, 'ozp': result_ozp}) + # Interactive(sdata) + + ## + # Compare data values at each scale level - for scale_name in result_msi.children: + import matplotlib.pyplot as plt + _, axes = plt.subplots(len(result_msi.children), 2, figsize=(8, 4 * len(result_msi.children))) + for i, scale_name in enumerate(result_msi.children): msi_arr = result_msi[scale_name].ds["image"] ozp_arr = result_ozp[scale_name].ds["image"] assert msi_arr.sizes == ozp_arr.sizes - np.testing.assert_allclose(msi_arr.values, ozp_arr.values) + + if msi_arr.ndim == 3: + msi_arr = msi_arr[0] + ozp_arr = ozp_arr[0] + axes[i, 0].imshow(msi_arr.values) + axes[i, 1].imshow(ozp_arr.values) + pass + # np.testing.assert_allclose(msi_arr.values, ozp_arr.values) + plt.tight_layout() + plt.show() From 7d7dd5d322813d31affb65a89e7c70125a9efc4a Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Wed, 11 Feb 2026 13:54:15 +0100 Subject: [PATCH 3/7] fix pre-commit and tests --- src/spatialdata/datasets.py | 2 ++ src/spatialdata/models/pyramids_utils.py | 30 ++++++++++++++++++ tests/models/test_pyramids_utils.py | 40 ++++++++++-------------- 3 files changed, 48 insertions(+), 24 deletions(-) diff --git a/src/spatialdata/datasets.py b/src/spatialdata/datasets.py index 4df422d8..8dfd9bf3 100644 --- a/src/spatialdata/datasets.py +++ b/src/spatialdata/datasets.py @@ -189,6 +189,7 @@ def _image_blobs( masks.append(mask) x = np.stack(masks, axis=0) + model: type[Image2DModel] | type[Image3DModel] if ndim == 2: dims = ["c", "y", "x"] model = Image2DModel @@ -229,6 +230,7 @@ def _labels_blobs( out[out == val[idx]] = 0 else: out[out == val[idx]] = i + model: type[Labels2DModel] | type[Labels3DModel] if ndim == 2: dims = ["y", "x"] model = Labels2DModel diff --git a/src/spatialdata/models/pyramids_utils.py b/src/spatialdata/models/pyramids_utils.py index cd902b97..ec68facb 100644 --- a/src/spatialdata/models/pyramids_utils.py +++ b/src/spatialdata/models/pyramids_utils.py @@ -51,6 +51,36 @@ def to_multiscale( scale_factors: ScaleFactors_t, chunks: Chunks_t | None = None, ) -> DataTree: + """Build a multiscale pyramid DataTree from a single-scale image. + + Iteratively downscales the image by the given scale factors using + interpolation (order 1 for images with a channel dimension, order 0 + for labels) and assembles all levels into a DataTree. + + Makes uses of internal ome-zarr-py APIs for dask downscaling. + + TODO: ome-zarr-py will support 3D downscaling once https://github.com/ome/ome-zarr-py/pull/516 is merged, and this + function could make use of it. Also the PR will introduce new downscaling methods such as "nearest". Nevertheless, + this function supports different scaling factors per axis, which is not supported by ome-zarr-py yet. + + Parameters + ---------- + image + Input image/labels as an xarray DataArray (e.g. with dims ``("c", "y", "x")`` + or ``("y", "x")``). Supports both 2D/3D images and 2D/3D labels. + scale_factors + Sequence of per-level scale factors. Each element is either an int + (applied to all spatial axes) or a dict mapping dimension names to + per-axis factors (e.g. ``{"y": 2, "x": 2}``). + chunks + Optional chunk specification passed to :meth:`dask.array.Array.rechunk` + after building the pyramid. + + Returns + ------- + DataTree + Multiscale DataTree with children ``scale0``, ``scale1``, etc. + """ dims = [str(dim) for dim in image.dims] spatial_dims = [d for d in dims if d != "c"] order = 1 if "c" in dims else 0 diff --git a/tests/models/test_pyramids_utils.py b/tests/models/test_pyramids_utils.py index 9f5c548f..fa49decc 100644 --- a/tests/models/test_pyramids_utils.py +++ b/tests/models/test_pyramids_utils.py @@ -12,8 +12,8 @@ @pytest.mark.parametrize( ("model", "length", "ndim", "n_channels", "scale_factors", "method"), [ - # (Image2DModel, 128, 2, 3, (2, 2), Methods.XARRAY_COARSEN), - # (Image3DModel, 32, 3, 3, (2, 2), Methods.XARRAY_COARSEN), + (Image2DModel, 128, 2, 3, (2, 2), Methods.XARRAY_COARSEN), + (Image3DModel, 32, 3, 3, (2, 2), Methods.XARRAY_COARSEN), (Labels2DModel, 128, 2, 0, (2, 2), Methods.DASK_IMAGE_NEAREST), (Labels3DModel, 32, 3, 0, (2, 2), Methods.DASK_IMAGE_NEAREST), ], @@ -38,35 +38,27 @@ def test_to_multiscale_via_ome_zarr_scaler(model, length, ndim, n_channels, scal dims = model.dims dask_data = dask.array.from_array(array).rechunk(CHUNK_SIZE) - # # multiscale-spatial-image path (explicit method) + # multiscale-spatial-image path (explicit method) result_msi = model.parse(dask_data, dims=dims, scale_factors=scale_factors, chunks=CHUNK_SIZE, method=method) # ome-zarr-py scaler path (method=None triggers the ome-zarr-py scaler) result_ozp = model.parse(dask_data, dims=dims, scale_factors=scale_factors, chunks=CHUNK_SIZE) - # ## - # from napari_spatialdata import Interactive - # from spatialdata import SpatialData - # - # sdata = SpatialData.init_from_elements({'msi': result_msi, 'ozp': result_ozp}) - # Interactive(sdata) - - ## - # Compare data values at each scale level - import matplotlib.pyplot as plt - _, axes = plt.subplots(len(result_msi.children), 2, figsize=(8, 4 * len(result_msi.children))) for i, scale_name in enumerate(result_msi.children): msi_arr = result_msi[scale_name].ds["image"] ozp_arr = result_ozp[scale_name].ds["image"] assert msi_arr.sizes == ozp_arr.sizes - - if msi_arr.ndim == 3: - msi_arr = msi_arr[0] - ozp_arr = ozp_arr[0] - axes[i, 0].imshow(msi_arr.values) - axes[i, 1].imshow(ozp_arr.values) - pass - # np.testing.assert_allclose(msi_arr.values, ozp_arr.values) - plt.tight_layout() - plt.show() + if model in [Image2DModel, Image3DModel]: + # exact comparison for images + np.testing.assert_allclose(msi_arr.values, ozp_arr.values) + else: + if i == 0: + # no downscaling is performed, so they must be equal + np.testing.assert_array_equal(msi_arr.values, ozp_arr.values) + else: + # we expect differences: ngff-zarr uses "nearest", ozp uses "resize" + # TODO: when https://github.com/ome/ome-zarr-py/pull/516 is merged we can use nearest for labels and + # expect a much stricter adherence + fraction_non_equal = np.sum(msi_arr.values != ozp_arr.values) / np.prod(msi_arr.values.shape) + assert fraction_non_equal < 0.5 From 529393ff56030b570c87e21be50c33d191038474 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Wed, 11 Feb 2026 15:17:04 +0100 Subject: [PATCH 4/7] code cleanup, moved here (and fixed) normalize_chunks from spatialdata-io --- src/spatialdata/datasets.py | 18 +++---- src/spatialdata/models/chunks_utils.py | 60 ++++++++++++++++++++++++ src/spatialdata/models/models.py | 7 ++- src/spatialdata/models/pyramids_utils.py | 25 ++++++---- tests/models/test_chunks_utils.py | 45 ++++++++++++++++++ tests/models/test_pyramids_utils.py | 26 ++++------ 6 files changed, 144 insertions(+), 37 deletions(-) create mode 100644 src/spatialdata/models/chunks_utils.py create mode 100644 tests/models/test_chunks_utils.py diff --git a/src/spatialdata/datasets.py b/src/spatialdata/datasets.py index 8dfd9bf3..de144be7 100644 --- a/src/spatialdata/datasets.py +++ b/src/spatialdata/datasets.py @@ -151,10 +151,10 @@ def blobs( """Blobs dataset.""" image = self._image_blobs(self.transformations, self.length, self.n_channels, self.c_coords) multiscale_image = self._image_blobs( - self.transformations, self.length, self.n_channels, self.c_coords, multiscale=True + self.transformations, self.length, self.n_channels, self.c_coords, scale_factors=[2, 2] ) labels = self._labels_blobs(self.transformations, self.length) - multiscale_labels = self._labels_blobs(self.transformations, self.length, multiscale=True) + multiscale_labels = self._labels_blobs(self.transformations, self.length, scale_factors=[2, 2]) points = self._points_blobs(self.transformations, self.length, self.n_points) circles = self._circles_blobs(self.transformations, self.length, self.n_shapes) polygons = self._polygons_blobs(self.transformations, self.length, self.n_shapes) @@ -179,7 +179,7 @@ def _image_blobs( length: int = 512, n_channels: int = 3, c_coords: str | list[str] | None = None, - multiscale: bool = False, + scale_factors: list[int] | None = None, ndim: int = 2, ) -> DataArray | DataTree: masks = [] @@ -196,15 +196,17 @@ def _image_blobs( else: dims = ["c", "z", "y", "x"] model = Image3DModel - if not multiscale: + if scale_factors is None: return model.parse(x, transformations=transformations, dims=dims, c_coords=c_coords) - return model.parse(x, transformations=transformations, dims=dims, c_coords=c_coords, scale_factors=[2, 2]) + return model.parse( + x, transformations=transformations, dims=dims, c_coords=c_coords, scale_factors=scale_factors + ) def _labels_blobs( self, transformations: dict[str, Any] | None = None, length: int = 512, - multiscale: bool = False, + scale_factors: list[int] | None = None, ndim: int = 2, ) -> DataArray | DataTree: """Create labels in 2D or 3D.""" @@ -237,9 +239,9 @@ def _labels_blobs( else: dims = ["z", "y", "x"] model = Labels3DModel - if not multiscale: + if scale_factors is None: return model.parse(out, transformations=transformations, dims=dims) - return model.parse(out, transformations=transformations, dims=dims, scale_factors=[2, 2]) + return model.parse(out, transformations=transformations, dims=dims, scale_factors=scale_factors) def _generate_blobs(self, length: int = 512, seed: int | None = None, ndim: int = 2) -> ArrayLike: from scipy.ndimage import gaussian_filter diff --git a/src/spatialdata/models/chunks_utils.py b/src/spatialdata/models/chunks_utils.py new file mode 100644 index 00000000..adda82ba --- /dev/null +++ b/src/spatialdata/models/chunks_utils.py @@ -0,0 +1,60 @@ +from collections.abc import Mapping, Sequence +from typing import Any, TypeAlias + +Chunks_t: TypeAlias = int | tuple[int, ...] | tuple[tuple[int, ...], ...] | Mapping[Any, None | int | tuple[int, ...]] + + +def normalize_chunks( + chunks: Chunks_t, + axes: Sequence[str], +) -> dict[str, None | int | tuple[int, ...]]: + """Normalize chunk specification to dict format. + + This function converts various chunk formats to a dict mapping dimension names + to chunk sizes. The dict format is preferred because it's explicit about which + dimension gets which chunk size. + + Parameters + ---------- + chunks + Chunk specification. Can be: + - int: Applied to all axes + - tuple[int, ...]: Chunk sizes in order corresponding to axes + - tuple[tuple[int, ...], ...]: Explicit per-block chunk sizes per axis + - dict: Mapping of axis names to chunk sizes. Values can be: + - int: uniform chunk size for that axis + - tuple[int, ...]: explicit per-block chunk sizes + - None: keep existing chunks / use full dimension (dask semantics) + axes + Tuple of axis names that defines the expected dimensions (e.g., ('c', 'y', 'x')). + + Returns + ------- + dict[str, None | int | tuple[int, ...]] + Dict mapping axis names to chunk sizes. ``None`` values are preserved + with dask semantics (keep existing chunks in ``rechunk``, or use full + dimension size in array creation). + + Raises + ------ + ValueError + If chunks format is not supported or incompatible with axes. + """ + if isinstance(chunks, int): + return dict.fromkeys(axes, chunks) + + if isinstance(chunks, Mapping): + chunks_dict = dict(chunks) + missing = set(axes) - set(chunks_dict.keys()) + if missing: + raise ValueError(f"chunks dict missing keys for axes {missing}, got: {list(chunks_dict.keys())}") + return {ax: chunks_dict[ax] for ax in axes} + + if isinstance(chunks, tuple): + if len(chunks) != len(axes): + raise ValueError(f"chunks tuple length {len(chunks)} doesn't match axes {axes} (length {len(axes)})") + if not all(isinstance(c, (int, tuple)) for c in chunks): + raise ValueError(f"All elements in chunks tuple must be int or tuple[int, ...], got: {chunks}") + return dict(zip(axes, chunks, strict=True)) # type: ignore[arg-type] + + raise ValueError(f"Unsupported chunks type: {type(chunks)}. Expected int, tuple, dict, or None.") diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index 4de58136..3fe0a272 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -38,8 +38,9 @@ _validate_mapping_to_coordinate_system_type, convert_region_column_to_categorical, ) -from spatialdata.models.pyramids_utils import Chunks_t, ScaleFactors_t -from spatialdata.models.pyramids_utils import to_multiscale as to_multiscale_ozp # ozp -> ome-zarr-py +from spatialdata.models.chunks_utils import Chunks_t +from spatialdata.models.pyramids_utils import ScaleFactors_t # ozp -> ome-zarr-py +from spatialdata.models.pyramids_utils import to_multiscale as to_multiscale_ozp from spatialdata.transformations._utils import ( _get_transformations, _set_transformations, @@ -47,6 +48,8 @@ ) from spatialdata.transformations.transformations import Identity +__all__ = ["Chunks_t", "ScaleFactors_t"] + ATTRS_KEY = "spatialdata_attrs" diff --git a/src/spatialdata/models/pyramids_utils.py b/src/spatialdata/models/pyramids_utils.py index ec68facb..8a0a0ba1 100644 --- a/src/spatialdata/models/pyramids_utils.py +++ b/src/spatialdata/models/pyramids_utils.py @@ -1,11 +1,12 @@ -from collections.abc import Mapping, Sequence -from typing import Any, TypeAlias +from collections.abc import Sequence +from typing import Any import dask.array as da from ome_zarr.dask_utils import resize from xarray import DataArray, Dataset, DataTree -Chunks_t: TypeAlias = int | tuple[int, ...] | tuple[tuple[int, ...], ...] | Mapping[Any, None | int | tuple[int, ...]] +from spatialdata.models.chunks_utils import Chunks_t, normalize_chunks + ScaleFactors_t = Sequence[dict[str, int] | int] @@ -30,6 +31,8 @@ def dask_arrays_to_datatree( ------- DataTree with one child per scale level. """ + if "c" in dims and channels is None: + raise ValueError("channels must be provided if the image has a channel dimension") coords = {"c": channels} if channels is not None else {} d = {} for i, arr in enumerate(arrays): @@ -59,9 +62,13 @@ def to_multiscale( Makes uses of internal ome-zarr-py APIs for dask downscaling. - TODO: ome-zarr-py will support 3D downscaling once https://github.com/ome/ome-zarr-py/pull/516 is merged, and this - function could make use of it. Also the PR will introduce new downscaling methods such as "nearest". Nevertheless, - this function supports different scaling factors per axis, which is not supported by ome-zarr-py yet. + ome-zarr-py will support 3D downscaling once https://github.com/ome/ome-zarr-py/pull/516 is merged, and this + function could make use of it. Also the PR will introduce new downscaling methods such as "nearest". Nevertheless, + this function supports different scaling factors per axis, a feature that could be also added to ome-zarr-py. + + TODO: once the PR above is merged, use the new APIs for 3D downscaling and additional downscaling methods + TODO: once the PR above is merged, consider adding support for per-axis scale factors to ome-zarr-py so that this + function can be simplified even further. Parameters ---------- @@ -109,7 +116,7 @@ def to_multiscale( ) pyramid.append(resized.astype(prev.dtype)) if chunks is not None: - if isinstance(chunks, Mapping): - chunks = {dims.index(k) if isinstance(k, str) else k: v for k, v in chunks.items()} - pyramid = [arr.rechunk(chunks) for arr in pyramid] + chunks_dict = normalize_chunks(chunks, axes=dims) + chunks_tuple = tuple(chunks_dict[d] for d in dims) + pyramid = [arr.rechunk(chunks_tuple) for arr in pyramid] return dask_arrays_to_datatree(pyramid, dims=dims, channels=channels) diff --git a/tests/models/test_chunks_utils.py b/tests/models/test_chunks_utils.py new file mode 100644 index 00000000..529e6946 --- /dev/null +++ b/tests/models/test_chunks_utils.py @@ -0,0 +1,45 @@ +import pytest + +from spatialdata.models.chunks_utils import Chunks_t, normalize_chunks + + +@pytest.mark.parametrize( + "chunks, axes, expected", + [ + # 2D (y, x) + (256, ("y", "x"), {"y": 256, "x": 256}), + ((200, 100), ("x", "y"), {"y": 100, "x": 200}), + ({"y": 300, "x": 400}, ("x", "y"), {"y": 300, "x": 400}), + # 2D with channel (c, y, x) + (256, ("c", "y", "x"), {"c": 256, "y": 256, "x": 256}), + ((1, 100, 200), ("c", "y", "x"), {"c": 1, "y": 100, "x": 200}), + ({"c": 1, "y": 300, "x": 400}, ("c", "y", "x"), {"c": 1, "y": 300, "x": 400}), + # 3D (z, y, x) + ((10, 100, 200), ("z", "y", "x"), {"z": 10, "y": 100, "x": 200}), + ({"z": 10, "y": 300, "x": 400}, ("z", "y", "x"), {"z": 10, "y": 300, "x": 400}), + # Mapping with None values (passed through with dask semantics: keep existing / full dimension) + ({"y": None, "x": 400}, ("y", "x"), {"y": None, "x": 400}), + ({"c": None, "y": None, "x": None}, ("c", "y", "x"), {"c": None, "y": None, "x": None}), + # Mapping with tuple[int, ...] values (explicit per-block chunk sizes) + ({"y": (256, 256, 128), "x": 512}, ("y", "x"), {"y": (256, 256, 128), "x": 512}), + ({"c": 1, "y": (100, 100), "x": (200, 50)}, ("c", "y", "x"), {"c": 1, "y": (100, 100), "x": (200, 50)}), + # Tuple of tuples (explicit per-block chunk sizes per axis) + (((256, 256, 128), (512, 512)), ("y", "x"), {"y": (256, 256, 128), "x": (512, 512)}), + ], +) +def test_normalize_chunks_valid(chunks: Chunks_t, axes: tuple[str, ...], expected: dict[str, int]) -> None: + assert normalize_chunks(chunks, axes=axes) == expected + + +@pytest.mark.parametrize( + "chunks, axes, match", + [ + ({"y": 100}, ("y", "x"), "missing keys for axes"), + ((1, 2, 3), ("y", "x"), "doesn't match axes"), + ((1.5, 2), ("y", "x"), "must be int or tuple"), + ("invalid", ("y", "x"), "Unsupported chunks type"), + ], +) +def test_normalize_chunks_errors(chunks: Chunks_t, axes: tuple[str, ...], match: str) -> None: + with pytest.raises(ValueError, match=match): + normalize_chunks(chunks, axes=axes) diff --git a/tests/models/test_pyramids_utils.py b/tests/models/test_pyramids_utils.py index fa49decc..14a45fe6 100644 --- a/tests/models/test_pyramids_utils.py +++ b/tests/models/test_pyramids_utils.py @@ -1,4 +1,3 @@ -import dask import numpy as np import pytest from multiscale_spatial_image.to_multiscale.to_multiscale import Methods @@ -12,31 +11,22 @@ @pytest.mark.parametrize( ("model", "length", "ndim", "n_channels", "scale_factors", "method"), [ - (Image2DModel, 128, 2, 3, (2, 2), Methods.XARRAY_COARSEN), - (Image3DModel, 32, 3, 3, (2, 2), Methods.XARRAY_COARSEN), - (Labels2DModel, 128, 2, 0, (2, 2), Methods.DASK_IMAGE_NEAREST), - (Labels3DModel, 32, 3, 0, (2, 2), Methods.DASK_IMAGE_NEAREST), + (Image2DModel, 128, 2, 3, [2, 2], Methods.XARRAY_COARSEN), + (Image3DModel, 32, 3, 3, [2, 2], Methods.XARRAY_COARSEN), + (Labels2DModel, 128, 2, 0, [2, 2], Methods.DASK_IMAGE_NEAREST), + (Labels3DModel, 32, 3, 0, [2, 2], Methods.DASK_IMAGE_NEAREST), ], ) def test_to_multiscale_via_ome_zarr_scaler(model, length, ndim, n_channels, scale_factors, method): blob_gen = BlobsDataset() - if n_channels > 0: - # Image: stack multiple blob channels - masks = [] - for i in range(n_channels): - mask = blob_gen._generate_blobs(length=length, seed=i, ndim=ndim) - mask = (mask - mask.min()) / np.ptp(mask) - masks.append(mask) - array = np.stack(masks, axis=0) + if model in [Image2DModel, Image3DModel]: + array = blob_gen._image_blobs(length=length, n_channels=n_channels, ndim=ndim).data else: - # Labels: threshold blob pattern to get integer labels - mask = blob_gen._generate_blobs(length=length, ndim=ndim) - threshold = np.percentile(mask, 70) - array = (mask >= threshold).astype(np.int64) + array = blob_gen._labels_blobs(length=length, ndim=ndim).data dims = model.dims - dask_data = dask.array.from_array(array).rechunk(CHUNK_SIZE) + dask_data = array.rechunk(CHUNK_SIZE) # multiscale-spatial-image path (explicit method) result_msi = model.parse(dask_data, dims=dims, scale_factors=scale_factors, chunks=CHUNK_SIZE, method=method) From d8e66ed81e361ed372f0cb143281de3ca01eba1f Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Wed, 11 Feb 2026 15:25:05 +0100 Subject: [PATCH 5/7] small code/comments cleanup --- src/spatialdata/models/chunks_utils.py | 6 +++--- tests/models/test_chunks_utils.py | 2 +- tests/models/test_pyramids_utils.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/spatialdata/models/chunks_utils.py b/src/spatialdata/models/chunks_utils.py index adda82ba..14ac208e 100644 --- a/src/spatialdata/models/chunks_utils.py +++ b/src/spatialdata/models/chunks_utils.py @@ -24,7 +24,7 @@ def normalize_chunks( - dict: Mapping of axis names to chunk sizes. Values can be: - int: uniform chunk size for that axis - tuple[int, ...]: explicit per-block chunk sizes - - None: keep existing chunks / use full dimension (dask semantics) + - None: keep existing chunks (or use full dimension when no chunks were available) axes Tuple of axis names that defines the expected dimensions (e.g., ('c', 'y', 'x')). @@ -32,8 +32,8 @@ def normalize_chunks( ------- dict[str, None | int | tuple[int, ...]] Dict mapping axis names to chunk sizes. ``None`` values are preserved - with dask semantics (keep existing chunks in ``rechunk``, or use full - dimension size in array creation). + with dask semantics (keep existing chunks, or use full dimension size if chunks + where not available and are being created). Raises ------ diff --git a/tests/models/test_chunks_utils.py b/tests/models/test_chunks_utils.py index 529e6946..5bfcfa81 100644 --- a/tests/models/test_chunks_utils.py +++ b/tests/models/test_chunks_utils.py @@ -17,7 +17,7 @@ # 3D (z, y, x) ((10, 100, 200), ("z", "y", "x"), {"z": 10, "y": 100, "x": 200}), ({"z": 10, "y": 300, "x": 400}, ("z", "y", "x"), {"z": 10, "y": 300, "x": 400}), - # Mapping with None values (passed through with dask semantics: keep existing / full dimension) + # Mapping with None values (passed through) ({"y": None, "x": 400}, ("y", "x"), {"y": None, "x": 400}), ({"c": None, "y": None, "x": None}, ("c", "y", "x"), {"c": None, "y": None, "x": None}), # Mapping with tuple[int, ...] values (explicit per-block chunk sizes) diff --git a/tests/models/test_pyramids_utils.py b/tests/models/test_pyramids_utils.py index 14a45fe6..eb9dab65 100644 --- a/tests/models/test_pyramids_utils.py +++ b/tests/models/test_pyramids_utils.py @@ -28,10 +28,10 @@ def test_to_multiscale_via_ome_zarr_scaler(model, length, ndim, n_channels, scal dims = model.dims dask_data = array.rechunk(CHUNK_SIZE) - # multiscale-spatial-image path (explicit method) + # multiscale-spatial-image (method is not None) result_msi = model.parse(dask_data, dims=dims, scale_factors=scale_factors, chunks=CHUNK_SIZE, method=method) - # ome-zarr-py scaler path (method=None triggers the ome-zarr-py scaler) + # ome-zarr-py scaler (method=None triggers the ome-zarr-py scaler) result_ozp = model.parse(dask_data, dims=dims, scale_factors=scale_factors, chunks=CHUNK_SIZE) # Compare data values at each scale level From 0620cca68c327efd4c5bc2d1386bf19c403241a5 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Wed, 11 Feb 2026 16:00:34 +0100 Subject: [PATCH 6/7] better test asserts; testing also per-axis scaling; plotting (to be removed) --- tests/models/test_pyramids_utils.py | 60 +++++++++++++++++++++-------- 1 file changed, 45 insertions(+), 15 deletions(-) diff --git a/tests/models/test_pyramids_utils.py b/tests/models/test_pyramids_utils.py index eb9dab65..43d89cb5 100644 --- a/tests/models/test_pyramids_utils.py +++ b/tests/models/test_pyramids_utils.py @@ -11,10 +11,10 @@ @pytest.mark.parametrize( ("model", "length", "ndim", "n_channels", "scale_factors", "method"), [ - (Image2DModel, 128, 2, 3, [2, 2], Methods.XARRAY_COARSEN), - (Image3DModel, 32, 3, 3, [2, 2], Methods.XARRAY_COARSEN), + (Image2DModel, 128, 2, 3, [2, 3], Methods.XARRAY_COARSEN), + (Image3DModel, 32, 3, 3, [{"x": 3, "y": 2, "z": 1}, {"x": 1, "y": 2, "z": 2}], Methods.XARRAY_COARSEN), (Labels2DModel, 128, 2, 0, [2, 2], Methods.DASK_IMAGE_NEAREST), - (Labels3DModel, 32, 3, 0, [2, 2], Methods.DASK_IMAGE_NEAREST), + (Labels3DModel, 32, 3, 0, [{"x": 2, "y": 2, "z": 3}, 2], Methods.DASK_IMAGE_NEAREST), ], ) def test_to_multiscale_via_ome_zarr_scaler(model, length, ndim, n_channels, scale_factors, method): @@ -34,21 +34,51 @@ def test_to_multiscale_via_ome_zarr_scaler(model, length, ndim, n_channels, scal # ome-zarr-py scaler (method=None triggers the ome-zarr-py scaler) result_ozp = model.parse(dask_data, dims=dims, scale_factors=scale_factors, chunks=CHUNK_SIZE) - # Compare data values at each scale level + # Compare data values and plot each scale level + import matplotlib.pyplot as plt + + n_scales = len(result_msi.children) + fig, axes = plt.subplots(n_scales, 2, figsize=(8, 4 * n_scales), squeeze=False) + fig.suptitle(f"{model.__name__} scale_factors={scale_factors}", fontsize=12) + axes[0, 0].set_title("multiscale-spatial-image") + axes[0, 1].set_title("ome-zarr-py") + for i, scale_name in enumerate(result_msi.children): msi_arr = result_msi[scale_name].ds["image"] ozp_arr = result_ozp[scale_name].ds["image"] assert msi_arr.sizes == ozp_arr.sizes - if model in [Image2DModel, Image3DModel]: - # exact comparison for images - np.testing.assert_allclose(msi_arr.values, ozp_arr.values) + + if i == 0: + # scale0 is the original data, must be identical + np.testing.assert_array_equal(msi_arr.values, ozp_arr.values) else: - if i == 0: - # no downscaling is performed, so they must be equal - np.testing.assert_array_equal(msi_arr.values, ozp_arr.values) - else: - # we expect differences: ngff-zarr uses "nearest", ozp uses "resize" - # TODO: when https://github.com/ome/ome-zarr-py/pull/516 is merged we can use nearest for labels and - # expect a much stricter adherence + if model in (Labels2DModel, Labels3DModel): + # labels use different nearest-like methods; expect <50% non-identical entries fraction_non_equal = np.sum(msi_arr.values != ozp_arr.values) / np.prod(msi_arr.values.shape) - assert fraction_non_equal < 0.5 + assert fraction_non_equal < 0.5, ( + f"{scale_name}: {fraction_non_equal:.1%} non-identical entries (expected <50%)" + ) + else: + # images use fundamentally different algorithms (coarsen vs spline interpolation); + # just check that the value ranges are similar + msi_vals, ozp_vals = msi_arr.values, ozp_arr.values + np.testing.assert_allclose(msi_vals.mean(), ozp_vals.mean(), rtol=0.5) + np.testing.assert_allclose(msi_vals.std(), ozp_vals.std(), rtol=0.5) + + # Select a 2D slice for plotting + msi_plot = msi_arr + ozp_plot = ozp_arr + if msi_plot.ndim == 4: + msi_plot = msi_plot[0, 0] + ozp_plot = ozp_plot[0, 0] + elif msi_plot.ndim == 3: + msi_plot = msi_plot[0] + ozp_plot = ozp_plot[0] + + shape_str = "x".join(str(s) for s in msi_arr.shape) + axes[i, 0].imshow(msi_plot.values) + axes[i, 0].set_ylabel(f"{scale_name}\n{shape_str}", rotation=0, ha="right", va="center") + axes[i, 1].imshow(ozp_plot.values) + + plt.tight_layout() + plt.show() From 0d4ae59dc2cff58f63ec6a1f74a832c751bbd05f Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Wed, 11 Feb 2026 16:04:41 +0100 Subject: [PATCH 7/7] remove plotting code --- tests/models/test_pyramids_utils.py | 27 --------------------------- 1 file changed, 27 deletions(-) diff --git a/tests/models/test_pyramids_utils.py b/tests/models/test_pyramids_utils.py index 43d89cb5..da0a3354 100644 --- a/tests/models/test_pyramids_utils.py +++ b/tests/models/test_pyramids_utils.py @@ -34,15 +34,6 @@ def test_to_multiscale_via_ome_zarr_scaler(model, length, ndim, n_channels, scal # ome-zarr-py scaler (method=None triggers the ome-zarr-py scaler) result_ozp = model.parse(dask_data, dims=dims, scale_factors=scale_factors, chunks=CHUNK_SIZE) - # Compare data values and plot each scale level - import matplotlib.pyplot as plt - - n_scales = len(result_msi.children) - fig, axes = plt.subplots(n_scales, 2, figsize=(8, 4 * n_scales), squeeze=False) - fig.suptitle(f"{model.__name__} scale_factors={scale_factors}", fontsize=12) - axes[0, 0].set_title("multiscale-spatial-image") - axes[0, 1].set_title("ome-zarr-py") - for i, scale_name in enumerate(result_msi.children): msi_arr = result_msi[scale_name].ds["image"] ozp_arr = result_ozp[scale_name].ds["image"] @@ -64,21 +55,3 @@ def test_to_multiscale_via_ome_zarr_scaler(model, length, ndim, n_channels, scal msi_vals, ozp_vals = msi_arr.values, ozp_arr.values np.testing.assert_allclose(msi_vals.mean(), ozp_vals.mean(), rtol=0.5) np.testing.assert_allclose(msi_vals.std(), ozp_vals.std(), rtol=0.5) - - # Select a 2D slice for plotting - msi_plot = msi_arr - ozp_plot = ozp_arr - if msi_plot.ndim == 4: - msi_plot = msi_plot[0, 0] - ozp_plot = ozp_plot[0, 0] - elif msi_plot.ndim == 3: - msi_plot = msi_plot[0] - ozp_plot = ozp_plot[0] - - shape_str = "x".join(str(s) for s in msi_arr.shape) - axes[i, 0].imshow(msi_plot.values) - axes[i, 0].set_ylabel(f"{scale_name}\n{shape_str}", rotation=0, ha="right", va="center") - axes[i, 1].imshow(ozp_plot.values) - - plt.tight_layout() - plt.show()