Source code for sc_reconstruction.models.reconscgpt

from __future__ import annotations

import scanpy as sc
import numpy as np
import scipy.sparse as sp
import warnings
from pathlib import Path
import scgpt as scg
from scgpt.tokenizer.gene_tokenizer import GeneVocab
from anndata import AnnData
import anndata as ad


[docs] class ReconPretrainedscGPT():
[docs] def __init__( self, checkpoint_path: str, gene_col: str = 'var_names', **kwargs ): """Pre-trained scGPT foundation model :cite:`cui:24scgpt`. Wraps the scGPT encoder for ReconEval's foundation-model reconstruction task. Requires the ``scgpt`` package and runs in the ``scgpt`` conda env. Args: checkpoint_path: Path to the pre-trained model checkpoint. """ self.checkpoint_path = checkpoint_path self.gene_col = gene_col self.genes = None self.overlap_genes = None self._vocab = GeneVocab.from_file(Path(checkpoint_path) / "vocab.json")
def prepare(self, adata: AnnData | None = None, **kwargs): """Cache the adata reference and its gene list on the wrapper.""" if adata is not None: self.adata = adata self.genes = adata.var_names.tolist() def get_latent_representation( self, X: np.ndarray | ad.AnnData ) -> np.ndarray: if isinstance(X, ad.AnnData): temp_adata = X else: temp_adata = AnnData(X=X) temp_adata.var[self.gene_col] = self.genes if self.gene_col in ("var_names", "index"): gene_names = temp_adata.var_names.tolist() else: gene_names = temp_adata.var[self.gene_col].tolist() vocab_mask = np.array([g in self._vocab for g in gene_names]) mat = temp_adata.X if sp.issparse(mat): vocab_nnz = np.diff(sp.csc_matrix(mat[:, vocab_mask]).tocsr().indptr) else: vocab_nnz = np.count_nonzero(mat[:, vocab_mask], axis=1) has_expr = vocab_nnz > 0 if not has_expr.all(): n_skip = int((~has_expr).sum()) warnings.warn( f"scGPT: {n_skip}/{temp_adata.n_obs} cells have zero expression " f"in all vocab genes — they will get zero embeddings." ) if not has_expr.any(): raise ValueError("All cells have zero expression in scGPT vocab genes") if has_expr.all(): return scg.tasks.embed_data( temp_adata, self.checkpoint_path, gene_col=self.gene_col, batch_size=64, return_new_adata=True, ).X result = scg.tasks.embed_data( temp_adata[has_expr].copy(), self.checkpoint_path, gene_col=self.gene_col, batch_size=64, return_new_adata=True, ).X full = np.zeros((temp_adata.n_obs, result.shape[1]), dtype=np.float32) full[has_expr] = result return full def set_genes(self, genes: list[str]): """Set the genes for reconstruction""" self.genes = genes