diff --git a/cytetype/main.py b/cytetype/main.py index 55b577a..408f556 100644 --- a/cytetype/main.py +++ b/cytetype/main.py @@ -21,7 +21,10 @@ aggregate_cluster_metadata, extract_visualization_coordinates, ) -from .preprocessing.validation import materialize_canonical_gene_symbols_column +from .preprocessing.validation import ( + materialize_canonical_gene_symbols_column, + _generate_unique_na_label, +) from .core.payload import build_annotation_payload, save_query_to_file from .core.artifacts import ( _is_integer_valued, @@ -87,6 +90,7 @@ def __init__( max_metadata_categories: int = 500, api_url: str = "https://prod.cytetype.nygen.io", auth_token: str | None = None, + label_na: bool = False, ) -> None: """Initialize CyteType with AnnData object and perform data preparation. @@ -125,6 +129,11 @@ def __init__( deployment. Defaults to "https://prod.cytetype.nygen.io". auth_token (str | None, optional): Bearer token for API authentication. If provided, will be included in the Authorization header as "Bearer {auth_token}". Defaults to None. + label_na (bool, optional): If True, cells with NaN values in the + ``group_key`` column are assigned an ``'Unknown'`` cluster label + (or ``'Unknown 2'``, etc. if that label already exists). The original + AnnData object is not modified. If False (default), a ``ValueError`` + is raised instead. Raises: KeyError: If the required keys are missing in `adata.obs` or `adata.uns` @@ -152,8 +161,40 @@ def __init__( self._original_gene_symbols_column = self.gene_symbols_column self.coordinates_key = validate_adata( - adata, group_key, rank_key, self.gene_symbols_column, coordinates_key + adata, group_key, rank_key, self.gene_symbols_column, coordinates_key, + label_na=label_na, ) + + if label_na: + nan_mask = adata.obs[group_key].isna() + if nan_mask.any(): + n_nan = int(nan_mask.sum()) + pct = round(100 * n_nan / adata.n_obs, 1) + existing_labels = set( + str(v) for v in adata.obs[group_key].dropna().unique() + ) + na_label = _generate_unique_na_label(existing_labels) + logger.warning( + f"⚠️ Relabeling {n_nan} cells ({pct}%) with NaN values " + f"in '{group_key}' as '{na_label}'." + ) + adata = anndata.AnnData( + X=adata.X, + obs=adata.obs.copy(), + var=adata.var, + uns=adata.uns, + obsm=adata.obsm, + varm=adata.varm, + layers=adata.layers, + obsp=adata.obsp, + varp=adata.varp, + ) + col = adata.obs[group_key] + if hasattr(col, "cat"): + col = col.cat.add_categories(na_label) + adata.obs[group_key] = col.fillna(na_label) + self.adata = adata + ( self.gene_symbols_column, self._original_gene_symbols_column, diff --git a/cytetype/preprocessing/validation.py b/cytetype/preprocessing/validation.py index bbede86..c785daf 100644 --- a/cytetype/preprocessing/validation.py +++ b/cytetype/preprocessing/validation.py @@ -266,15 +266,43 @@ def _ur_sort_key(ur: float) -> float: return None +def _generate_unique_na_label(existing_labels: set[str]) -> str: + label = "Unknown" + if label not in existing_labels: + return label + n = 2 + while f"{label} {n}" in existing_labels: + n += 1 + return f"{label} {n}" + + def validate_adata( adata: anndata.AnnData, cell_group_key: str, rank_genes_key: str, gene_symbols_col: str | None, coordinates_key: str, + label_na: bool = False, ) -> str | None: if cell_group_key not in adata.obs: raise KeyError(f"Cell group key '{cell_group_key}' not found in `adata.obs`.") + + nan_mask = adata.obs[cell_group_key].isna() + n_nan = int(nan_mask.sum()) + if n_nan > 0: + pct = round(100 * n_nan / adata.n_obs, 1) + if n_nan == adata.n_obs: + raise ValueError( + f"All {n_nan} cells have NaN values in '{cell_group_key}'. " + f"Cannot proceed with annotation." + ) + if not label_na: + raise ValueError( + f"{n_nan} cells ({pct}%) have NaN values in '{cell_group_key}'. " + f"Either fix the data or set label_na=True to assign these cells " + f"an 'Unknown' cluster label." + ) + if adata.X is None: raise ValueError( "`adata.X` is required for ranking genes. Please ensure it contains log1p normalized data."