End-to-end reconstruction#

How to plug a reconstruction model into the ReconEval benchmark. We use an autoencoder (sc_reconstruction.models.reconae.ReconAE) as the reference implementation, but anything satisfying the small Protocol below works — your model does not need to inherit from anything we ship.

Outline

  1. The contract a reconstruction model must satisfy (a 2-method Protocol).

  2. Train an AE on the tutorial-sized LuCA panel.

  3. Reconstruct and wrap the prediction back into an AnnData.

  4. Hand the (true, reconstructed) pair to the metrics API — see metrics.ipynb for the per-metric walkthrough.

1. The model contract#

ReconEval needs exactly two things from your model:

  • train(...) — fit on some training data. Signature is left open; pass whatever your model needs.

  • predict(X) -> X_recon — given an (n_cells, n_genes) numpy array, return the reconstruction with the same shape.

Anything implementing these two methods can be scored. We capture this as a runtime-checkable Protocol so structural typing checks are enough — you do not need to import or inherit from a ReconEval base class.

import sys
from pathlib import Path

# Prefer the public src over any pip-installed sc_reconstruction (which may
# come from a stale private clone in the active env).
sys.path.insert(0, str(Path("..").resolve() / "src"))

from typing import Protocol, runtime_checkable

import numpy as np


@runtime_checkable
class ReconstructionModel(Protocol):
    """Minimum interface to score a model with ReconEval.

    ``sc_reconstruction.models.{ReconAE, ReconPCA, ReconSCVI, ReconNLSCVI,
    ReconMLSCVI}`` already satisfy this Protocol; convenience methods
    (``get_latent_representation``, ``save``, ``load``) are available on
    the reference implementations but are not required.
    """

    def train(self, *args, **kwargs) -> None: ...

    def predict(self, X: np.ndarray) -> np.ndarray: ...

2. Train: AE as the reference implementation#

ReconAE is a thin wrapper around a Lightning module. To train it we need a LightningDataModule whose batches are dicts of the form {'X': tensor}. The wrapper below is a stub; for paper-scale runs use the chunked dataloaders in sc_reconstruction.dataloaders.

import warnings

import lightning as L
import pandas as pd
import torch
from anndata import AnnData
from torch.utils.data import DataLoader, Dataset

warnings.filterwarnings("ignore")


class _DictDataset(Dataset):
    """Tiny dataset that yields `{'X': row}` — what `Autoencoder.training_step` expects."""
    def __init__(self, X):
        self.X = torch.as_tensor(X, dtype=torch.float32)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, i):
        return {"X": self.X[i]}


class InMemoryDataModule(L.LightningDataModule):
    def __init__(self, X, batch_size=64, val_frac=0.1, seed=0):
        super().__init__()
        rng = np.random.default_rng(seed)
        perm = rng.permutation(X.shape[0])
        n_val = int(X.shape[0] * val_frac)
        self._train = _DictDataset(X[perm[n_val:]])
        self._val   = _DictDataset(X[perm[:n_val]])
        self.batch_size = batch_size

    def train_dataloader(self):
        return DataLoader(self._train, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self._val, batch_size=self.batch_size)

Load the tutorial-sized LuCA panel (500 cells × 643 genes, log-normalised expression on real gene symbols). Replace this with your own data — anything with an .X expression matrix works.

import scanpy as sc

FROZEN = Path("../analysis/data/frozen")
adata = sc.read_h5ad(FROZEN / "luca_demo.h5ad")
n_cells, n_genes = adata.shape

# Reuse the var DataFrame so the recon AnnData shares gene names.
var = adata.var

print(adata)
AnnData object with n_obs × n_vars = 500 × 643
    obs: 'sample', 'uicc_stage', 'ever_smoker', 'age', 'donor_id', 'origin', 'dataset', 'ann_fine', 'cell_type_predicted', 'doublet_status', 'leiden', 'n_genes_by_counts', 'total_counts', 'total_counts_mito', 'pct_counts_mito', 'ann_coarse', 'cell_type_tumor', 'tumor_stage', 'TP53_mutation', 'ALK_mutation', 'BRAF_mutation', 'ERBB2_mutation', 'KRAS_mutation', 'ROS_mutation', 'origin_fine', 'study', 'platform', 'cell_type_major', 'suspension_type', 'assay_ontology_term_id', 'cell_type_ontology_term_id', 'development_stage_ontology_term_id', 'disease_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'is_primary_data', 'organism_ontology_term_id', 'sex_ontology_term_id', 'tissue_ontology_term_id', 'tissue_type', 'EGFR_mutation', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'observation_joinid'
    var: 'feature_name'

A note on library size#

ReconAE has three modes for handling library size (the total count per cell), set via library_size_mode:

Mode

What the decoder sees

Use when

"none" (default)

Raw decoder output decoder(z) — no library-size term.

Data are already log-normalised, or you don’t care about modeling count totals.

"observed"

decoder(z, library_size = X.sum(axis=1)) — passes the cell’s observed total directly to the decoder, which scales the predicted mean by it.

Raw count data where the empirical library size is reliable (Smart-seq, deep 10x).

"modeled"

A second encoder l_encoder learns a one-dim library factor; the decoder gets decoder(z, library_size = exp(l_encoder(x))).

You want the model to learn its own library normalisation (e.g. droplet data where library size is itself noisy).

The choice affects the decoder’s output scale, not the latent z. Pair library_size_mode with a matching distribution: "normal" / "huber" / "l1" for log-normalised data; "nb_gene" for raw counts.

from sc_reconstruction.models.reconae import ReconAE

ae = ReconAE(
    input_dim=n_genes,
    n_hidden=[128, 64],
    n_latent=16,
    distribution="normal",
    learning_rate=1e-3,
)

# `ReconAE` satisfies our 2-method Protocol — quick sanity check.
assert isinstance(ae, ReconstructionModel), "ReconAE should satisfy ReconstructionModel"

dm = InMemoryDataModule(adata.X, batch_size=64)
ae.train(datamodule=dm, max_epochs=200, accelerator="cpu",
         enable_progress_bar=False, logger=False)
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
  | Name              | Type        | Params | Mode 
----------------------------------------------------------
0 | output_activation | Identity    | 0      | train
1 | encoder           | BaseEncoder | 92.1 K | train
2 | decoder           | BaseDecoder | 92.7 K | train
----------------------------------------------------------
184 K     Trainable params
0         Non-trainable params
184 K     Total params
0.739     Total estimated model params size (MB)
24        Modules in train mode
0         Modules in eval mode
SLURM auto-requeueing enabled. Setting signal handlers.
`Trainer.fit` stopped: `max_epochs=200` reached.

3. Reconstruct#

predict returns a numpy array; wrap it in an AnnData with the same var_names so the metrics API can score it against the truth.

X_recon = ae.predict(adata.X)
recon = AnnData(X_recon, var=var)

print(recon)
print(f"\ntruth  mean(X) = {adata.X.mean():.3f}, var = {adata.X.var():.3f}")
print(f"recon  mean(X) = {recon.X.mean():.3f}, var = {recon.X.var():.3f}")
AnnData object with n_obs × n_vars = 500 × 643
    var: 'feature_name'

truth  mean(X) = 0.095, var = 0.118
recon  mean(X) = 0.085, var = 0.059

4. Score#

Hand the pair to the metrics API. See the metrics tutorial for the full walkthrough (statistical + biological + perturbational + rank-percentile aggregation); we only show the headline statistical metrics here.

from sc_reconstruction.metrics import compute_statistical_metrics

compute_statistical_metrics(adata, recon)
{'r2': 0.9821816682815552,
 'mse': 0.05252765864133835,
 'energy_distance': 0.8585710525512695}

Bringing your own model#

Write a class that exposes the two-method Protocol above and you can drop it in anywhere ReconAE is used:

class MyModel:
    def train(self, datamodule): ...
    def predict(self, X): ...  # returns ndarray with the same shape as X

model = MyModel()
model.train(datamodule=InMemoryDataModule(adata.X))
recon = AnnData(model.predict(adata.X), var=adata.var)
compute_all_metrics(adata, recon, ...)  # see the metrics tutorial

The reference implementations in sc_reconstruction.models (ReconAE, ReconPCA, ReconSCVI, ReconNLSCVI, ReconMLSCVI, and the foundation-model wrappers) all follow the same shape.