diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index bc8206db..86b537d0 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -13,7 +13,7 @@ from ome_zarr.writer import write_labels as write_labels_ngff from ome_zarr.writer import write_multiscale as write_multiscale_ngff from ome_zarr.writer import write_multiscale_labels as write_multiscale_labels_ngff -from xarray import DataArray, Dataset, DataTree +from xarray import DataArray, DataTree from spatialdata._io._utils import ( _get_transformations_from_ngff_dict, @@ -27,6 +27,7 @@ from spatialdata._utils import get_pyramid_levels from spatialdata.models._utils import get_channel_names from spatialdata.models.models import ATTRS_KEY +from spatialdata.models.pyramids_utils import dask_arrays_to_datatree from spatialdata.transformations._utils import ( _get_transformations, _get_transformations_xarray, @@ -91,20 +92,8 @@ def _read_multiscale( channels = [d["label"] for d in omero_metadata["channels"]] axes = [i["name"] for i in node.metadata["axes"]] if len(datasets) > 1: - multiscale_image = {} - for i, d in enumerate(datasets): - data = node.load(Multiscales).array(resolution=d) - multiscale_image[f"scale{i}"] = Dataset( - { - "image": DataArray( - data, - name="image", - dims=axes, - coords={"c": channels} if channels is not None else {}, - ) - } - ) - msi = DataTree.from_dict(multiscale_image) + arrays = [node.load(Multiscales).array(resolution=d) for d in datasets] + msi = dask_arrays_to_datatree(arrays, dims=axes, channels=channels) _set_transformations(msi, transformations) return compute_coordinates(msi) diff --git a/src/spatialdata/datasets.py b/src/spatialdata/datasets.py index ea38d739..de144be7 100644 --- a/src/spatialdata/datasets.py +++ b/src/spatialdata/datasets.py @@ -20,7 +20,15 @@ from spatialdata._core.query.relational_query import get_element_instances from spatialdata._core.spatialdata import SpatialData from spatialdata._types import ArrayLike -from spatialdata.models import Image2DModel, Labels2DModel, PointsModel, ShapesModel, TableModel +from spatialdata.models import ( + Image2DModel, + Image3DModel, + Labels2DModel, + Labels3DModel, + PointsModel, + ShapesModel, + TableModel, +) from spatialdata.transformations import Identity __all__ = ["blobs", "raccoon"] @@ -143,10 +151,10 @@ def blobs( """Blobs dataset.""" image = self._image_blobs(self.transformations, self.length, self.n_channels, self.c_coords) multiscale_image = self._image_blobs( - self.transformations, self.length, self.n_channels, self.c_coords, multiscale=True + self.transformations, self.length, self.n_channels, self.c_coords, scale_factors=[2, 2] ) labels = self._labels_blobs(self.transformations, self.length) - multiscale_labels = self._labels_blobs(self.transformations, self.length, multiscale=True) + multiscale_labels = self._labels_blobs(self.transformations, self.length, scale_factors=[2, 2]) points = self._points_blobs(self.transformations, self.length, self.n_points) circles = self._circles_blobs(self.transformations, self.length, self.n_shapes) polygons = self._polygons_blobs(self.transformations, self.length, self.n_shapes) @@ -171,38 +179,51 @@ def _image_blobs( length: int = 512, n_channels: int = 3, c_coords: str | list[str] | None = None, - multiscale: bool = False, + scale_factors: list[int] | None = None, + ndim: int = 2, ) -> DataArray | DataTree: masks = [] for i in range(n_channels): - mask = self._generate_blobs(length=length, seed=i) + mask = self._generate_blobs(length=length, seed=i, ndim=ndim) mask = (mask - mask.min()) / np.ptp(mask) masks.append(mask) x = np.stack(masks, axis=0) - dims = ["c", "y", "x"] - if not multiscale: - return Image2DModel.parse(x, transformations=transformations, dims=dims, c_coords=c_coords) - return Image2DModel.parse( - x, transformations=transformations, dims=dims, c_coords=c_coords, scale_factors=[2, 2] + model: type[Image2DModel] | type[Image3DModel] + if ndim == 2: + dims = ["c", "y", "x"] + model = Image2DModel + else: + dims = ["c", "z", "y", "x"] + model = Image3DModel + if scale_factors is None: + return model.parse(x, transformations=transformations, dims=dims, c_coords=c_coords) + return model.parse( + x, transformations=transformations, dims=dims, c_coords=c_coords, scale_factors=scale_factors ) def _labels_blobs( - self, transformations: dict[str, Any] | None = None, length: int = 512, multiscale: bool = False + self, + transformations: dict[str, Any] | None = None, + length: int = 512, + scale_factors: list[int] | None = None, + ndim: int = 2, ) -> DataArray | DataTree: - """Create a 2D labels.""" + """Create labels in 2D or 3D.""" from scipy.ndimage import watershed_ift # from skimage - mask = self._generate_blobs(length=length) + mask = self._generate_blobs(length=length, ndim=ndim) threshold = np.percentile(mask, 100 * (1 - 0.3)) inputs = np.logical_not(mask < threshold).astype(np.uint8) # use watershed from scipy - xm, ym = np.ogrid[0:length:10, 0:length:10] + grid = np.ogrid[tuple(slice(0, length, 10) for _ in range(ndim))] markers = np.zeros_like(inputs).astype(np.int16) - markers[xm, ym] = np.arange(xm.size * ym.size).reshape((xm.size, ym.size)) + grid_shape = tuple(g.size for g in grid) + markers[tuple(grid)] = np.arange(np.prod(grid_shape)).reshape(grid_shape) out = watershed_ift(inputs, markers) - out[xm, ym] = out[xm - 1, ym - 1] # remove the isolate seeds + shifted = tuple(g - 1 for g in grid) + out[tuple(grid)] = out[tuple(shifted)] # remove the isolated seeds # reindex by frequency val, counts = np.unique(out, return_counts=True) sorted_idx = np.argsort(counts) @@ -211,20 +232,26 @@ def _labels_blobs( out[out == val[idx]] = 0 else: out[out == val[idx]] = i - dims = ["y", "x"] - if not multiscale: - return Labels2DModel.parse(out, transformations=transformations, dims=dims) - return Labels2DModel.parse(out, transformations=transformations, dims=dims, scale_factors=[2, 2]) - - def _generate_blobs(self, length: int = 512, seed: int | None = None) -> ArrayLike: + model: type[Labels2DModel] | type[Labels3DModel] + if ndim == 2: + dims = ["y", "x"] + model = Labels2DModel + else: + dims = ["z", "y", "x"] + model = Labels3DModel + if scale_factors is None: + return model.parse(out, transformations=transformations, dims=dims) + return model.parse(out, transformations=transformations, dims=dims, scale_factors=scale_factors) + + def _generate_blobs(self, length: int = 512, seed: int | None = None, ndim: int = 2) -> ArrayLike: from scipy.ndimage import gaussian_filter rng = default_rng(42) if seed is None else default_rng(seed) # from skimage - shape = tuple([length] * 2) + shape = (length,) * ndim mask = np.zeros(shape) - n_pts = max(int(1.0 / 0.1) ** 2, 1) - points = (length * rng.random((2, n_pts))).astype(int) + n_pts = max(int(1.0 / 0.1) ** ndim, 1) + points = (length * rng.random((ndim, n_pts))).astype(int) mask[tuple(indices for indices in points)] = 1 mask = gaussian_filter(mask, sigma=0.25 * length * 0.1) assert isinstance(mask, np.ndarray) diff --git a/src/spatialdata/models/chunks_utils.py b/src/spatialdata/models/chunks_utils.py new file mode 100644 index 00000000..14ac208e --- /dev/null +++ b/src/spatialdata/models/chunks_utils.py @@ -0,0 +1,60 @@ +from collections.abc import Mapping, Sequence +from typing import Any, TypeAlias + +Chunks_t: TypeAlias = int | tuple[int, ...] | tuple[tuple[int, ...], ...] | Mapping[Any, None | int | tuple[int, ...]] + + +def normalize_chunks( + chunks: Chunks_t, + axes: Sequence[str], +) -> dict[str, None | int | tuple[int, ...]]: + """Normalize chunk specification to dict format. + + This function converts various chunk formats to a dict mapping dimension names + to chunk sizes. The dict format is preferred because it's explicit about which + dimension gets which chunk size. + + Parameters + ---------- + chunks + Chunk specification. Can be: + - int: Applied to all axes + - tuple[int, ...]: Chunk sizes in order corresponding to axes + - tuple[tuple[int, ...], ...]: Explicit per-block chunk sizes per axis + - dict: Mapping of axis names to chunk sizes. Values can be: + - int: uniform chunk size for that axis + - tuple[int, ...]: explicit per-block chunk sizes + - None: keep existing chunks (or use full dimension when no chunks were available) + axes + Tuple of axis names that defines the expected dimensions (e.g., ('c', 'y', 'x')). + + Returns + ------- + dict[str, None | int | tuple[int, ...]] + Dict mapping axis names to chunk sizes. ``None`` values are preserved + with dask semantics (keep existing chunks, or use full dimension size if chunks + where not available and are being created). + + Raises + ------ + ValueError + If chunks format is not supported or incompatible with axes. + """ + if isinstance(chunks, int): + return dict.fromkeys(axes, chunks) + + if isinstance(chunks, Mapping): + chunks_dict = dict(chunks) + missing = set(axes) - set(chunks_dict.keys()) + if missing: + raise ValueError(f"chunks dict missing keys for axes {missing}, got: {list(chunks_dict.keys())}") + return {ax: chunks_dict[ax] for ax in axes} + + if isinstance(chunks, tuple): + if len(chunks) != len(axes): + raise ValueError(f"chunks tuple length {len(chunks)} doesn't match axes {axes} (length {len(axes)})") + if not all(isinstance(c, (int, tuple)) for c in chunks): + raise ValueError(f"All elements in chunks tuple must be int or tuple[int, ...], got: {chunks}") + return dict(zip(axes, chunks, strict=True)) # type: ignore[arg-type] + + raise ValueError(f"Unsupported chunks type: {type(chunks)}. Expected int, tuple, dict, or None.") diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index 2d54c709..3fe0a272 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -14,7 +14,7 @@ from dask.array.core import from_array from dask.dataframe import DataFrame as DaskDataFrame from geopandas import GeoDataFrame, GeoSeries -from multiscale_spatial_image import to_multiscale +from multiscale_spatial_image import to_multiscale as to_multiscale_msi from multiscale_spatial_image.to_multiscale.to_multiscale import Methods from pandas import CategoricalDtype from shapely._geometry import GeometryType @@ -38,6 +38,9 @@ _validate_mapping_to_coordinate_system_type, convert_region_column_to_categorical, ) +from spatialdata.models.chunks_utils import Chunks_t +from spatialdata.models.pyramids_utils import ScaleFactors_t # ozp -> ome-zarr-py +from spatialdata.models.pyramids_utils import to_multiscale as to_multiscale_ozp from spatialdata.transformations._utils import ( _get_transformations, _set_transformations, @@ -45,9 +48,7 @@ ) from spatialdata.transformations.transformations import Identity -# Types -Chunks_t: TypeAlias = int | tuple[int, ...] | tuple[tuple[int, ...], ...] | Mapping[Any, None | int | tuple[int, ...]] -ScaleFactors_t = Sequence[dict[str, int] | int] +__all__ = ["Chunks_t", "ScaleFactors_t"] ATTRS_KEY = "spatialdata_attrs" @@ -225,12 +226,19 @@ def parse( chunks = {dim: chunks[index] for index, dim in enumerate(data.dims)} if isinstance(chunks, float): chunks = {dim: chunks for index, dim in data.dims} - data = to_multiscale( - data, - scale_factors=scale_factors, - method=method, - chunks=chunks, - ) + if method is not None: + data = to_multiscale_msi( + data, + scale_factors=scale_factors, + method=method, + chunks=chunks, + ) + else: + data = to_multiscale_ozp( + data, + scale_factors=scale_factors, + chunks=chunks, + ) _parse_transformations(data, parsed_transform) else: # Chunk single scale images @@ -375,9 +383,6 @@ def parse( # noqa: D102 ) -> DataArray | DataTree: if kwargs.get("c_coords") is not None: raise ValueError("`c_coords` is not supported for labels") - if kwargs.get("scale_factors") is not None and kwargs.get("method") is None: - # Override default scaling method to preserve labels - kwargs["method"] = Methods.DASK_IMAGE_NEAREST return super().parse(*args, **kwargs) @@ -388,9 +393,6 @@ class Labels3DModel(RasterSchema): def parse(self, *args: Any, **kwargs: Any) -> DataArray | DataTree: # noqa: D102 if kwargs.get("c_coords") is not None: raise ValueError("`c_coords` is not supported for labels") - if kwargs.get("scale_factors") is not None and kwargs.get("method") is None: - # Override default scaling method to preserve labels - kwargs["method"] = Methods.DASK_IMAGE_NEAREST return super().parse(*args, **kwargs) diff --git a/src/spatialdata/models/pyramids_utils.py b/src/spatialdata/models/pyramids_utils.py new file mode 100644 index 00000000..8a0a0ba1 --- /dev/null +++ b/src/spatialdata/models/pyramids_utils.py @@ -0,0 +1,122 @@ +from collections.abc import Sequence +from typing import Any + +import dask.array as da +from ome_zarr.dask_utils import resize +from xarray import DataArray, Dataset, DataTree + +from spatialdata.models.chunks_utils import Chunks_t, normalize_chunks + +ScaleFactors_t = Sequence[dict[str, int] | int] + + +def dask_arrays_to_datatree( + arrays: Sequence[da.Array], + dims: Sequence[str], + channels: list[Any] | None = None, +) -> DataTree: + """Build a multiscale DataTree from a sequence of dask arrays. + + Parameters + ---------- + arrays + Sequence of dask arrays, one per scale level (scale0, scale1, ...). + dims + Dimension names for the arrays (e.g. ``("c", "y", "x")``). + channels + Optional channel coordinate values. If provided, a ``"c"`` coordinate + is added to each scale level. + + Returns + ------- + DataTree with one child per scale level. + """ + if "c" in dims and channels is None: + raise ValueError("channels must be provided if the image has a channel dimension") + coords = {"c": channels} if channels is not None else {} + d = {} + for i, arr in enumerate(arrays): + d[f"scale{i}"] = Dataset( + { + "image": DataArray( + arr, + name="image", + dims=list(dims), + coords=coords, + ) + } + ) + return DataTree.from_dict(d) + + +def to_multiscale( + image: DataArray, + scale_factors: ScaleFactors_t, + chunks: Chunks_t | None = None, +) -> DataTree: + """Build a multiscale pyramid DataTree from a single-scale image. + + Iteratively downscales the image by the given scale factors using + interpolation (order 1 for images with a channel dimension, order 0 + for labels) and assembles all levels into a DataTree. + + Makes uses of internal ome-zarr-py APIs for dask downscaling. + + ome-zarr-py will support 3D downscaling once https://github.com/ome/ome-zarr-py/pull/516 is merged, and this + function could make use of it. Also the PR will introduce new downscaling methods such as "nearest". Nevertheless, + this function supports different scaling factors per axis, a feature that could be also added to ome-zarr-py. + + TODO: once the PR above is merged, use the new APIs for 3D downscaling and additional downscaling methods + TODO: once the PR above is merged, consider adding support for per-axis scale factors to ome-zarr-py so that this + function can be simplified even further. + + Parameters + ---------- + image + Input image/labels as an xarray DataArray (e.g. with dims ``("c", "y", "x")`` + or ``("y", "x")``). Supports both 2D/3D images and 2D/3D labels. + scale_factors + Sequence of per-level scale factors. Each element is either an int + (applied to all spatial axes) or a dict mapping dimension names to + per-axis factors (e.g. ``{"y": 2, "x": 2}``). + chunks + Optional chunk specification passed to :meth:`dask.array.Array.rechunk` + after building the pyramid. + + Returns + ------- + DataTree + Multiscale DataTree with children ``scale0``, ``scale1``, etc. + """ + dims = [str(dim) for dim in image.dims] + spatial_dims = [d for d in dims if d != "c"] + order = 1 if "c" in dims else 0 + channels = None if "c" not in dims else image.coords["c"].values + pyramid = [image.data] + for sf in scale_factors: + prev = pyramid[-1] + # Compute per-axis scale factors: int applies to spatial axes only, dict to specific ones. + sf_by_axis = dict.fromkeys(dims, 1) + if isinstance(sf, int): + sf_by_axis.update(dict.fromkeys(spatial_dims, sf)) + else: + sf_by_axis.update(sf) + # skip axes where the scale factor exceeds the axis size. + for ax, factor in sf_by_axis.items(): + ax_size = prev.shape[dims.index(ax)] + if factor > ax_size: + sf_by_axis[ax] = 1 + output_shape = tuple(prev.shape[dims.index(ax)] // f for ax, f in sf_by_axis.items()) + resized = resize( + image=prev.astype(float), + output_shape=output_shape, + order=order, + mode="reflect", + anti_aliasing=False, + ) + pyramid.append(resized.astype(prev.dtype)) + if chunks is not None: + chunks_dict = normalize_chunks(chunks, axes=dims) + chunks_tuple = tuple(chunks_dict[d] for d in dims) + pyramid = [arr.rechunk(chunks_tuple) for arr in pyramid] + return dask_arrays_to_datatree(pyramid, dims=dims, channels=channels) diff --git a/tests/models/test_chunks_utils.py b/tests/models/test_chunks_utils.py new file mode 100644 index 00000000..5bfcfa81 --- /dev/null +++ b/tests/models/test_chunks_utils.py @@ -0,0 +1,45 @@ +import pytest + +from spatialdata.models.chunks_utils import Chunks_t, normalize_chunks + + +@pytest.mark.parametrize( + "chunks, axes, expected", + [ + # 2D (y, x) + (256, ("y", "x"), {"y": 256, "x": 256}), + ((200, 100), ("x", "y"), {"y": 100, "x": 200}), + ({"y": 300, "x": 400}, ("x", "y"), {"y": 300, "x": 400}), + # 2D with channel (c, y, x) + (256, ("c", "y", "x"), {"c": 256, "y": 256, "x": 256}), + ((1, 100, 200), ("c", "y", "x"), {"c": 1, "y": 100, "x": 200}), + ({"c": 1, "y": 300, "x": 400}, ("c", "y", "x"), {"c": 1, "y": 300, "x": 400}), + # 3D (z, y, x) + ((10, 100, 200), ("z", "y", "x"), {"z": 10, "y": 100, "x": 200}), + ({"z": 10, "y": 300, "x": 400}, ("z", "y", "x"), {"z": 10, "y": 300, "x": 400}), + # Mapping with None values (passed through) + ({"y": None, "x": 400}, ("y", "x"), {"y": None, "x": 400}), + ({"c": None, "y": None, "x": None}, ("c", "y", "x"), {"c": None, "y": None, "x": None}), + # Mapping with tuple[int, ...] values (explicit per-block chunk sizes) + ({"y": (256, 256, 128), "x": 512}, ("y", "x"), {"y": (256, 256, 128), "x": 512}), + ({"c": 1, "y": (100, 100), "x": (200, 50)}, ("c", "y", "x"), {"c": 1, "y": (100, 100), "x": (200, 50)}), + # Tuple of tuples (explicit per-block chunk sizes per axis) + (((256, 256, 128), (512, 512)), ("y", "x"), {"y": (256, 256, 128), "x": (512, 512)}), + ], +) +def test_normalize_chunks_valid(chunks: Chunks_t, axes: tuple[str, ...], expected: dict[str, int]) -> None: + assert normalize_chunks(chunks, axes=axes) == expected + + +@pytest.mark.parametrize( + "chunks, axes, match", + [ + ({"y": 100}, ("y", "x"), "missing keys for axes"), + ((1, 2, 3), ("y", "x"), "doesn't match axes"), + ((1.5, 2), ("y", "x"), "must be int or tuple"), + ("invalid", ("y", "x"), "Unsupported chunks type"), + ], +) +def test_normalize_chunks_errors(chunks: Chunks_t, axes: tuple[str, ...], match: str) -> None: + with pytest.raises(ValueError, match=match): + normalize_chunks(chunks, axes=axes) diff --git a/tests/models/test_pyramids_utils.py b/tests/models/test_pyramids_utils.py new file mode 100644 index 00000000..da0a3354 --- /dev/null +++ b/tests/models/test_pyramids_utils.py @@ -0,0 +1,57 @@ +import numpy as np +import pytest +from multiscale_spatial_image.to_multiscale.to_multiscale import Methods + +from spatialdata.datasets import BlobsDataset +from spatialdata.models import Image2DModel, Image3DModel, Labels2DModel, Labels3DModel + +CHUNK_SIZE = 32 + + +@pytest.mark.parametrize( + ("model", "length", "ndim", "n_channels", "scale_factors", "method"), + [ + (Image2DModel, 128, 2, 3, [2, 3], Methods.XARRAY_COARSEN), + (Image3DModel, 32, 3, 3, [{"x": 3, "y": 2, "z": 1}, {"x": 1, "y": 2, "z": 2}], Methods.XARRAY_COARSEN), + (Labels2DModel, 128, 2, 0, [2, 2], Methods.DASK_IMAGE_NEAREST), + (Labels3DModel, 32, 3, 0, [{"x": 2, "y": 2, "z": 3}, 2], Methods.DASK_IMAGE_NEAREST), + ], +) +def test_to_multiscale_via_ome_zarr_scaler(model, length, ndim, n_channels, scale_factors, method): + blob_gen = BlobsDataset() + + if model in [Image2DModel, Image3DModel]: + array = blob_gen._image_blobs(length=length, n_channels=n_channels, ndim=ndim).data + else: + array = blob_gen._labels_blobs(length=length, ndim=ndim).data + + dims = model.dims + dask_data = array.rechunk(CHUNK_SIZE) + + # multiscale-spatial-image (method is not None) + result_msi = model.parse(dask_data, dims=dims, scale_factors=scale_factors, chunks=CHUNK_SIZE, method=method) + + # ome-zarr-py scaler (method=None triggers the ome-zarr-py scaler) + result_ozp = model.parse(dask_data, dims=dims, scale_factors=scale_factors, chunks=CHUNK_SIZE) + + for i, scale_name in enumerate(result_msi.children): + msi_arr = result_msi[scale_name].ds["image"] + ozp_arr = result_ozp[scale_name].ds["image"] + assert msi_arr.sizes == ozp_arr.sizes + + if i == 0: + # scale0 is the original data, must be identical + np.testing.assert_array_equal(msi_arr.values, ozp_arr.values) + else: + if model in (Labels2DModel, Labels3DModel): + # labels use different nearest-like methods; expect <50% non-identical entries + fraction_non_equal = np.sum(msi_arr.values != ozp_arr.values) / np.prod(msi_arr.values.shape) + assert fraction_non_equal < 0.5, ( + f"{scale_name}: {fraction_non_equal:.1%} non-identical entries (expected <50%)" + ) + else: + # images use fundamentally different algorithms (coarsen vs spline interpolation); + # just check that the value ranges are similar + msi_vals, ozp_vals = msi_arr.values, ozp_arr.values + np.testing.assert_allclose(msi_vals.mean(), ozp_vals.mean(), rtol=0.5) + np.testing.assert_allclose(msi_vals.std(), ozp_vals.std(), rtol=0.5)