Foundation-model reconstruction#

Embed cells with a frozen pretrained FM, train a lightweight MLP decoder on top, score the reconstruction. The pipeline is identical for SE (STATE), scGPT, scConcept, SCimilarity, or your own FM — only step 1 swaps the wrapper class.

Pipeline

  1. Data — load expression, split train / test.

  2. Embed — run the FM on both splits, store at adata.obsm["X_fm"].

  3. Train a lightweight MLP decoder on (Z_train, X_train).

  4. Reconstruct — decode the held-out embeddings into adata_pred.

  5. Scorecompute_all_metrics(adata_test, adata_pred).

  6. Compare multiple FMs — stack scores and render the funky map.

We use SE as the concrete FM (it runs in the same cstm_scvi_env as the rest of the tutorial). The other FMs each need their own conda env; swap the wrapper class in step 1 and the rest of the pipeline is unchanged.

import os, sys, warnings
from pathlib import Path

# Disable tqdm progress bars: in nbconvert, tqdm + state's dataloader
# workers stall the kernel via the ZMQ stdout pipe (see fm.ipynb timeouts).
# Has to be set BEFORE any module imports tqdm.
os.environ["TQDM_DISABLE"] = "1"

# Prefer the public src over any pip-installed sc_reconstruction (which may
# come from a stale private clone in the active env).
sys.path.insert(0, str(Path("..").resolve() / "src"))

import numpy as np
import pandas as pd
import scanpy as sc
import torch
import lightning as L
from anndata import AnnData
from torch.utils.data import DataLoader, Dataset

warnings.filterwarnings("ignore")
# Pick ONE FM. The others are listed for reference — uncomment the one whose
# weights you have and whose env you are running in.

from sc_reconstruction.models.reconse import ReconPretrainedStateModel
# from sc_reconstruction.models.reconscgpt       import ReconPretrainedscGPT
# from sc_reconstruction.models.reconscconcept   import ReconPretrainedscConcept
# from sc_reconstruction.models.reconscimilarity import ReconPretrainedscimilarity

0. Data — perturbed / control split#

We use the tutorial-sized LuCA panel that ships with the repo (500 cells, 643 genes, already normalised + log1p, gene symbols match the cell-cycle / PROGENy / cytokine resource lists).

LuCA’s obs['origin'] separates the cells into matched tumor_primary (perturbed) and normal_adjacent (control) groups — the same split used by the paper’s eval scripts so the DEG metric can compare predicted-vs-true differential genes. We hold out 20% of the tumor cells for evaluation; the control cells are used in full as the reference.

from sc_reconstruction.metrics import (
    load_cell_cycle_genes, load_progeny, load_cytokine_dict_from_csv,
)

FROZEN   = Path("../analysis/data/frozen")
CYTO_CSV = FROZEN / "cytokine_act_merged.csv"

adata = sc.read_h5ad(FROZEN / "luca_demo.h5ad")

# Perturbed = tumor_primary, control = normal_adjacent. The split column
# is `origin`; in the paper's experiments/03_latent_shift eval scripts the
# control key is `<study>-<tissue>-normal_adjacent`.
pert_mask = adata.obs["origin"] == "tumor_primary"
ctrl_mask = adata.obs["origin"] == "normal_adjacent"
adata_pert = adata[pert_mask].copy()
adata_ctrl = adata[ctrl_mask].copy()
print(f"perturbed (tumor_primary):  {adata_pert.shape}")
print(f"control   (normal_adjacent):{adata_ctrl.shape}")

# Hold out 20% of the tumor cells for evaluation.
rng = np.random.default_rng(0)
perm = rng.permutation(adata_pert.n_obs)
split = int(0.8 * adata_pert.n_obs)
adata_train = adata_pert[perm[:split]].copy()
adata_test  = adata_pert[perm[split:]].copy()
print(f"train tumor: {adata_train.shape}  test tumor: {adata_test.shape}")

# Resources for the biological metrics later on.
s_genes, g2m_genes = load_cell_cycle_genes(FROZEN / "regev_lab_cell_cycle_genes.txt")
progeny   = load_progeny(organism="human")
cytokines = load_cytokine_dict_from_csv(CYTO_CSV, celltype="B_cell")
print(f"{len(s_genes)} S genes, {len(g2m_genes)} G2M genes, "
      f"{progeny['source'].nunique()} PROGENy pathways, {len(cytokines)} cytokines")

1. Embed with a foundation model#

Instantiate the FM and embed all three sets — train (tumor), test (held-out tumor), and the control reference. The decoder (step 2) trains on Z_train and is scored against adata_test; the control embeddings are decoded too so the DEG metric in step 4 can compare predicted DEGs (held-out tumor vs decoded control) against true DEGs.

Note: SE-600M takes ~15–20 min of one-time import + checkpoint load before the first forward pass; the actual encoding of 500 cells is ~20 s on an A100/H100. Run this notebook on a GPU node.

# Point at your SE checkpoint files. We read from env vars so the rendered
# tutorial doesn't bake in machine-specific paths — set them before running:
#   export SE_CKPT=/path/to/se600m_epoch16.ckpt
#   export SE_PROTEIN=/path/to/protein_embeddings.pt
SE_CKPT    = os.environ.get("SE_CKPT",    "your-ckpt-path/se600m_epoch16.ckpt")
SE_PROTEIN = os.environ.get("SE_PROTEIN", "your-ckpt-path/protein_embeddings.pt")

fm = ReconPretrainedStateModel(
    checkpoint_path=SE_CKPT,
    protein_embeds_path=SE_PROTEIN,
    emb_key="X_fm",
    encode_batch_size=64,
)
fm.set_genes(adata_train.var_names.tolist())

adata_train.obsm["X_fm"] = fm.get_latent_representation(adata_train)
adata_test .obsm["X_fm"] = fm.get_latent_representation(adata_test)
adata_ctrl .obsm["X_fm"] = fm.get_latent_representation(adata_ctrl)
print(f"train embedding: {adata_train.obsm['X_fm'].shape}")
print(f"test  embedding: {adata_test .obsm['X_fm'].shape}")
print(f"ctrl  embedding: {adata_ctrl .obsm['X_fm'].shape}")

2. Train a lightweight MLP decoder#

The MLP decoder maps an FM embedding back to gene expression. We use the package’s ReconMLPDecoder here; any module with a decode(Z) -> X method works in its place. The data module just yields {'x': expression, 'z': embedding} batches.

class _ZXDataset(Dataset):
    def __init__(self, Z, X):
        self.Z = torch.as_tensor(Z, dtype=torch.float32)
        self.X = torch.as_tensor(X, dtype=torch.float32)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, i):
        return {"z": self.Z[i], "x": self.X[i]}


class EmbeddingDataModule(L.LightningDataModule):
    def __init__(self, Z, X, batch_size=64, val_frac=0.1, seed=0):
        super().__init__()
        rng = np.random.default_rng(seed)
        perm = rng.permutation(X.shape[0])
        n_val = int(X.shape[0] * val_frac)
        self._train = _ZXDataset(Z[perm[n_val:]], X[perm[n_val:]])
        self._val   = _ZXDataset(Z[perm[:n_val]], X[perm[:n_val]])
        self.batch_size = batch_size

    def train_dataloader(self):
        return DataLoader(self._train, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self._val, batch_size=self.batch_size)
from sc_reconstruction.decoders.reconmlp import ReconMLPDecoder

embed_dim = adata_train.obsm["X_fm"].shape[1]
dec = ReconMLPDecoder(
    n_input=embed_dim,
    n_output=adata_train.n_vars,
    n_hidden=512,
    n_layers=2,
    distribution="normal",
    learning_rate=1e-3,
)

dm = EmbeddingDataModule(adata_train.obsm["X_fm"], adata_train.X, batch_size=64)
dec.train(datamodule=dm, max_epochs=100, accelerator="auto",
          enable_progress_bar=False, logger=False)

3. Reconstruct the held-out cells#

Decode the test (held-out tumor) and control embeddings back to gene space. The DEG metric in step 4 needs all four AnnDatas: (adata_test, adata_pred) for the predicted-vs-true tumor side, plus (adata_ctrl, adata_ctrl_pred) as the reference pair.

X_pred      = dec.decode(adata_test.obsm["X_fm"])
X_ctrl_pred = dec.decode(adata_ctrl.obsm["X_fm"])

adata_pred      = AnnData(X_pred     .astype(np.float32), var=adata_test.var, obs=adata_test.obs.copy())
adata_ctrl_pred = AnnData(X_ctrl_pred.astype(np.float32), var=adata_ctrl.var, obs=adata_ctrl.obs.copy())
print("pred (held-out tumor) :", adata_pred.shape)
print("pred (control)        :", adata_ctrl_pred.shape)

4. Score#

Hand the (true, predicted) held-out pair to the metrics API, plus deg_refs=(adata_ctrl, adata_ctrl_pred) so the DEG metric can compute differential genes on both sides. min_cells=5 lowers the per-gene expression cutoff so the cell-cycle and pathway metrics work on the small tutorial slice (the paper uses the default of 20 on its full test splits).

from sc_reconstruction.metrics import compute_all_metrics

scores_se = compute_all_metrics(
    adata_test, adata_pred,
    s_genes=s_genes, g2m_genes=g2m_genes,
    progeny_model=progeny,
    cytokine_dict=cytokines,
    deg_refs=(adata_ctrl, adata_ctrl_pred),
    min_cells=5,
)
pd.Series(scores_se, name="SE").to_frame()

deg_dice_at_100 = 0. Top-100/643 (≈15.6%) Wilcoxon intersection on 42 tumor vs 76 control cells is too narrow to be rank-stable: PCA-128 (R² = 0.97) also returns 0 while deg_logfc_spearman stays at ~0.3–0.5, so the gene-wise direction is preserved and only the cutoff collapses. The paper’s full splits (much larger N per group) stabilize it.

5. Compare multiple foundation models#

To compare FMs, stack the per-method score dicts into a wide DataFrame (rows = method, columns = metric) and feed it to funky_heatmap.

We can’t run every FM in the same Python process — each lives in its own conda env — so the realistic workflow is:

  1. Run steps 0–4 once per FM, in the matching env, saving the compute_all_metrics output as a JSON / CSV.

  2. Load all of them into one DataFrame and call funky_heatmap once.

To show the pattern end-to-end inside a single notebook we substitute two PCA baselines for the other FMs. Replace those rows with the score dicts you saved from your scGPT / scConcept / SCimilarity runs.

from sklearn.decomposition import PCA
from sc_reconstruction.metrics import funky_heatmap

scored = {"SE": scores_se}

# Two PCA baselines as stand-ins for the other FMs.
for d in (32, 128):
    pca = PCA(n_components=d, random_state=0).fit(adata_train.X)
    X_pca_test = pca.inverse_transform(pca.transform(adata_test.X)).astype(np.float32)
    X_pca_ctrl = pca.inverse_transform(pca.transform(adata_ctrl.X)).astype(np.float32)
    adata_pca_test = AnnData(X_pca_test, var=adata_test.var, obs=adata_test.obs.copy())
    adata_pca_ctrl = AnnData(X_pca_ctrl, var=adata_ctrl.var, obs=adata_ctrl.obs.copy())
    scored[f"PCA-{d}"] = compute_all_metrics(
        adata_test, adata_pca_test,
        s_genes=s_genes, g2m_genes=g2m_genes,
        progeny_model=progeny,
        cytokine_dict=cytokines,
        deg_refs=(adata_ctrl, adata_pca_ctrl),
        min_cells=5,
    )

raw = pd.DataFrame(scored).T
print(raw.to_string(float_format=lambda x: f"{x:.4f}"))

ax = funky_heatmap(raw, title="FM × decoder reconstruction")

Bringing your own FM#

Implement two methods and the rest of the pipeline is unchanged:

class MyFM:
    def set_genes(self, genes: list[str]) -> None:
        """Record the gene panel — needed by some FMs for tokenization."""
        ...

    def get_latent_representation(self, adata: AnnData) -> np.ndarray:
        """Return an (n_cells, d) embedding for adata."""
        ...

Then plug it into step 1:

fm = MyFM()
fm.set_genes(adata_train.var_names.tolist())
adata_train.obsm["X_fm"] = fm.get_latent_representation(adata_train)
adata_test .obsm["X_fm"] = fm.get_latent_representation(adata_test)
# steps 2–5 run unchanged.

The four reference wrappers in sc_reconstruction.models follow exactly this contract.