Source code for sc_reconstruction.models.reconscconcept
from __future__ import annotations
from concept import scConcept
import scanpy as sc
import numpy as np
import warnings
from anndata import AnnData
import anndata as ad
import lightning as L
# Can't inherit from baserecon class due to version conflict
# Largely simplified version
[docs]
class ReconPretrainedscConcept():
[docs]
def __init__(
self,
checkpoint_path: str,
emb_key: str = 'X_scConcept',
**kwargs
):
"""Pre-trained scConcept foundation model wrapper.
Wraps the scConcept encoder for ReconEval's foundation-model
reconstruction task. Requires the ``concept`` package
(FlashAttention, Ampere+ GPU) and runs in the
``scconcept_env`` conda env.
Args:
checkpoint_path: Path to the pre-trained model checkpoint
(config, model weights, gene mapping).
"""
self.model_params = {
'checkpoint_path': checkpoint_path,
'emb_key': emb_key,
}
self.inferer = self._load_pretrained_model(checkpoint_path)
self.emb_key = emb_key
self.genes = None
self.overlap_genes = None
def _load_pretrained_model(self, checkpoint_path):
"""Load the pre-trained scConcept model."""
concept = scConcept(cache_dir=f'{checkpoint_path}/cache/')
config_path = f"{checkpoint_path}/config.yaml"
model_path = f"{checkpoint_path}/model.ckpt"
gene_mapping_path = f"{checkpoint_path}/pc_gene_token_mapping.pkl"
concept.load_config_and_model(
config=config_path,
model_path=model_path,
gene_mapping_path=gene_mapping_path,
)
return concept
def prepare(self, adata: AnnData | None = None, **kwargs):
"""Cache the adata reference and its gene list on the wrapper."""
if adata is not None:
self.adata = adata
self.genes = adata.var_names.tolist()
def get_latent_representation(
self,
X: np.ndarray|ad.AnnData
) -> np.ndarray:
device = next(self.inferer.model.parameters()).device
self.inferer.model.eval()
if isinstance(X, ad.AnnData):
temp_adata = X
else:
temp_adata = AnnData(X=X)
temp_adata.var['gene_id'] = self.genes
cell_embs = self.inferer.extract_embeddings(
adata=temp_adata,
gene_id_column='gene_id'
)
return cell_embs['cls_cell_emb']
def train(self, datamodule: L.LightningDataModule = None, **train_kwargs):
"""No-op: scConcept is used frozen. Raises if a datamodule is passed."""
if datamodule is not None:
raise NotImplementedError(
"Fine-tuning is not supported for the pretrained scConcept wrapper."
)
def set_genes(self, genes: list[str]):
"""Set the genes for reconstruction"""
self.genes = genes
def get_overlap_genes(self, genes) -> list[str]:
"""Not implemented for scConcept; reserved for parity with other FM wrappers."""
raise NotImplementedError(
"get_overlap_genes is not implemented for the scConcept wrapper."
)
def load(self, path: str, map_location=None) -> None:
"""Load model configuration"""
self.inferer = self._load_pretrained_model(path)