Foundation-model reconstruction#
Embed cells with a frozen pretrained FM, train a lightweight MLP decoder on top, score the reconstruction. The pipeline is identical for SE (STATE), scGPT, scConcept, SCimilarity, or your own FM — only step 1 swaps the wrapper class.
Pipeline
Data — load expression, split train / test.
Embed — run the FM on both splits, store at
adata.obsm["X_fm"].Train a lightweight MLP decoder on
(Z_train, X_train).Reconstruct — decode the held-out embeddings into
adata_pred.Score —
compute_all_metrics(adata_test, adata_pred).Compare multiple FMs — stack scores and render the funky map.
We use SE as the concrete FM (it runs in the same cstm_scvi_env
as the rest of the tutorial). The other FMs each need their own conda
env; swap the wrapper class in step 1 and the rest of the pipeline is
unchanged.
import os, sys, warnings
from pathlib import Path
# Disable tqdm progress bars: in nbconvert, tqdm + state's dataloader
# workers stall the kernel via the ZMQ stdout pipe (see fm.ipynb timeouts).
# Has to be set BEFORE any module imports tqdm.
os.environ["TQDM_DISABLE"] = "1"
# Prefer the public src over any pip-installed sc_reconstruction (which may
# come from a stale private clone in the active env).
sys.path.insert(0, str(Path("..").resolve() / "src"))
import numpy as np
import pandas as pd
import scanpy as sc
import torch
import lightning as L
from anndata import AnnData
from torch.utils.data import DataLoader, Dataset
warnings.filterwarnings("ignore")
# Pick ONE FM. The others are listed for reference — uncomment the one whose
# weights you have and whose env you are running in.
from sc_reconstruction.models.reconse import ReconPretrainedStateModel
# from sc_reconstruction.models.reconscgpt import ReconPretrainedscGPT
# from sc_reconstruction.models.reconscconcept import ReconPretrainedscConcept
# from sc_reconstruction.models.reconscimilarity import ReconPretrainedscimilarity
0. Data — perturbed / control split#
We use the tutorial-sized LuCA panel that ships with the repo (500 cells, 643 genes, already normalised + log1p, gene symbols match the cell-cycle / PROGENy / cytokine resource lists).
LuCA’s obs['origin'] separates the cells into matched
tumor_primary (perturbed) and normal_adjacent (control) groups —
the same split used by the paper’s eval scripts so the DEG metric can
compare predicted-vs-true differential genes. We hold out 20% of the
tumor cells for evaluation; the control cells are used in full as the
reference.
from sc_reconstruction.metrics import (
load_cell_cycle_genes, load_progeny, load_cytokine_dict_from_csv,
)
FROZEN = Path("../analysis/data/frozen")
CYTO_CSV = FROZEN / "cytokine_act_merged.csv"
adata = sc.read_h5ad(FROZEN / "luca_demo.h5ad")
# Perturbed = tumor_primary, control = normal_adjacent. The split column
# is `origin`; in the paper's experiments/03_latent_shift eval scripts the
# control key is `<study>-<tissue>-normal_adjacent`.
pert_mask = adata.obs["origin"] == "tumor_primary"
ctrl_mask = adata.obs["origin"] == "normal_adjacent"
adata_pert = adata[pert_mask].copy()
adata_ctrl = adata[ctrl_mask].copy()
print(f"perturbed (tumor_primary): {adata_pert.shape}")
print(f"control (normal_adjacent):{adata_ctrl.shape}")
# Hold out 20% of the tumor cells for evaluation.
rng = np.random.default_rng(0)
perm = rng.permutation(adata_pert.n_obs)
split = int(0.8 * adata_pert.n_obs)
adata_train = adata_pert[perm[:split]].copy()
adata_test = adata_pert[perm[split:]].copy()
print(f"train tumor: {adata_train.shape} test tumor: {adata_test.shape}")
# Resources for the biological metrics later on.
s_genes, g2m_genes = load_cell_cycle_genes(FROZEN / "regev_lab_cell_cycle_genes.txt")
progeny = load_progeny(organism="human")
cytokines = load_cytokine_dict_from_csv(CYTO_CSV, celltype="B_cell")
print(f"{len(s_genes)} S genes, {len(g2m_genes)} G2M genes, "
f"{progeny['source'].nunique()} PROGENy pathways, {len(cytokines)} cytokines")
1. Embed with a foundation model#
Instantiate the FM and embed all three sets — train (tumor), test
(held-out tumor), and the control reference. The decoder (step 2) trains
on Z_train and is scored against adata_test; the control embeddings
are decoded too so the DEG metric in step 4 can compare predicted
DEGs (held-out tumor vs decoded control) against true DEGs.
Note: SE-600M takes ~15–20 min of one-time import + checkpoint load before the first forward pass; the actual encoding of 500 cells is ~20 s on an A100/H100. Run this notebook on a GPU node.
# Point at your SE checkpoint files. We read from env vars so the rendered
# tutorial doesn't bake in machine-specific paths — set them before running:
# export SE_CKPT=/path/to/se600m_epoch16.ckpt
# export SE_PROTEIN=/path/to/protein_embeddings.pt
SE_CKPT = os.environ.get("SE_CKPT", "your-ckpt-path/se600m_epoch16.ckpt")
SE_PROTEIN = os.environ.get("SE_PROTEIN", "your-ckpt-path/protein_embeddings.pt")
fm = ReconPretrainedStateModel(
checkpoint_path=SE_CKPT,
protein_embeds_path=SE_PROTEIN,
emb_key="X_fm",
encode_batch_size=64,
)
fm.set_genes(adata_train.var_names.tolist())
adata_train.obsm["X_fm"] = fm.get_latent_representation(adata_train)
adata_test .obsm["X_fm"] = fm.get_latent_representation(adata_test)
adata_ctrl .obsm["X_fm"] = fm.get_latent_representation(adata_ctrl)
print(f"train embedding: {adata_train.obsm['X_fm'].shape}")
print(f"test embedding: {adata_test .obsm['X_fm'].shape}")
print(f"ctrl embedding: {adata_ctrl .obsm['X_fm'].shape}")
2. Train a lightweight MLP decoder#
The MLP decoder maps an FM embedding back to gene expression. We use the
package’s ReconMLPDecoder here; any module with a decode(Z) -> X
method works in its place. The data module just yields
{'x': expression, 'z': embedding} batches.
class _ZXDataset(Dataset):
def __init__(self, Z, X):
self.Z = torch.as_tensor(Z, dtype=torch.float32)
self.X = torch.as_tensor(X, dtype=torch.float32)
def __len__(self):
return len(self.X)
def __getitem__(self, i):
return {"z": self.Z[i], "x": self.X[i]}
class EmbeddingDataModule(L.LightningDataModule):
def __init__(self, Z, X, batch_size=64, val_frac=0.1, seed=0):
super().__init__()
rng = np.random.default_rng(seed)
perm = rng.permutation(X.shape[0])
n_val = int(X.shape[0] * val_frac)
self._train = _ZXDataset(Z[perm[n_val:]], X[perm[n_val:]])
self._val = _ZXDataset(Z[perm[:n_val]], X[perm[:n_val]])
self.batch_size = batch_size
def train_dataloader(self):
return DataLoader(self._train, batch_size=self.batch_size, shuffle=True)
def val_dataloader(self):
return DataLoader(self._val, batch_size=self.batch_size)
from sc_reconstruction.decoders.reconmlp import ReconMLPDecoder
embed_dim = adata_train.obsm["X_fm"].shape[1]
dec = ReconMLPDecoder(
n_input=embed_dim,
n_output=adata_train.n_vars,
n_hidden=512,
n_layers=2,
distribution="normal",
learning_rate=1e-3,
)
dm = EmbeddingDataModule(adata_train.obsm["X_fm"], adata_train.X, batch_size=64)
dec.train(datamodule=dm, max_epochs=100, accelerator="auto",
enable_progress_bar=False, logger=False)
3. Reconstruct the held-out cells#
Decode the test (held-out tumor) and control embeddings back to gene
space. The DEG metric in step 4 needs all four AnnDatas:
(adata_test, adata_pred) for the predicted-vs-true tumor side, plus
(adata_ctrl, adata_ctrl_pred) as the reference pair.
X_pred = dec.decode(adata_test.obsm["X_fm"])
X_ctrl_pred = dec.decode(adata_ctrl.obsm["X_fm"])
adata_pred = AnnData(X_pred .astype(np.float32), var=adata_test.var, obs=adata_test.obs.copy())
adata_ctrl_pred = AnnData(X_ctrl_pred.astype(np.float32), var=adata_ctrl.var, obs=adata_ctrl.obs.copy())
print("pred (held-out tumor) :", adata_pred.shape)
print("pred (control) :", adata_ctrl_pred.shape)
4. Score#
Hand the (true, predicted) held-out pair to the metrics API, plus
deg_refs=(adata_ctrl, adata_ctrl_pred) so the DEG metric can compute
differential genes on both sides. min_cells=5 lowers the per-gene
expression cutoff so the cell-cycle and pathway metrics work on the
small tutorial slice (the paper uses the default of 20 on its full
test splits).
from sc_reconstruction.metrics import compute_all_metrics
scores_se = compute_all_metrics(
adata_test, adata_pred,
s_genes=s_genes, g2m_genes=g2m_genes,
progeny_model=progeny,
cytokine_dict=cytokines,
deg_refs=(adata_ctrl, adata_ctrl_pred),
min_cells=5,
)
pd.Series(scores_se, name="SE").to_frame()
deg_dice_at_100 = 0. Top-100/643 (≈15.6%) Wilcoxon intersection
on 42 tumor vs 76 control cells is too narrow to be rank-stable: PCA-128
(R² = 0.97) also returns 0 while deg_logfc_spearman stays at ~0.3–0.5,
so the gene-wise direction is preserved and only the cutoff collapses.
The paper’s full splits (much larger N per group) stabilize it.
5. Compare multiple foundation models#
To compare FMs, stack the per-method score dicts into a wide DataFrame
(rows = method, columns = metric) and feed it to funky_heatmap.
We can’t run every FM in the same Python process — each lives in its own conda env — so the realistic workflow is:
Run steps 0–4 once per FM, in the matching env, saving the
compute_all_metricsoutput as a JSON / CSV.Load all of them into one DataFrame and call
funky_heatmaponce.
To show the pattern end-to-end inside a single notebook we substitute two PCA baselines for the other FMs. Replace those rows with the score dicts you saved from your scGPT / scConcept / SCimilarity runs.
from sklearn.decomposition import PCA
from sc_reconstruction.metrics import funky_heatmap
scored = {"SE": scores_se}
# Two PCA baselines as stand-ins for the other FMs.
for d in (32, 128):
pca = PCA(n_components=d, random_state=0).fit(adata_train.X)
X_pca_test = pca.inverse_transform(pca.transform(adata_test.X)).astype(np.float32)
X_pca_ctrl = pca.inverse_transform(pca.transform(adata_ctrl.X)).astype(np.float32)
adata_pca_test = AnnData(X_pca_test, var=adata_test.var, obs=adata_test.obs.copy())
adata_pca_ctrl = AnnData(X_pca_ctrl, var=adata_ctrl.var, obs=adata_ctrl.obs.copy())
scored[f"PCA-{d}"] = compute_all_metrics(
adata_test, adata_pca_test,
s_genes=s_genes, g2m_genes=g2m_genes,
progeny_model=progeny,
cytokine_dict=cytokines,
deg_refs=(adata_ctrl, adata_pca_ctrl),
min_cells=5,
)
raw = pd.DataFrame(scored).T
print(raw.to_string(float_format=lambda x: f"{x:.4f}"))
ax = funky_heatmap(raw, title="FM × decoder reconstruction")
Bringing your own FM#
Implement two methods and the rest of the pipeline is unchanged:
class MyFM:
def set_genes(self, genes: list[str]) -> None:
"""Record the gene panel — needed by some FMs for tokenization."""
...
def get_latent_representation(self, adata: AnnData) -> np.ndarray:
"""Return an (n_cells, d) embedding for adata."""
...
Then plug it into step 1:
fm = MyFM()
fm.set_genes(adata_train.var_names.tolist())
adata_train.obsm["X_fm"] = fm.get_latent_representation(adata_train)
adata_test .obsm["X_fm"] = fm.get_latent_representation(adata_test)
# steps 2–5 run unchanged.
The four reference wrappers in sc_reconstruction.models follow
exactly this contract.