diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index 6a126b02..e148a986 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -1047,25 +1047,39 @@ def _validate_table_annotation_metadata(self, data: AnnData) -> None: raise ValueError(f"`{attr[self.REGION_KEY_KEY]}` not found in `adata.obs`. Please create the column.") if attr[self.INSTANCE_KEY] not in data.obs: raise ValueError(f"`{attr[self.INSTANCE_KEY]}` not found in `adata.obs`. Please create the column.") - if ( - (dtype := data.obs[attr[self.INSTANCE_KEY]].dtype) - not in [ - int, - np.int16, - np.uint16, - np.int32, - np.uint32, - np.int64, - np.uint64, - "O", - ] - and not pd.api.types.is_string_dtype(data.obs[attr[self.INSTANCE_KEY]]) - or (dtype == "O" and (val_dtype := type(data.obs[attr[self.INSTANCE_KEY]].iloc[0])) is not str) + dtype = data.obs[attr[self.INSTANCE_KEY]].dtype + + # Check if dtype is valid for instance_key column + is_valid_dtype = False + + # Check for integer types + if dtype in [int, np.int16, np.uint16, np.int32, np.uint32, np.int64, np.uint64] or isinstance( + dtype, pd.StringDtype ): - dtype = dtype if dtype != "O" else val_dtype + is_valid_dtype = True + # Check for CategoricalDtype with string categories + elif isinstance(dtype, pd.CategoricalDtype): + if pd.api.types.is_string_dtype(dtype.categories.dtype) or isinstance( + dtype.categories.dtype, pd.StringDtype + ): + is_valid_dtype = True + # Check for object dtype with string values + elif dtype == "O": + if len(data.obs[attr[self.INSTANCE_KEY]]) > 0: + val_dtype = type(data.obs[attr[self.INSTANCE_KEY]].iloc[0]) + if val_dtype is str: + is_valid_dtype = True + else: + # Empty column with object dtype is acceptable + is_valid_dtype = True + # Fallback check using pandas is_string_dtype + elif pd.api.types.is_string_dtype(dtype): + is_valid_dtype = True + + if not is_valid_dtype: raise TypeError( - f"Only int, np.int16, np.int32, np.int64, uint equivalents or string allowed as dtype for " - f"instance_key column in obs. Dtype found to be {dtype}" + f"Only int, np.int16, np.int32, np.int64, uint equivalents, pandas StringDtype, or string " + f"allowed as dtype for instance_key column in obs. Dtype found to be {dtype}" ) expected_regions = attr[self.REGION_KEY] if isinstance(attr[self.REGION_KEY], list) else [attr[self.REGION_KEY]] found_regions = data.obs[attr[self.REGION_KEY_KEY]].unique().tolist()