from __future__ import annotations
import lightning as L
import torch
import torch.nn as nn
import numpy as np
from anndata import AnnData
import os
import sys
import logging
from tqdm import tqdm
from typing import Optional
import anndata as ad
import warnings
log = logging.getLogger(__name__)
from sc_reconstruction.models._base_model import BaseReconstructionModel
repo = "/lustre/groups/ml01/code/xiaotong.fu/state/src"
sys.path.insert(0, repo)
from state.emb.nn.model import StateEmbeddingModel
from omegaconf import OmegaConf
from state.emb.train.trainer import get_embeddings
from state.emb.data import create_dataloader
from state.emb.utils import get_embedding_cfg, get_precision_config
from state.emb import Inference
[docs]
class ReconPretrainedStateModel(BaseReconstructionModel):
[docs]
def __init__(
self,
checkpoint_path: str,
protein_embeds_path: str,
emb_key: str = 'X_state',
read_depth: float = 4.0,
encode_batch_size: int = 64,
decode_batch_size: int = 64,
**kwargs
):
"""Pre-trained State Embedding (SE) foundation model :cite:`adduri:25`.
Wraps the SE encoder for ReconEval's foundation-model
reconstruction task. Requires the ``state`` package
and runs in the ``cstm_scvi_env`` conda env.
Args:
checkpoint_path: Path to the pre-trained model checkpoint
protein_embeds_path: Path to protein embeddings
emb_key: Key for cell embeddings in adata.obsm
read_depth: Read depth for decoding
batch_size: Batch size for inference
library_size_mode: For interface compatibility (not used in this model)
"""
self.model_params = {
'checkpoint_path': checkpoint_path,
'protein_embeds_path': protein_embeds_path,
'emb_key': emb_key,
'read_depth': read_depth,
'encode_batch_size': encode_batch_size,
'decode_batch_size': decode_batch_size
}
self.inferer = self._load_pretrained_model(checkpoint_path, protein_embeds_path)
self.emb_key = emb_key
self.read_depth = read_depth
self.encode_batch_size = encode_batch_size
self.decode_batch_size = decode_batch_size
self.genes = None
self.overlap_genes = None
def _load_pretrained_model(self, checkpoint_path, protein_embeds_path):
"""Load the pre-trained State model"""
from state.emb.utils import get_precision_config
from state.emb.nn.model import StateEmbeddingModel
import torch
print(f"Loading protein embeddings from {protein_embeds_path}")
print(f"Loading model checkpoint from {checkpoint_path}")
protein_embeds = torch.load(protein_embeds_path, weights_only=False, map_location="cpu")
# Create inference instance
inferer = ReconInference(cfg=None, protein_embeds=protein_embeds)
inferer.load_model(checkpoint_path)
return inferer
def prepare(self, adata: AnnData | None = None, **kwargs):
"""Cache the adata reference and its gene list on the wrapper."""
if adata is not None:
self.adata = adata
self.genes = adata.var_names.tolist()
def get_latent_representation(
self,
X: np.ndarray|ad.AnnData
) -> np.ndarray:
device = next(self.inferer.model.parameters()).device
self.inferer.model.eval()
if isinstance(X, ad.AnnData):
temp_adata = X
else:
temp_adata = AnnData(X=X)
temp_adata.var_names = self.genes
cell_embs = self.inferer.encode_adata(
adata=temp_adata,
batch_size=self.encode_batch_size
)
return cell_embs
def train(self, datamodule: L.LightningDataModule = None, **train_kwargs):
"""No-op: SE is used frozen. Raises if a datamodule is passed."""
if datamodule is not None:
raise NotImplementedError(
"Fine-tuning is not supported for the pretrained SE wrapper."
)
def set_genes(self, genes: list[str]):
"""Set the genes for reconstruction"""
self.genes = genes
def get_overlap_genes(self, genes) -> list[str]:
"""Get overlapping genes between adata and protein embeddings."""
adata = AnnData(X= np.zeros((1, len(genes))))
adata.var_names = genes
_, overlap_genes = ReconInference._auto_detect_gene_column(self.inferer, adata)
return list(overlap_genes)
def predict(
self,
X: np.ndarray | ad.AnnData,
target_genes: Optional[list[str]] = None,
read_depth: Optional[float] = None,
) -> np.ndarray:
"""Predict reconstruction using pre-trained model"""
device = next(self.inferer.model.parameters()).device
self.inferer.model.eval()
if isinstance(X, ad.AnnData):
temp_adata = X
else:
temp_adata = AnnData(X=X)
temp_adata.obsm[self.emb_key] = X
reconstruction = self.inferer.decode_adata(
adata=temp_adata,
genes=target_genes if target_genes is not None else self.genes,
emb_key=self.emb_key,
read_depth=read_depth if read_depth is not None else self.read_depth,
batch_size=self.decode_batch_size
)
return reconstruction
def predict_relu(self, X: np.ndarray) -> np.ndarray:
"""Predict reconstruction with ReLU activation"""
reconstruction = self.predict(X)
return np.maximum(reconstruction, 0)
def forward(self, x):
"""Forward pass for compatibility"""
# This would need to be adapted for tensor input
if isinstance(x, torch.Tensor):
x_np = x.cpu().numpy()
else:
x_np = x
reconstruction = self.predict(x_np)
return torch.from_numpy(reconstruction).to(x.device if isinstance(x, torch.Tensor) else 'cpu')
def save(self, path: str):
"""Save model configuration (not the actual pre-trained weights)"""
import json
if not path.endswith('.json'):
path = path + '.json'
os.makedirs(os.path.dirname(path), exist_ok=True)
# Save the configuration
config = {
'model_params': self.model_params,
'emb_key': self.emb_key,
'read_depth': self.read_depth,
'encode_batch_size': self.encode_batch_size,
'decode_batch_size': self.decode_batch_size,
'genes': self.genes
}
with open(path, 'w') as f:
json.dump(config, f, indent=2)
print(f"Model configuration saved to {path}")
def load(self, path: str, map_location=None) -> None:
"""Load model configuration"""
import json
with open(path, 'r') as f:
config = json.load(f)
self.model_params = config['model_params']
self.emb_key = config['emb_key']
self.read_depth = config['read_depth']
self.encode_batch_size = config['encode_batch_size']
self.decode_batch_size = config['decode_batch_size']
self.genes = config['genes']
# Reload the pre-trained model
self.inferer = self._load_pretrained_model(
self.model_params['checkpoint_path'],
self.model_params['protein_embeds_path']
)
print(f"Model configuration loaded from {path}")
class ReconInference(Inference):
def _auto_detect_gene_column(self, adata):
"""Auto-detect the gene column with highest overlap with protein embeddings."""
if self.protein_embeds is None:
log.warning("No protein embeddings available for auto-detection, using index")
return None
protein_genes = set(self.protein_embeds.keys())
best_column = None
best_overlap = 0
best_overlap_pct = 0
# Check index first
if hasattr(adata.var, "index"):
index_genes = set(adata.var.index)
overlap = len(protein_genes.intersection(index_genes))
overlap_genes = protein_genes.intersection(index_genes)
overlap_pct = overlap / len(index_genes) if len(index_genes) > 0 else 0
if overlap > best_overlap:
best_overlap = overlap
best_overlap_pct = overlap_pct
best_column = None # None means use index
# Check all columns in var
for col in adata.var.columns:
col_genes = set(adata.var[col].dropna().astype(str))
overlap = len(protein_genes.intersection(col_genes))
overlap_pct = overlap / len(col_genes) if len(col_genes) > 0 else 0
if overlap > best_overlap:
best_overlap = overlap
best_overlap_pct = overlap_pct
best_column = col
overlap_genes = protein_genes.intersection(col_genes)
if best_column is None:
log.info(
f"Auto-detected gene column: var.index (overlap: {best_overlap}/{len(protein_genes)} protein embeddings, {best_overlap_pct:.1%} of genes)"
)
else:
log.info(
f"Auto-detected gene column: var.{best_column} (overlap: {best_overlap}/{len(protein_genes)} protein embeddings, {best_overlap_pct:.1%} of genes)"
)
return best_column, overlap_genes
def get_overlap_genes(self, adata):
"""Get overlapping genes between adata and protein embeddings."""
gene_column, overlap_genes = self._auto_detect_gene_column(adata)
return overlap_genes
def __load_dataset_meta(self, adata):
num_cells, num_genes = adata.shape
return {"inference": (num_cells, num_genes)}
def encode_adata(
self,
adata,
dataset_name: str | None = None,
batch_size: int | None = None,
) -> np.ndarray:
shape_dict = self.__load_dataset_meta(adata)
if dataset_name is None:
dataset_name = "inference"
# Convert to CSR format if needed
adata = self._convert_to_csr(adata)
# Auto-detect the best gene column
gene_column, _ = self._auto_detect_gene_column(adata)
device_type = "cuda" if torch.cuda.is_available() else "cpu"
precision = get_precision_config(device_type=device_type)
# Allow overriding batch size for faster inference if more VRAM is available
dataloader_cfg = self._vci_conf
if batch_size is not None:
try:
dataloader_cfg = OmegaConf.create(OmegaConf.to_container(self._vci_conf, resolve=True))
# Ensure nested structure exists
if not hasattr(dataloader_cfg, "model"):
dataloader_cfg["model"] = {}
dataloader_cfg.model.batch_size = int(batch_size)
log.info(f"Using override batch size: {batch_size}")
except Exception:
# Fallback: attempt direct set; if it fails, proceed with original config
try:
dataloader_cfg.model.batch_size = int(batch_size)
log.info(f"Using override batch size: {batch_size}")
except Exception:
log.warning("Failed to override batch size; using config default")
dataloader = create_dataloader(
dataloader_cfg,
adata=adata,
adata_name= dataset_name or "inference",
shape_dict=shape_dict,
# data_dir=os.path.dirname(input_adata_path),
shuffle=False,
protein_embeds=self.protein_embeds,
precision=precision,
gene_column=gene_column,
)
all_embeddings = []
all_ds_embeddings = []
for embeddings, ds_embeddings in tqdm(self.encode(dataloader), total=len(dataloader), desc="Encoding"):
all_embeddings.append(embeddings)
if ds_embeddings is not None:
all_ds_embeddings.append(ds_embeddings)
# attach this as a numpy array to the adata and write it out
all_embeddings = np.concatenate(all_embeddings, axis=0).astype(np.float32)
if len(all_ds_embeddings) > 0:
all_ds_embeddings = np.concatenate(all_ds_embeddings, axis=0).astype(np.float32)
# concatenate along axis -1 with all embeddings
all_embeddings = np.concatenate([all_embeddings, all_ds_embeddings], axis=-1)
return all_embeddings
@torch.no_grad()
def decode_generator(self, adata, genes, emb_key: str, read_depth=None, batch_size=64):
try:
cell_embs = adata.obsm[emb_key]
except:
cell_embs = adata.X
device_type = "cuda" if torch.cuda.is_available() else "cpu"
precision = get_precision_config(device_type=device_type)
# changed
cell_embs = torch.Tensor(cell_embs).to('cpu', dtype=precision)
use_rda = getattr(self.model.cfg.model, "rda", False)
if use_rda and read_depth is None:
read_depth = 4.0
print('Decoding with read depth:', read_depth)
gene_embeds = self.get_gene_embedding(genes)
print('Gene embeddings shape:', gene_embeds.shape)
# with torch.autocast(device_type=device_type, dtype=precision):
for i in tqdm(range(0, cell_embs.size(0), batch_size), total=int(cell_embs.size(0) // batch_size)):
batch_cpu = cell_embs[i : i + batch_size]
cell_embeds_batch = batch_cpu.to(self.model.device, dtype=precision)
task_counts = torch.full(
(cell_embeds_batch.shape[0],), read_depth, device=self.model.device, dtype=precision
)
ds_emb = cell_embeds_batch[:, -self.model.z_dim_ds :] # last ten columns are the dataset embeddings
merged_embs = StateEmbeddingModel.resize_batch(
cell_embeds_batch[:, :-self.model.z_dim_ds], gene_embeds, task_counts=task_counts, ds_emb=ds_emb
)
logprobs_batch = self.model.binary_decoder(merged_embs)
logprobs_batch = logprobs_batch.detach().cpu().float().numpy()
del merged_embs, ds_emb, cell_embeds_batch, task_counts
torch.cuda.empty_cache()
yield logprobs_batch.squeeze()
def decode_adata(self, adata, genes, emb_key: str, read_depth=None, batch_size=64):
decoded_list = []
for batch_decoded in self.decode_generator(
adata, genes, emb_key=emb_key, read_depth=read_depth, batch_size=batch_size
):
decoded_list.append(batch_decoded)
decoded_array = np.vstack(decoded_list)
return decoded_array