Latent-shift reconstruction (perturbation prediction)#
Latent-shift reconstruction predicts the effect of a perturbation in embedding space and decodes back to gene expression. Two stages:
Predict the perturbed embedding — a conditional model trained on
(control_embedding, perturbed_embedding, perturbation_descriptor)triples. Same control + different perturbation should produce different predictions. The paper uses CellFlow (JAX flow matching) and STATE (PyTorch transformer over cell sets) for this step.Decode — a frozen pretrained MLP decoder turns the predicted embedding back into gene expression. Same decoder as in
fm.ipynb.
The model contract#
A perturbation predictor needs two methods, both conditioned on a perturbation descriptor (e.g. an ESM2 cytokine embedding, a drug SMILES embedding, a sgRNA one-hot):
train(ctrl_emb, pert_emb, pert_cov, ...)— fit on paired control/perturbed embeddings, one perturbation descriptor per pair.predict(ctrl_emb, pert_cov) -> pred_emb— given a held-out control and a perturbation descriptor, return the predicted perturbed embedding.
Without the covariate the model can only learn one global shift — which is what an unconditional predictor would do.
import sys
from pathlib import Path
# Prefer the public src over any pip-installed sc_reconstruction (which may
# come from a stale private clone in the active env).
sys.path.insert(0, str(Path("..").resolve() / "src"))
from typing import Protocol, runtime_checkable
import numpy as np
@runtime_checkable
class PerturbationPredictor(Protocol):
"""Minimum interface for a latent-shift model in ReconEval."""
def train(self, ctrl_emb: np.ndarray, pert_emb: np.ndarray,
pert_cov: np.ndarray, **kwargs) -> None: ...
def predict(self, ctrl_emb: np.ndarray, pert_cov: np.ndarray) -> np.ndarray: ...
Reference implementation: conditional MLP#
Concat (ctrl_emb, pert_cov), pass through an MLP, output pred_emb. This is the smallest
thing that satisfies the conditional Protocol. The paper alternatives below are listed but
commented out — they live in their own envs and need 500 K-iter training to converge.
import warnings
import torch
import torch.nn as nn
warnings.filterwarnings("ignore")
class MLPPerturbationPredictor:
"""Conditional `(ctrl_emb, pert_cov) -> pert_emb` MLP. Tiny stand-in for CellFlow/STATE."""
def __init__(self, embed_dim: int, cov_dim: int, n_hidden: int = 256, lr: float = 1e-3):
in_dim = embed_dim + cov_dim
self.net = nn.Sequential(
nn.Linear(in_dim, n_hidden), nn.GELU(),
nn.Linear(n_hidden, n_hidden), nn.GELU(),
nn.Linear(n_hidden, embed_dim),
)
self.lr = lr
def _cat(self, ctrl_emb, pert_cov):
ctrl = torch.as_tensor(ctrl_emb, dtype=torch.float32)
cov = torch.as_tensor(pert_cov, dtype=torch.float32)
return torch.cat([ctrl, cov], dim=1)
def train(self, ctrl_emb, pert_emb, pert_cov, *, epochs=300, batch_size=128):
x = self._cat(ctrl_emb, pert_cov)
y = torch.as_tensor(pert_emb, dtype=torch.float32)
opt = torch.optim.Adam(self.net.parameters(), lr=self.lr)
loss_fn = nn.MSELoss()
self.net.train()
for _ in range(epochs):
perm = torch.randperm(len(x))
for i in range(0, len(x), batch_size):
idx = perm[i:i + batch_size]
opt.zero_grad()
loss = loss_fn(self.net(x[idx]), y[idx])
loss.backward()
opt.step()
def predict(self, ctrl_emb, pert_cov):
self.net.eval()
with torch.no_grad():
return self.net(self._cat(ctrl_emb, pert_cov)).cpu().numpy()
# --- the paper alternatives (each lives in its own conda env / repo): ---
#
# from cellflow.model import CellFlow
# # JAX flow matching, OT solver. Takes (adata, perturbation_covariates dict) — see
# # experiments/03_latent_shift/codes/train_cf.py for the full pipeline (ESM2 cytokine
# # embeddings, donor + cell_type covariates, ~500 K iters on GPU).
#
# from state.tx.models.state_transition import StateTransitionPerturbationModel
# # PyTorch transformer over cell sets (cell_set_len=512). Takes a frozen MLP
# # decoder injected via FrozenMLPDecoderAdapter — see
# # experiments/03_latent_shift/codes/train_st.py.
Synthetic data with multiple perturbations#
We synthesise five perturbations, each producing a different latent direction, with a small perturbation descriptor per condition (a 16-d vector standing in for ESM2 / drug embeddings). For the tutorial we split cells within each perturbation into train (80%) and test (20%): the model sees every perturbation during training, but is evaluated on held-out cells.
The paper’s harder setup holds out whole perturbations — see
experiments/03_latent_shift/codes/eval_cf.pyfor that pipeline. With only K=5 random covariates the conditional MLP cannot generalise to an unseen perturbation, so we use the cell-level split here to demonstrate the conditional mechanism working.
import pandas as pd
from anndata import AnnData
rng = np.random.default_rng(0)
embed_dim, cov_dim, n_genes = 128, 16, 600
K = 5 # number of perturbations
n_per_pert = 500 # cells per perturbation
val_frac = 0.2 # 20% of each perturbation's cells held out
# A descriptor per perturbation (random for tutorial; the paper uses ESM2 embeddings).
cov_book = rng.normal(0, 1, size=(K, cov_dim)).astype(np.float32)
# Each perturbation's true effect: a linear function of its covariate so the model has
# something learnable. shift_k = W @ cov_book[k]
W = rng.normal(0, 0.3, size=(embed_dim, cov_dim)).astype(np.float32)
def _pair(k, n):
ctrl = rng.normal(0, 1, size=(n, embed_dim)).astype(np.float32)
shift = cov_book[k] @ W.T
pert = ctrl + shift + 0.1 * rng.normal(size=ctrl.shape).astype(np.float32)
cov = np.broadcast_to(cov_book[k], (n, cov_dim)).copy()
return ctrl, pert, cov
# 80/20 split inside each perturbation: every pert is seen at training, evaluation uses
# held-out cells of the same perts.
train_ctrl_l, train_pert_l, train_cov_l = [], [], []
test_ctrl_l, test_pert_l, test_cov_l = [], [], []
n_val = int(n_per_pert * val_frac)
for k in range(K):
ctrl, pert, cov = _pair(k, n_per_pert)
perm = rng.permutation(n_per_pert)
train_ctrl_l.append(ctrl[perm[n_val:]]); train_pert_l.append(pert[perm[n_val:]]); train_cov_l.append(cov[perm[n_val:]])
test_ctrl_l .append(ctrl[perm[:n_val]]); test_pert_l .append(pert[perm[:n_val]]); test_cov_l .append(cov[perm[:n_val]])
train_ctrl = np.vstack(train_ctrl_l); train_pert = np.vstack(train_pert_l); train_cov = np.vstack(train_cov_l)
test_ctrl = np.vstack(test_ctrl_l); test_pert = np.vstack(test_pert_l); test_cov = np.vstack(test_cov_l)
print(f"train: {train_ctrl.shape} ctrl, {train_pert.shape} pert, {train_cov.shape} cov ({K} perturbations)")
print(f"test : {test_ctrl.shape} ctrl, {test_pert.shape} pert, {test_cov.shape} cov (held-out cells)")
train: (2000, 128) ctrl, (2000, 128) pert, (2000, 16) cov (5 perturbations)
test : (500, 128) ctrl, (500, 128) pert, (500, 16) cov (held-out cells)
Train and predict the held-out perturbation#
Same .train(...) and .predict(...) shape any conditional predictor must expose.
pred = MLPPerturbationPredictor(embed_dim=embed_dim, cov_dim=cov_dim, n_hidden=256, lr=1e-3)
assert isinstance(pred, PerturbationPredictor)
pred.train(train_ctrl, train_pert, train_cov, epochs=300, batch_size=128)
pred_pert_emb = pred.predict(test_ctrl, test_cov)
print(f"held-out predicted embeddings: {pred_pert_emb.shape}")
held-out predicted embeddings: (500, 128)
Decode to gene space#
Each embedding is mapped back to expression by a frozen pretrained decoder. The paper plugs in
FrozenMLPDecoderAdapter from sc_reconstruction.adapters.state_decoder_adapter (the MLP
decoder trained in fm.ipynb). For tutorial brevity we train a tiny MLP decoder on the same
synthetic data so the gene-space numbers are meaningful.
# A small MLP decoder: emb -> expression. In the paper this is FrozenMLPDecoderAdapter loaded
# from a pretrained MLP-decoder checkpoint produced by fm.ipynb (step 2).
class TinyDecoder(nn.Module):
def __init__(self, emb_dim, n_genes, hidden=256):
super().__init__()
self.net = nn.Sequential(
nn.Linear(emb_dim, hidden), nn.GELU(),
nn.Linear(hidden, n_genes), nn.ReLU(), # ReLU = non-negative expression
)
def forward(self, x):
return self.net(x)
# Build (emb, X) training pairs from the existing training data — control + perturbed both fine.
all_emb = np.vstack([train_ctrl, train_pert])
decoder_W_true = rng.normal(0, 1 / np.sqrt(embed_dim), size=(embed_dim, n_genes)).astype(np.float32)
all_X = np.maximum(0, all_emb @ decoder_W_true + rng.normal(0, 0.1, size=(len(all_emb), n_genes))).astype(np.float32)
dec = TinyDecoder(embed_dim, n_genes)
opt = torch.optim.Adam(dec.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()
x = torch.as_tensor(all_emb); y = torch.as_tensor(all_X)
for _ in range(100):
perm = torch.randperm(len(x))
for i in range(0, len(x), 256):
idx = perm[i:i + 256]
opt.zero_grad(); loss = loss_fn(dec(x[idx]), y[idx]); loss.backward(); opt.step()
def decode(emb):
with torch.no_grad():
return dec(torch.as_tensor(emb, dtype=torch.float32)).cpu().numpy()
# Decode every relevant cell pool, wrap in AnnData.
var = pd.DataFrame(index=[f"G{i:03d}" for i in range(n_genes)])
ctrl = AnnData(decode(test_ctrl), var=var)
pert_true = AnnData(decode(test_pert), var=var)
pert_pred = AnnData(decode(pred_pert_emb), var=var)
# Keep embeddings on .obsm so KNN purity runs in latent space.
ctrl.obsm["X_fm"] = test_ctrl
pert_true.obsm["X_fm"] = test_pert
pert_pred.obsm["X_fm"] = pred_pert_emb
Score the held-out cells#
KNN purity is the headline metric: of the predicted cell’s nearest neighbours in the pool
(true-perturbed ∪ control), how many are true-perturbed (1.0 = perfect; 0.5 = random;
0.0 = lands in control). We measure it in latent space (use_rep="X_fm") because KNN in raw
expression is dominated by mean/variance differences.
The test cells were held out at training time, so this measures generalisation across cells of the same perturbations. The harder held-out perturbation test (paper setup) is referenced in the markdown above the synthetic-data cell.
from sc_reconstruction.metrics import (
metric_knn_purity, metric_r2, metric_mse, metric_energy_distance,
)
print("Gene-space metrics (predicted held-out vs. true held-out):")
print(f" R^2 = {metric_r2(pert_true, pert_pred):.4f}")
print(f" MSE = {metric_mse(pert_true, pert_pred):.4f}")
print(f" Energy distance = {metric_energy_distance(pert_true, pert_pred):.4f}")
print("\nKNN purity (latent space — held-out cells):")
for k in (10, 20, 50):
p = metric_knn_purity(adata_pred=pert_pred,
adata_pert_true=pert_true,
adata_ctrl=ctrl,
k=k, use_rep="X_fm")
print(f" k={k:2d} : {p:.4f}")
Gene-space metrics (predicted held-out vs. true held-out):
R^2 = 0.9798
MSE = 0.0126
Energy distance = 0.0384
KNN purity (latent space — held-out cells):
k=10 : 0.9332
k=20 : 0.9615
k=50 : 0.9734
Bringing your own perturbation predictor#
Any class with the two conditional methods slots in:
class MyPredictor:
def train(self, ctrl_emb, pert_emb, pert_cov, **kw): ...
def predict(self, ctrl_emb, pert_cov) -> np.ndarray: ...
model = MyPredictor()
model.train(ctrl_emb, pert_emb, pert_cov)
pred_emb = model.predict(ctrl_test, pert_cov_test)
For the paper’s full pipelines see
experiments/03_latent_shift/codes/train_cf.py and
experiments/03_latent_shift/codes/train_st.py.