diff --git a/pyproject.toml b/pyproject.toml index 0e813ead..89e9e023 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,6 @@ dependencies = [ "typing_extensions>=4.8.0", "universal_pathlib>=0.2.6", "xarray>=2024.10.0", - "xarray-schema", "xarray-spatial>=0.3.5", "zarr>=3.0.0", ] diff --git a/src/spatialdata/_core/_deepcopy.py b/src/spatialdata/_core/_deepcopy.py index 8b8c0b5c..67d709f0 100644 --- a/src/spatialdata/_core/_deepcopy.py +++ b/src/spatialdata/_core/_deepcopy.py @@ -79,7 +79,7 @@ def _(element: DataTree) -> DataTree: msi[key][variable].data = from_array(msi[key][variable].data) element[key][variable].data = from_array(element[key][variable].data) assert model in [Image2DModel, Image3DModel, Labels2DModel, Labels3DModel] - model().validate(msi) + model.validate(msi) return msi diff --git a/src/spatialdata/_core/_elements.py b/src/spatialdata/_core/_elements.py index 99ff9d33..88205c7e 100644 --- a/src/spatialdata/_core/_elements.py +++ b/src/spatialdata/_core/_elements.py @@ -72,10 +72,10 @@ def __setitem__(self, key: str, value: Raster_T) -> None: raise TypeError(f"Unknown element type with schema: {schema!r}.") ndim = len(get_axes_names(value)) if ndim == 3: - Image2DModel().validate(value) + Image2DModel.validate(value) super().__setitem__(key, value) elif ndim == 4: - Image3DModel().validate(value) + Image3DModel.validate(value) super().__setitem__(key, value) else: NotImplementedError("TODO: implement for ndim > 4.") @@ -89,10 +89,10 @@ def __setitem__(self, key: str, value: Raster_T) -> None: raise TypeError(f"Unknown element type with schema: {schema!r}.") ndim = len(get_axes_names(value)) if ndim == 2: - Labels2DModel().validate(value) + Labels2DModel.validate(value) super().__setitem__(key, value) elif ndim == 3: - Labels3DModel().validate(value) + Labels3DModel.validate(value) super().__setitem__(key, value) else: NotImplementedError("TODO: implement for ndim > 3.") @@ -104,7 +104,7 @@ def __setitem__(self, key: str, value: GeoDataFrame) -> None: schema = get_model(value) if schema != ShapesModel: raise TypeError(f"Unknown element type with schema: {schema!r}.") - ShapesModel().validate(value) + ShapesModel.validate(value) super().__setitem__(key, value) @@ -114,7 +114,7 @@ def __setitem__(self, key: str, value: DaskDataFrame) -> None: schema = get_model(value) if schema != PointsModel: raise TypeError(f"Unknown element type with schema: {schema!r}.") - PointsModel().validate(value) + PointsModel.validate(value) super().__setitem__(key, value) @@ -124,5 +124,5 @@ def __setitem__(self, key: str, value: AnnData) -> None: schema = get_model(value) if schema != TableModel: raise TypeError(f"Unknown element type with schema: {schema!r}.") - TableModel().validate(value) + TableModel.validate(value) super().__setitem__(key, value) diff --git a/src/spatialdata/_core/operations/rasterize.py b/src/spatialdata/_core/operations/rasterize.py index 6da0a7cc..12bbb68f 100644 --- a/src/spatialdata/_core/operations/rasterize.py +++ b/src/spatialdata/_core/operations/rasterize.py @@ -602,7 +602,7 @@ def rasterize_images_labels( set_transformation(transformed_data, sequence, target_coordinate_system) transformed_data = compute_coordinates(transformed_data) - schema().validate(transformed_data) + schema.validate(transformed_data) return transformed_data diff --git a/src/spatialdata/_core/operations/transform.py b/src/spatialdata/_core/operations/transform.py index 84258aae..e821edcf 100644 --- a/src/spatialdata/_core/operations/transform.py +++ b/src/spatialdata/_core/operations/transform.py @@ -333,7 +333,7 @@ def _( to_coordinate_system=to_coordinate_system, ) transformed_data = compute_coordinates(transformed_data) - schema().validate(transformed_data) + schema.validate(transformed_data) return transformed_data @@ -419,7 +419,7 @@ def _( to_coordinate_system=to_coordinate_system, ) transformed_data = compute_coordinates(transformed_data) - schema().validate(transformed_data) + schema.validate(transformed_data) return transformed_data diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index c06e62b7..739b225f 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -57,15 +57,6 @@ SpatialDataFormatType, ) -# schema for elements -Label2D_s = Labels2DModel() -Label3D_s = Labels3DModel() -Image2D_s = Image2DModel() -Image3D_s = Image3DModel() -Shape_s = ShapesModel() -Point_s = PointsModel() -Table_s = TableModel() - class SpatialData: """ @@ -199,7 +190,7 @@ def validate_table_in_spatialdata(self, table: AnnData) -> None: UserWarning The dtypes of the instance key column in the table and the annotation target do not match. """ - TableModel().validate(table) + TableModel.validate(table) if TableModel.ATTRS_KEY in table.uns: region, _, instance_key = get_table_keys(table) region = region if isinstance(region, list) else [region] @@ -349,8 +340,13 @@ def _set_table_annotation_target( ValueError If `instance_key` is not present in the `table.obs` columns. """ - TableModel()._validate_set_region_key(table, region_key) - TableModel()._validate_set_instance_key(table, instance_key) + old_attrs = table.uns.get(TableModel.ATTRS_KEY) + # _validate_set_region_key and _validate_set_instance_key will raise an error if table.uns[ATTRS_KEY] is None, + # so let's initialize it here. Below it will be replaced with the actual metadata. + if old_attrs is None: + table.uns[TableModel.ATTRS_KEY] = {} + TableModel._validate_set_region_key(table, region_key) + TableModel._validate_set_instance_key(table, instance_key) attrs = { TableModel.REGION_KEY: region, TableModel.REGION_KEY_KEY: region_key, @@ -393,8 +389,8 @@ def _change_table_annotation_target( attrs = table.uns[TableModel.ATTRS_KEY] table_region_key = region_key if region_key else attrs.get(TableModel.REGION_KEY_KEY) - TableModel()._validate_set_region_key(table, region_key) - TableModel()._validate_set_instance_key(table, instance_key) + TableModel._validate_set_region_key(table, region_key) + TableModel._validate_set_instance_key(table, instance_key) check_target_region_column_symmetry(table, table_region_key, region) attrs[TableModel.REGION_KEY] = region @@ -1822,7 +1818,7 @@ def tables(self, tables: dict[str, AnnData]) -> None: self._shared_keys = self._shared_keys - set(self._tables.keys()) self._tables = Tables(shared_keys=self._shared_keys) for k, v in tables.items(): - TableModel().validate(v) + TableModel.validate(v) self._tables[k] = v @staticmethod diff --git a/src/spatialdata/_io/io_table.py b/src/spatialdata/_io/io_table.py index 719c9c1a..d2e5c3cc 100644 --- a/src/spatialdata/_io/io_table.py +++ b/src/spatialdata/_io/io_table.py @@ -56,7 +56,7 @@ def write_table( ) -> None: if TableModel.ATTRS_KEY in table.uns: region, region_key, instance_key = get_table_keys(table) - TableModel().validate(table) + TableModel.validate(table) else: region, region_key, instance_key = (None, None, None) diff --git a/src/spatialdata/_utils.py b/src/spatialdata/_utils.py index f5bc1257..64dd7638 100644 --- a/src/spatialdata/_utils.py +++ b/src/spatialdata/_utils.py @@ -308,7 +308,7 @@ def _error_message_add_element() -> None: def _check_match_length_channels_c_dim( - data: DaskArray | DataArray | DataTree, c_coords: str | list[str], dims: tuple[str] + data: DaskArray | DataArray | DataTree, c_coords: str | list[str], dims: tuple[str, ...] ) -> list[str]: """ Check whether channel names `c_coords` are of equal length to the `c` dimension of the data. diff --git a/src/spatialdata/models/_utils.py b/src/spatialdata/models/_utils.py index f2840f5e..6631f6c9 100644 --- a/src/spatialdata/models/_utils.py +++ b/src/spatialdata/models/_utils.py @@ -401,7 +401,7 @@ def set_channel_names(element: DataArray | DataTree, channel_names: str | list[s # get_model cannot be used due to circular import so get_axes_names is used instead if model in [Image2DModel, Image3DModel]: - channel_names = _check_match_length_channels_c_dim(element, channel_names, model.dims.dims) # type: ignore[union-attr] + channel_names = _check_match_length_channels_c_dim(element, channel_names, model.dims) # type: ignore[union-attr] if isinstance(element, DataArray): element = element.assign_coords(c=channel_names) else: diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index 6a126b02..2d54c709 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -23,13 +23,6 @@ from shapely.io import from_geojson, from_ragged_array from spatial_image import to_spatial_image from xarray import DataArray, DataTree -from xarray_schema.components import ( - ArrayTypeSchema, - AttrSchema, - AttrsSchema, - DimsSchema, -) -from xarray_schema.dataarray import DataArraySchema from spatialdata._core.validation import validate_table_attr_keys from spatialdata._logging import logger @@ -50,13 +43,12 @@ _set_transformations, compute_coordinates, ) -from spatialdata.transformations.transformations import BaseTransformation, Identity +from spatialdata.transformations.transformations import Identity # Types Chunks_t: TypeAlias = int | tuple[int, ...] | tuple[tuple[int, ...], ...] | Mapping[Any, None | int | tuple[int, ...]] ScaleFactors_t = Sequence[dict[str, int] | int] -Transform_s = AttrSchema(BaseTransformation, None) ATTRS_KEY = "spatialdata_attrs" @@ -83,11 +75,12 @@ def _parse_transformations(element: SpatialElement, transformations: MappingToCo _set_transformations(element, parsed_transformations) -class RasterSchema(DataArraySchema): +class RasterSchema: """Base schema for raster data.""" # TODO add DataTree validation, validate has scale0... etc and each scale contains 1 image in .variables. ATTRS_KEY = ATTRS_KEY + dims: tuple[str, ...] @classmethod def parse( @@ -179,29 +172,29 @@ def parse( else: dims = data.dims # but if dims don't match the model's dims, throw error - if set(dims).symmetric_difference(cls.dims.dims): - raise ValueError(f"Wrong `dims`: {dims}. Expected {cls.dims.dims}.") + if set(dims).symmetric_difference(cls.dims): + raise ValueError(f"Wrong `dims`: {dims}. Expected {cls.dims}.") _reindex = lambda d: d # if there are no dims in the data, use the model's dims or provided dims elif isinstance(data, np.ndarray | DaskArray): if not isinstance(data, DaskArray): # numpy -> dask data = from_array(data) if dims is None: - dims = cls.dims.dims + dims = cls.dims else: - if len(set(dims).symmetric_difference(cls.dims.dims)) > 0: - raise ValueError(f"Wrong `dims`: {dims}. Expected {cls.dims.dims}.") + if len(set(dims).symmetric_difference(cls.dims)) > 0: + raise ValueError(f"Wrong `dims`: {dims}. Expected {cls.dims}.") _reindex = lambda d: dims.index(d) else: raise ValueError(f"Unsupported data type: {type(data)}.") # transpose if possible - if tuple(dims) != cls.dims.dims: + if tuple(dims) != cls.dims: try: if isinstance(data, DataArray): - data = data.transpose(*list(cls.dims.dims)) + data = data.transpose(*list(cls.dims)) elif isinstance(data, DaskArray): - data = data.transpose(*[_reindex(d) for d in cls.dims.dims]) + data = data.transpose(*[_reindex(d) for d in cls.dims]) else: raise ValueError(f"Unsupported data type: {type(data)}.") except ValueError as e: @@ -212,15 +205,15 @@ def parse( # finally convert to spatial image if c_coords is not None: - c_coords = _check_match_length_channels_c_dim(data, c_coords, cls.dims.dims) + c_coords = _check_match_length_channels_c_dim(data, c_coords, cls.dims) - if c_coords is not None and len(c_coords) != data.shape[cls.dims.dims.index("c")]: + if c_coords is not None and len(c_coords) != data.shape[cls.dims.index("c")]: raise ValueError( f"The number of channel names `{len(c_coords)}` does not match the length of dimension 'c'" - f" with length {data.shape[cls.dims.dims.index('c')]}." + f" with length {data.shape[cls.dims.index('c')]}." ) - data = to_spatial_image(array_like=data, dims=cls.dims.dims, c_coords=c_coords, **kwargs) + data = to_spatial_image(array_like=data, dims=cls.dims, c_coords=c_coords, **kwargs) # parse transformations _parse_transformations(data, transformations) # convert to multiscale if needed @@ -245,12 +238,13 @@ def parse( if isinstance(chunks, tuple): chunks = {dim: chunks[index] for index, dim in enumerate(data.dims)} data = data.chunk(chunks=chunks) - cls()._check_chunk_size_not_too_large(data) # recompute coordinates for (multiscale) spatial image - return compute_coordinates(data) + data = compute_coordinates(data) + cls.validate(data) + return data - @singledispatchmethod - def validate(self, data: Any) -> None: + @classmethod + def validate(cls, data: Any) -> None: """ Validate data. @@ -264,19 +258,20 @@ def validate(self, data: Any) -> None: ValueError If data is not valid. """ - raise ValueError( - f"Unsupported data type: {type(data)}. Please use .parse() from Image2DModel, Image3DModel, Labels2DModel " - "or Labels3DModel to construct data that is guaranteed to be valid." - ) - - @validate.register(DataArray) - def _(self, data: DataArray) -> None: - super().validate(data) - self._check_chunk_size_not_too_large(data) - self._check_transforms_present(data) + if isinstance(data, DataArray): + cls._validate_dataarray(data) + elif isinstance(data, DataTree): + cls._validate_datatree(data) + else: + raise ValueError( + f"Unsupported data type: {type(data)}. Please use .parse() from Image2DModel, Image3DModel, " + "Labels2DModel or Labels3DModel to construct data that is guaranteed to be valid." + ) + cls._check_chunk_size_not_too_large(data) + cls._check_transforms_present(data) - @validate.register(DataTree) - def _(self, data: DataTree) -> None: + @classmethod + def _validate_datatree(cls, data: DataTree) -> None: for j, k in zip(data.keys(), [f"scale{i}" for i in np.arange(len(data.keys()))], strict=True): if j != k: raise ValueError(f"Wrong key for multiscale data, found: `{j}`, expected: `{k}`.") @@ -285,11 +280,37 @@ def _(self, data: DataTree) -> None: raise ValueError(f"Expected exactly one data variable for the datatree: found `{name}`.") name = list(name)[0] for d in data: - super().validate(data[d][name]) - self._check_chunk_size_not_too_large(data) - self._check_transforms_present(data) + cls._validate_dataarray(data[d][name]) + + @classmethod + def _validate_dataarray(cls, data: DataArray) -> None: + """Validate a single DataArray against this schema's dims, array type, and attrs.""" + if not isinstance(data, DataArray): + raise ValueError(f"Expected DataArray, got {type(data)}") + cls._validate_dims(data) + cls._validate_array_type(data) + cls._validate_attrs(data) + + @classmethod + def _validate_dims(cls, data: DataArray) -> None: + if len(data.dims) != len(cls.dims): + raise ValueError(f"Expected {len(cls.dims)} dimensions, got {len(data.dims)}: {data.dims}") + for expected, actual in zip(cls.dims, data.dims, strict=True): + if expected != actual: + raise ValueError(f"Expected dimension '{expected}', got '{actual}'") + + @classmethod + def _validate_array_type(cls, data: DataArray) -> None: + if not isinstance(data.data, DaskArray): + raise ValueError(f"Expected array type {DaskArray}, got {type(data.data)}") + + @classmethod + def _validate_attrs(cls, data: DataArray) -> None: + if "transform" not in data.attrs: + raise ValueError("Missing required attribute 'transform'") - def _check_transforms_present(self, data: DataArray | DataTree) -> None: + @classmethod + def _check_transforms_present(cls, data: DataArray | DataTree) -> None: parsed_transform = _get_transformations(data) if parsed_transform is None: raise ValueError( @@ -297,7 +318,8 @@ def _check_transforms_present(self, data: DataArray | DataTree) -> None: f"raster elements, e.g. images, labels." ) - def _check_chunk_size_not_too_large(self, data: DataArray | DataTree) -> None: + @classmethod + def _check_chunk_size_not_too_large(cls, data: DataArray | DataTree) -> None: if isinstance(data, DataArray): try: max_per_dimension: dict[int, int] = {} @@ -339,23 +361,11 @@ def _check_chunk_size_not_too_large(self, data: DataArray | DataTree) -> None: assert len(name) == 1 name = list(name)[0] for d in data: - super().validate(data[d][name]) - self._check_chunk_size_not_too_large(data[d][name]) + cls._check_chunk_size_not_too_large(data[d][name]) class Labels2DModel(RasterSchema): - dims = DimsSchema((Y, X)) - array_type = ArrayTypeSchema(DaskArray) - attrs = AttrsSchema({"transform": Transform_s}) - - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__( - dims=self.dims, - array_type=self.array_type, - attrs=self.attrs, - *args, - **kwargs, - ) + dims = (Y, X) @classmethod def parse( # noqa: D102 @@ -372,18 +382,7 @@ def parse( # noqa: D102 class Labels3DModel(RasterSchema): - dims = DimsSchema((Z, Y, X)) - array_type = ArrayTypeSchema(DaskArray) - attrs = AttrsSchema({"transform": Transform_s}) - - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__( - dims=self.dims, - array_type=self.array_type, - attrs=self.attrs, - *args, - **kwargs, - ) + dims = (Z, Y, X) @classmethod def parse(self, *args: Any, **kwargs: Any) -> DataArray | DataTree: # noqa: D102 @@ -396,33 +395,11 @@ def parse(self, *args: Any, **kwargs: Any) -> DataArray | DataTree: # noqa: D10 class Image2DModel(RasterSchema): - dims = DimsSchema((C, Y, X)) - array_type = ArrayTypeSchema(DaskArray) - attrs = AttrsSchema({"transform": Transform_s}) - - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__( - dims=self.dims, - array_type=self.array_type, - attrs=self.attrs, - *args, - **kwargs, - ) + dims = (C, Y, X) class Image3DModel(RasterSchema): - dims = DimsSchema((C, Z, Y, X)) - array_type = ArrayTypeSchema(DaskArray) - attrs = AttrsSchema({"transform": Transform_s}) - - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__( - dims=self.dims, - array_type=self.array_type, - attrs=self.attrs, - *args, - **kwargs, - ) + dims = (C, Z, Y, X) class ShapesModel: @@ -918,7 +895,8 @@ class TableModel: INSTANCE_KEY = "instance_key" ATTRS_KEY = ATTRS_KEY - def _validate_set_region_key(self, data: AnnData, region_key: str | None = None) -> None: + @classmethod + def _validate_set_region_key(cls, data: AnnData, region_key: str | None = None) -> None: """ Validate the region key in table.uns or set a new region key as the region key column. @@ -941,14 +919,16 @@ def _validate_set_region_key(self, data: AnnData, region_key: str | None = None) """ attrs = data.uns.get(ATTRS_KEY) if attrs is None: - data.uns[ATTRS_KEY] = attrs = {} - table_region_key = attrs.get(self.REGION_KEY_KEY) + raise ValueError( + f"No '{ATTRS_KEY}' found in `adata.uns`. Please use TableModel.parse() to initialize the table." + ) if not region_key: + table_region_key = attrs.get(cls.REGION_KEY_KEY) if not table_region_key: raise ValueError( "No region_key in table.uns and no region_key provided as argument. Please specify 'region_key'." ) - if data.obs.get(attrs[TableModel.REGION_KEY_KEY]) is None: + if data.obs.get(attrs[cls.REGION_KEY_KEY]) is None: raise ValueError( f"Specified region_key in table.uns '{table_region_key}' is not " f"present as column in table.obs. Please specify region_key." @@ -956,9 +936,10 @@ def _validate_set_region_key(self, data: AnnData, region_key: str | None = None) else: if region_key not in data.obs: raise ValueError(f"'{region_key}' column not present in table.obs") - attrs[self.REGION_KEY_KEY] = region_key + attrs[cls.REGION_KEY_KEY] = region_key - def _validate_set_instance_key(self, data: AnnData, instance_key: str | None = None) -> None: + @classmethod + def _validate_set_instance_key(cls, data: AnnData, instance_key: str | None = None) -> None: """ Validate the instance_key in table.uns or set a new instance_key as the instance_key column. @@ -985,26 +966,27 @@ def _validate_set_instance_key(self, data: AnnData, instance_key: str | None = N """ attrs = data.uns.get(ATTRS_KEY) if attrs is None: - data.uns[ATTRS_KEY] = {} - + raise ValueError( + f"No '{ATTRS_KEY}' found in `adata.uns`. Please use TableModel.parse() to initialize the table." + ) if not instance_key: - if not attrs.get(TableModel.INSTANCE_KEY): + if not attrs.get(cls.INSTANCE_KEY): raise ValueError( "No instance_key in table.uns and no instance_key provided as argument. Please " "specify instance_key." ) - if data.obs.get(attrs[self.INSTANCE_KEY]) is None: + if data.obs.get(attrs[cls.INSTANCE_KEY]) is None: raise ValueError( - f"Specified instance_key in table.uns '{attrs.get(self.INSTANCE_KEY)}' is not present" + f"Specified instance_key in table.uns '{attrs.get(cls.INSTANCE_KEY)}' is not present" f" as column in table.obs. Please specify instance_key." ) - if instance_key: - if instance_key in data.obs: - attrs[self.INSTANCE_KEY] = instance_key - else: + else: + if instance_key not in data.obs: raise ValueError(f"Instance key column '{instance_key}' not found in table.obs.") + attrs[cls.INSTANCE_KEY] = instance_key - def _validate_table_annotation_metadata(self, data: AnnData) -> None: + @classmethod + def _validate_table_annotation_metadata(cls, data: AnnData) -> None: """ Validate annotation metadata. @@ -1043,12 +1025,12 @@ def _validate_table_annotation_metadata(self, data: AnnData) -> None: if "instance_key" not in attr: raise ValueError(f"`instance_key` not found in `adata.uns['{ATTRS_KEY}']`." + SUGGESTION) - if attr[self.REGION_KEY_KEY] not in data.obs: - raise ValueError(f"`{attr[self.REGION_KEY_KEY]}` not found in `adata.obs`. Please create the column.") - if attr[self.INSTANCE_KEY] not in data.obs: - raise ValueError(f"`{attr[self.INSTANCE_KEY]}` not found in `adata.obs`. Please create the column.") + if attr[cls.REGION_KEY_KEY] not in data.obs: + raise ValueError(f"`{attr[cls.REGION_KEY_KEY]}` not found in `adata.obs`. Please create the column.") + if attr[cls.INSTANCE_KEY] not in data.obs: + raise ValueError(f"`{attr[cls.INSTANCE_KEY]}` not found in `adata.obs`. Please create the column.") if ( - (dtype := data.obs[attr[self.INSTANCE_KEY]].dtype) + (dtype := data.obs[attr[cls.INSTANCE_KEY]].dtype) not in [ int, np.int16, @@ -1059,22 +1041,22 @@ def _validate_table_annotation_metadata(self, data: AnnData) -> None: np.uint64, "O", ] - and not pd.api.types.is_string_dtype(data.obs[attr[self.INSTANCE_KEY]]) - or (dtype == "O" and (val_dtype := type(data.obs[attr[self.INSTANCE_KEY]].iloc[0])) is not str) + and not pd.api.types.is_string_dtype(data.obs[attr[cls.INSTANCE_KEY]]) + or (dtype == "O" and (val_dtype := type(data.obs[attr[cls.INSTANCE_KEY]].iloc[0])) is not str) ): dtype = dtype if dtype != "O" else val_dtype raise TypeError( f"Only int, np.int16, np.int32, np.int64, uint equivalents or string allowed as dtype for " f"instance_key column in obs. Dtype found to be {dtype}" ) - expected_regions = attr[self.REGION_KEY] if isinstance(attr[self.REGION_KEY], list) else [attr[self.REGION_KEY]] - found_regions = data.obs[attr[self.REGION_KEY_KEY]].unique().tolist() + expected_regions = attr[cls.REGION_KEY] if isinstance(attr[cls.REGION_KEY], list) else [attr[cls.REGION_KEY]] + found_regions = data.obs[attr[cls.REGION_KEY_KEY]].unique().tolist() if len(set(expected_regions).symmetric_difference(set(found_regions))) > 0: - raise ValueError(f"Regions in the AnnData object and `{attr[self.REGION_KEY_KEY]}` do not match.") + raise ValueError(f"Regions in the AnnData object and `{attr[cls.REGION_KEY_KEY]}` do not match.") # Warning for object/string columns with NaN in region_key or instance_key - instance_key = attr[self.INSTANCE_KEY] - region_key = attr[self.REGION_KEY_KEY] + instance_key = attr[cls.INSTANCE_KEY] + region_key = attr[cls.REGION_KEY_KEY] for key_name, key_value in [("region_key", region_key), ("instance_key", instance_key)]: if key_value in data.obs: col = data.obs[key_value] @@ -1087,8 +1069,9 @@ def _validate_table_annotation_metadata(self, data: AnnData) -> None: "cycles." ) + @classmethod def validate( - self, + cls, data: AnnData, ) -> AnnData: """ @@ -1130,7 +1113,7 @@ def validate( if data.obs[instance_key].isnull().values.any(): raise ValueError("`table.obs[instance_key]` must not contain null values, but it does.") - self._validate_table_annotation_metadata(data) + cls._validate_table_annotation_metadata(data) return data @@ -1221,7 +1204,7 @@ def parse( } adata.uns[cls.ATTRS_KEY] = attr convert_region_column_to_categorical(adata) - cls().validate(adata) + cls.validate(adata) return adata @@ -1256,7 +1239,7 @@ def _validate_and_return( schema: Schema_t, e: SpatialElement, ) -> Schema_t: - schema().validate(e) + schema.validate(e) return schema if isinstance(e, DataArray | DataTree): diff --git a/tests/io/test_multi_table.py b/tests/io/test_multi_table.py index 70974c8a..5c90109d 100644 --- a/tests/io/test_multi_table.py +++ b/tests/io/test_multi_table.py @@ -192,7 +192,7 @@ def test_single_table_multiple_elements(self, tmp_path: str): table = _get_table(region=["poly", "multipoly"]) subset = table[table.obs.region == "multipoly"] with pytest.raises(ValueError, match="Regions in"): - TableModel().validate(subset) + TableModel.validate(subset) test_sdata = SpatialData( shapes={ diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 1e82b698..e2087ace 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -138,12 +138,7 @@ def _passes_validation_after_io(self, model: Any, element: Any, element_type: st sdata_read = SpatialData.read(path) group_name = element_type if element_type != "image" else "images" element_read = sdata_read.__getattribute__(group_name)["element"] - # TODO: raster models have validate as a method (for non-raster it's a class method), - # probably because they call the xarray schema validation in the superclass. Can we make it consistent? - if element_type == "image" or element_type == "labels": - model().validate(element_read) - else: - model.validate(element_read) + model.validate(element_read) @pytest.mark.parametrize("converter", [lambda _: _, from_array, DataArray, to_spatial_image]) @pytest.mark.parametrize("model", [Image2DModel, Labels2DModel, Labels3DModel, Image3DModel]) @@ -156,7 +151,7 @@ def test_raster_schema( permute: bool, kwargs: dict[str, str] | None, ) -> None: - dims = np.array(model.dims.dims).tolist() + dims = np.array(model.dims).tolist() if permute: RNG.shuffle(dims) n_dims = len(dims) @@ -164,7 +159,7 @@ def test_raster_schema( if converter is DataArray: converter = partial(converter, dims=dims) elif converter is to_spatial_image: - converter = partial(converter, dims=model.dims.dims) + converter = partial(converter, dims=model.dims) if n_dims == 2: image: ArrayLike = RNG.uniform(size=(10, 10)) elif n_dims == 3: @@ -246,7 +241,7 @@ def test_raster_models_parse_with_chunks_parameter(self, model, chunks, expected assert y_ms["scale0"]["image"].data.chunksize == expected # parse as DataArray - data_array = DataArray(image, dims=model.dims.dims) + data_array = DataArray(image, dims=model.dims) # single scale z_ss = model.parse(data_array, chunks=chunks) assert z_ss.data.chunksize == expected @@ -257,7 +252,7 @@ def test_raster_models_parse_with_chunks_parameter(self, model, chunks, expected @pytest.mark.parametrize("model", [Labels2DModel, Labels3DModel]) def test_labels_model_with_multiscales(self, model): # Passing "scale_factors" should generate multiscales with a "method" appropriate for labels - dims = np.array(model.dims.dims).tolist() + dims = np.array(model.dims).tolist() n_dims = len(dims) # A labels image with one label value 4, that partially covers 2×2 blocks. @@ -545,7 +540,7 @@ def test_table_model_invalid_names(self, key: str, attr: str, parse: bool): if parse: TableModel.parse(adata) else: - TableModel().validate(adata) + TableModel.validate(adata) elif key != "_index": # "_index" is only disallowed in obs/var if attr in ("obsm", "varm", "obsp", "varp", "layers"): array = np.array([[0]]) @@ -557,7 +552,7 @@ def test_table_model_invalid_names(self, key: str, attr: str, parse: bool): if parse: TableModel.parse(adata) else: - TableModel().validate(adata) + TableModel.validate(adata) elif attr == "uns": adata = AnnData(np.array([[0]]), **{attr: {key: {}}}) with pytest.raises( @@ -567,7 +562,7 @@ def test_table_model_invalid_names(self, key: str, attr: str, parse: bool): if parse: TableModel.parse(adata) else: - TableModel().validate(adata) + TableModel.validate(adata) @pytest.mark.parametrize( "keys", @@ -586,7 +581,43 @@ def test_table_model_not_unique_columns(self, keys: list[str], attr: str, parse: if parse: TableModel.parse(adata) else: - TableModel().validate(adata) + TableModel.validate(adata) + + +def test_validate_set_instance_key_missing_attrs(): + """Test _validate_set_instance_key behavior when ATTRS_KEY is missing from uns.""" + # When instance_key arg is provided and column exists, but attrs is missing, it should fail + adata = AnnData(np.array([[0]]), obs=pd.DataFrame({"instance_id": [1]}, index=["1"])) + with pytest.raises(ValueError, match="No 'spatialdata_attrs' found"): + TableModel._validate_set_instance_key(adata, instance_key="instance_id") + + # When instance_key arg is provided but column doesn't exist, should raise about the column + adata2 = AnnData(np.array([[0]])) + adata2.uns[TableModel.ATTRS_KEY] = {} + with pytest.raises(ValueError, match="Instance key column 'missing' not found"): + TableModel._validate_set_instance_key(adata2, instance_key="missing") + + # When no instance_key arg and no attrs, should raise about missing attrs + with pytest.raises(ValueError, match="No 'spatialdata_attrs' found"): + TableModel._validate_set_instance_key(adata) + + +def test_validate_set_region_key_missing_attrs(): + """Test _validate_set_region_key behavior when ATTRS_KEY is missing from uns.""" + # When region_key arg is provided and column exists, but attrs is missing, it should fail + adata = AnnData(np.array([[0]]), obs=pd.DataFrame({"region": ["r1"]}, index=["1"])) + with pytest.raises(ValueError, match="No 'spatialdata_attrs' found"): + TableModel._validate_set_region_key(adata, region_key="region") + + # When region_key arg is provided but column doesn't exist, should raise about the column + adata2 = AnnData(np.array([[0]])) + adata2.uns[TableModel.ATTRS_KEY] = {} + with pytest.raises(ValueError, match="column not present in table.obs"): + TableModel._validate_set_region_key(adata2, region_key="missing") + + # When no region_key arg and no attrs, should raise about missing attrs + with pytest.raises(ValueError, match="No 'spatialdata_attrs' found"): + TableModel._validate_set_region_key(adata) def test_get_schema(): @@ -826,7 +857,7 @@ def test_warning_on_large_chunks(): warnings.simplefilter("always") multiscale = Labels2DModel.parse(data_large, scale_factors=[2, 2], method="xarray_coarsen") multiscale = multiscale.chunk({"x": 50000, "y": 50000}) - Labels2DModel().validate(multiscale) + Labels2DModel.validate(multiscale) assert len(w) == 1, "Warning should be raised for large chunk size" assert issubclass(w[-1].category, UserWarning) assert "Detected chunks larger than:" in str(w[-1].message)