Source code for sc_reconstruction.models.reconscimilarity
from __future__ import annotations
import scanpy as sc
import numpy as np
import scipy.sparse as sp
import warnings
import anndata as ad
import lightning as L
from anndata import AnnData
from scimilarity import CellEmbedding
from scimilarity.utils import align_dataset, lognorm_counts
import pandas as pd
def _is_raw_counts(X):
"""Heuristic: raw counts are non-negative integers."""
vals = X.data if sp.issparse(X) else X.ravel()
sample = vals[:min(100_000, len(vals))]
if len(sample) == 0:
return True
return bool(np.all(sample >= 0) and np.allclose(sample, np.round(sample)))
def _normalize_total(X, target_sum=1e4):
rs = np.asarray(X.sum(axis=1)).ravel()
scale = np.zeros_like(rs, dtype=np.float64)
nz = rs > 0
scale[nz] = target_sum / rs[nz]
return X * scale[:, None]
def preprocess_for_scimilarity(X):
"""Bring X into log1p(normalize_total(raw, 1e4)) space.
Auto-detects whether X is raw counts or already log1p-normalised:
- raw counts → normalize_total(1e4) → log1p
- log1p data → expm1 → normalize_total(1e4) → log1p
"""
if _is_raw_counts(X):
X = _normalize_total(X, target_sum=1e4)
return np.log1p(X)
X = np.expm1(X)
X = _normalize_total(X, target_sum=1e4)
return np.log1p(X)
[docs]
class ReconPretrainedscimilarity():
[docs]
def __init__(
self,
checkpoint_path: str,
emb_key: str = 'X_scimilarity',
**kwargs
):
"""Pre-trained SCimilarity foundation model wrapper.
Wraps the SCimilarity encoder for ReconEval's foundation-model
reconstruction task. Requires the ``scimilarity`` package and
runs in the ``scimilarity_env`` conda env.
Args:
checkpoint_path: Path to the pre-trained model checkpoint
(config, model weights, gene mapping).
"""
self.model_params = {
'checkpoint_path': checkpoint_path,
'emb_key': emb_key,
}
self.inferer = self._load_pretrained_model(checkpoint_path)
self.emb_key = emb_key
self.genes = None
self.overlap_genes = None
def _load_pretrained_model(self, checkpoint_path):
"""Load the pre-trained SCimilarity CellEmbedding model."""
ce = CellEmbedding(model_path=checkpoint_path)
return ce
def prepare(self, adata: AnnData | None = None, **kwargs):
"""Cache the adata reference and its gene list on the wrapper."""
if adata is not None:
self.adata = adata
self.genes = adata.var_names.tolist()
def get_latent_representation(
self,
X: np.ndarray|ad.AnnData
) -> np.ndarray:
device = next(self.inferer.model.parameters()).device
if isinstance(X, ad.AnnData):
temp_adata = X
else:
temp_adata = AnnData(X=X, var=pd.DataFrame(index=self.genes))
temp_adata = align_dataset(temp_adata, self.inferer.gene_order)
X_dense = temp_adata.X.toarray() if sp.issparse(temp_adata.X) else np.asarray(temp_adata.X)
temp_X = preprocess_for_scimilarity(X_dense)
cell_embs = self.inferer.get_embeddings(temp_X)
return cell_embs
def train(self, datamodule: L.LightningDataModule = None, **train_kwargs):
"""No-op: SCimilarity is used frozen. Raises if a datamodule is passed."""
if datamodule is not None:
raise NotImplementedError(
"Fine-tuning is not supported for the pretrained SCimilarity wrapper."
)
def set_genes(self, genes: list[str]):
"""Set the genes for reconstruction"""
self.genes = genes
def get_overlap_genes(self, genes) -> list[str]:
"""Not implemented for SCimilarity; reserved for parity with other FM wrappers."""
raise NotImplementedError(
"get_overlap_genes is not implemented for the SCimilarity wrapper."
)
def load(self, path: str, map_location=None) -> None:
"""Load model configuration"""
self.inferer = self._load_pretrained_model(path)