End-to-end reconstruction#
How to plug a reconstruction model into the ReconEval benchmark. We use
an autoencoder (sc_reconstruction.models.reconae.ReconAE) as the
reference implementation, but anything satisfying the small Protocol
below works — your model does not need to inherit from anything we ship.
Outline
The contract a reconstruction model must satisfy (a 2-method
Protocol).Train an AE on the tutorial-sized LuCA panel.
Reconstruct and wrap the prediction back into an AnnData.
Hand the (true, reconstructed) pair to the metrics API — see
metrics.ipynbfor the per-metric walkthrough.
1. The model contract#
ReconEval needs exactly two things from your model:
train(...)— fit on some training data. Signature is left open; pass whatever your model needs.predict(X) -> X_recon— given an(n_cells, n_genes)numpy array, return the reconstruction with the same shape.
Anything implementing these two methods can be scored. We capture this as a runtime-checkable
Protocol so structural typing checks are enough — you do not need to import or inherit from a
ReconEval base class.
import sys
from pathlib import Path
# Prefer the public src over any pip-installed sc_reconstruction (which may
# come from a stale private clone in the active env).
sys.path.insert(0, str(Path("..").resolve() / "src"))
from typing import Protocol, runtime_checkable
import numpy as np
@runtime_checkable
class ReconstructionModel(Protocol):
"""Minimum interface to score a model with ReconEval.
``sc_reconstruction.models.{ReconAE, ReconPCA, ReconSCVI, ReconNLSCVI,
ReconMLSCVI}`` already satisfy this Protocol; convenience methods
(``get_latent_representation``, ``save``, ``load``) are available on
the reference implementations but are not required.
"""
def train(self, *args, **kwargs) -> None: ...
def predict(self, X: np.ndarray) -> np.ndarray: ...
2. Train: AE as the reference implementation#
ReconAE is a thin wrapper around a Lightning module. To train it we
need a LightningDataModule whose batches are dicts of the form
{'X': tensor}. The wrapper below is a stub; for paper-scale runs use
the chunked dataloaders in sc_reconstruction.dataloaders.
import warnings
import lightning as L
import pandas as pd
import torch
from anndata import AnnData
from torch.utils.data import DataLoader, Dataset
warnings.filterwarnings("ignore")
class _DictDataset(Dataset):
"""Tiny dataset that yields `{'X': row}` — what `Autoencoder.training_step` expects."""
def __init__(self, X):
self.X = torch.as_tensor(X, dtype=torch.float32)
def __len__(self):
return len(self.X)
def __getitem__(self, i):
return {"X": self.X[i]}
class InMemoryDataModule(L.LightningDataModule):
def __init__(self, X, batch_size=64, val_frac=0.1, seed=0):
super().__init__()
rng = np.random.default_rng(seed)
perm = rng.permutation(X.shape[0])
n_val = int(X.shape[0] * val_frac)
self._train = _DictDataset(X[perm[n_val:]])
self._val = _DictDataset(X[perm[:n_val]])
self.batch_size = batch_size
def train_dataloader(self):
return DataLoader(self._train, batch_size=self.batch_size, shuffle=True)
def val_dataloader(self):
return DataLoader(self._val, batch_size=self.batch_size)
Load the tutorial-sized LuCA panel (500 cells × 643 genes, log-normalised
expression on real gene symbols). Replace this with your own data —
anything with an .X expression matrix works.
import scanpy as sc
FROZEN = Path("../analysis/data/frozen")
adata = sc.read_h5ad(FROZEN / "luca_demo.h5ad")
n_cells, n_genes = adata.shape
# Reuse the var DataFrame so the recon AnnData shares gene names.
var = adata.var
print(adata)
AnnData object with n_obs × n_vars = 500 × 643
obs: 'sample', 'uicc_stage', 'ever_smoker', 'age', 'donor_id', 'origin', 'dataset', 'ann_fine', 'cell_type_predicted', 'doublet_status', 'leiden', 'n_genes_by_counts', 'total_counts', 'total_counts_mito', 'pct_counts_mito', 'ann_coarse', 'cell_type_tumor', 'tumor_stage', 'TP53_mutation', 'ALK_mutation', 'BRAF_mutation', 'ERBB2_mutation', 'KRAS_mutation', 'ROS_mutation', 'origin_fine', 'study', 'platform', 'cell_type_major', 'suspension_type', 'assay_ontology_term_id', 'cell_type_ontology_term_id', 'development_stage_ontology_term_id', 'disease_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'is_primary_data', 'organism_ontology_term_id', 'sex_ontology_term_id', 'tissue_ontology_term_id', 'tissue_type', 'EGFR_mutation', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'observation_joinid'
var: 'feature_name'
A note on library size#
ReconAE has three modes for handling library size (the total count per cell), set via
library_size_mode:
Mode |
What the decoder sees |
Use when |
|---|---|---|
|
Raw decoder output |
Data are already log-normalised, or you don’t care about modeling count totals. |
|
|
Raw count data where the empirical library size is reliable (Smart-seq, deep 10x). |
|
A second encoder |
You want the model to learn its own library normalisation (e.g. droplet data where library size is itself noisy). |
The choice affects the decoder’s output scale, not the latent z. Pair library_size_mode with a
matching distribution: "normal" / "huber" / "l1" for log-normalised data; "nb_gene"
for raw counts.
from sc_reconstruction.models.reconae import ReconAE
ae = ReconAE(
input_dim=n_genes,
n_hidden=[128, 64],
n_latent=16,
distribution="normal",
learning_rate=1e-3,
)
# `ReconAE` satisfies our 2-method Protocol — quick sanity check.
assert isinstance(ae, ReconstructionModel), "ReconAE should satisfy ReconstructionModel"
dm = InMemoryDataModule(adata.X, batch_size=64)
ae.train(datamodule=dm, max_epochs=200, accelerator="cpu",
enable_progress_bar=False, logger=False)
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
| Name | Type | Params | Mode
----------------------------------------------------------
0 | output_activation | Identity | 0 | train
1 | encoder | BaseEncoder | 92.1 K | train
2 | decoder | BaseDecoder | 92.7 K | train
----------------------------------------------------------
184 K Trainable params
0 Non-trainable params
184 K Total params
0.739 Total estimated model params size (MB)
24 Modules in train mode
0 Modules in eval mode
SLURM auto-requeueing enabled. Setting signal handlers.
`Trainer.fit` stopped: `max_epochs=200` reached.
3. Reconstruct#
predict returns a numpy array; wrap it in an AnnData with the same var_names so the metrics
API can score it against the truth.
X_recon = ae.predict(adata.X)
recon = AnnData(X_recon, var=var)
print(recon)
print(f"\ntruth mean(X) = {adata.X.mean():.3f}, var = {adata.X.var():.3f}")
print(f"recon mean(X) = {recon.X.mean():.3f}, var = {recon.X.var():.3f}")
AnnData object with n_obs × n_vars = 500 × 643
var: 'feature_name'
truth mean(X) = 0.095, var = 0.118
recon mean(X) = 0.085, var = 0.059
4. Score#
Hand the pair to the metrics API. See the metrics tutorial for the full walkthrough (statistical + biological + perturbational + rank-percentile aggregation); we only show the headline statistical metrics here.
from sc_reconstruction.metrics import compute_statistical_metrics
compute_statistical_metrics(adata, recon)
{'r2': 0.9821816682815552,
'mse': 0.05252765864133835,
'energy_distance': 0.8585710525512695}
Bringing your own model#
Write a class that exposes the two-method Protocol above and you can drop it in anywhere
ReconAE is used:
class MyModel:
def train(self, datamodule): ...
def predict(self, X): ... # returns ndarray with the same shape as X
model = MyModel()
model.train(datamodule=InMemoryDataModule(adata.X))
recon = AnnData(model.predict(adata.X), var=adata.var)
compute_all_metrics(adata, recon, ...) # see the metrics tutorial
The reference implementations in sc_reconstruction.models (ReconAE, ReconPCA,
ReconSCVI, ReconNLSCVI, ReconMLSCVI, and the foundation-model wrappers) all
follow the same shape.