Source code for sc_reconstruction.decoders.recontransformer

import math
import os
import warnings
from collections.abc import Iterable

import lightning as L
import numpy as np
import torch
import torch.nn as nn

import argparse
from pathlib import Path

import anndata as ad
from torch.utils.data import DataLoader, Dataset, random_split
from sc_reconstruction.decoders._base_decoder import BaseReconstructionDecoder
from concept.decoder.decoder_model import TransformerDecoderModel
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.plugins.environments import LightningEnvironment


class ReconTransformerModule(L.LightningModule):
    """Lightning wrapper that adapts TransformerDecoderModel to the ReconMLPDecoder API.

    - ``forward(z, *cat_list) -> (B, n_output)``
    - ``training_step`` expects batch dict with keys like ReconMLPDecoder:
      ``batch['x']``, ``batch['z']``, optional ``batch['batch_onehot']``.
    - Automatically constructs gene_indices ``(0..n_output-1)`` and expands to batch.
    """

    def __init__(
        self,
        n_input: int,
        n_output: int,
        n_cat_list: Iterable[int] | None = None,
        dim_model: int = 128,
        num_head: int = 8,
        dim_hid: int = 256,
        nlayers: int = 6,
        dropout: float = 0.1,
        distribution: str = "normal",
        learning_rate: float = 1e-4,
        reduce_lr_on_plateau: bool = False,
        lr_factor: float = 0.6,
        lr_patience: int = 5,
        lr_threshold: float = 1e-3,
        lr_min: float = 0.0,
        library_size_mode: str = "none",
        # optional optimizer knobs
        weight_decay: float = 0.0,
        use_adamw: bool = False,
        **kwargs,
    ):
        super().__init__()
        self.save_hyperparameters()

        self.n_input = n_input
        self.n_output = n_output
        self.n_cat_list = n_cat_list
        self.distribution = distribution
        self.learning_rate = learning_rate
        self.reduce_lr_on_plateau = reduce_lr_on_plateau
        self.lr_factor = lr_factor
        self.lr_patience = lr_patience
        self.lr_threshold = lr_threshold
        self.lr_min = lr_min
        self.library_size_mode = library_size_mode
        self.weight_decay = weight_decay
        self.use_adamw = use_adamw
        self.wandb_run_id = None
        self.resume_from_checkpoint = None
        total_input_dim = n_input
        if n_cat_list is not None:
            total_input_dim += sum(n_cat_list)

        self.transformer = TransformerDecoderModel(
            num_genes=n_output,
            cell_emb_dim=total_input_dim,
            dim_model=dim_model,
            num_head=num_head,
            dim_hid=dim_hid,
            nlayers=nlayers,
            dropout=dropout,
            lr=learning_rate,            
            weight_decay=weight_decay,  
        )

        self.register_buffer("gene_index_base", torch.arange(n_output, dtype=torch.long), persistent=True)

        if self.distribution == "normal_mle_gene":
            initial = torch.zeros(n_output).normal_(mean=0.0, std=0.1)
            self.px_r = nn.Parameter(initial)

    def _expand_gene_indices(self, batch_size: int, device: torch.device) -> torch.Tensor:
        base = self.gene_index_base.to(device=device)
        return base.unsqueeze(0).expand(batch_size, -1)

    def forward(self, z: torch.Tensor, *cat_list: torch.Tensor) -> torch.Tensor:
        # Match ReconMLPDecoder: concatenate conditional one-hots if present
        if cat_list and len(cat_list) > 0:
            z_cat = torch.cat([z] + list(cat_list), dim=1)
        else:
            z_cat = z

        if self.library_size_mode != "none":
            warnings.warn(
                "library_size_mode is set but not implemented for decoding only. "
                "Ignoring library size mode.",
                stacklevel=2,
            )

        gene_indices = self._expand_gene_indices(z_cat.shape[0], z_cat.device)
        preds = self.transformer(z_cat, gene_indices)  # (B, n_output)
        return preds

    def _shared_step(self, batch, stage: str = "train"):
        x = batch["x"]
        z = batch["z"]

        if "batch_onehot" in batch and batch["batch_onehot"] is not None:
            reconstruction = self.forward(z, batch["batch_onehot"])
        else:
            reconstruction = self.forward(z)

        loss = self.compute_loss(x, reconstruction)
        self.log_metrics({"loss": loss}, stage)
        return loss

    def training_step(self, batch, batch_idx):
        opt = self.optimizers()
        if opt is not None:
            lr = opt.param_groups[0].get("lr", self.learning_rate)
            self.log("lr", lr, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
        return self._shared_step(batch, "train")

    def validation_step(self, batch, batch_idx):
        return self._shared_step(batch, "val")

    def test_step(self, batch, batch_idx):
        return self._shared_step(batch, "test")

    def compute_loss(self, x: torch.Tensor, reconstruction: torch.Tensor) -> torch.Tensor:
        if self.distribution == "normal":
            return nn.MSELoss()(reconstruction, x)
        elif self.distribution == "huber":
            return nn.HuberLoss()(reconstruction, x)
        elif self.distribution == "l1":
            return nn.L1Loss()(reconstruction, x)
        elif self.distribution == "normal_mle_fixed":
            sigma = torch.tensor(1e-1, device=x.device)
            dist = torch.distributions.Normal(loc=reconstruction, scale=sigma)
            log_prob = dist.log_prob(x)
            return (-log_prob.sum(dim=1)).mean()
        elif self.distribution == "normal_mle_gene":
            sigma = torch.exp(torch.clamp(self.px_r, min=-10.0, max=10.0))  # (n_output,)
            sigma_batch = sigma.unsqueeze(0)  # (1, n_output)
            dist = torch.distributions.Normal(loc=reconstruction, scale=sigma_batch)
            log_prob = dist.log_prob(x)
            return (-log_prob.sum(dim=1)).mean()
        else:
            raise ValueError(f"Unsupported distribution: {self.distribution}")

    def configure_optimizers(self):
        if self.use_adamw:
            optimizer = torch.optim.AdamW(
                self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
            )
        else:
            optimizer = torch.optim.Adam(
                self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
            )

        if self.reduce_lr_on_plateau:
            scheduler = {
                "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
                    optimizer,
                    patience=self.lr_patience,
                    factor=self.lr_factor,
                    threshold=self.lr_threshold,
                    min_lr=self.lr_min,
                    threshold_mode="abs",
                ),
                "monitor": "val/loss_epoch",
                "interval": "epoch",
                "frequency": 1,
                "name": "lr_scheduler",
            }
            return {"optimizer": optimizer, "lr_scheduler": scheduler}

        return optimizer

    def log_metrics(self, loss_dict, stage: str = "train"):
        for key, value in loss_dict.items():
            self.log(f"{stage}/{key}", value, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)

    def on_save_checkpoint(self, checkpoint):
        """Save W&B run ID when checkpoint is saved."""
        super().on_save_checkpoint(checkpoint)
        
        # Try to get W&B run ID from logger
        for logger in self.trainer.loggers:
            if isinstance(logger, WandbLogger):
                if hasattr(logger, 'experiment') and logger.experiment is not None:
                    self.wandb_run_id = logger.experiment.id
                    break
        
        if self.wandb_run_id:
            checkpoint['wandb_run_id'] = self.wandb_run_id
            checkpoint['wandb_name'] = getattr(logger, 'name', None)
            checkpoint['wandb_project'] = getattr(logger.experiment, 'project', None) if hasattr(logger, 'experiment') and logger.experiment else None

    def on_load_checkpoint(self, checkpoint):
        super().on_load_checkpoint(checkpoint)
        
        self.wandb_run_id = checkpoint.get('wandb_run_id')
        self.wandb_name = checkpoint.get('wandb_name')
        self.wandb_project = checkpoint.get('wandb_project')

[docs] class ReconTransformerDecoder(BaseReconstructionDecoder): """Recon-style wrapper mirroring :class:`ReconMLPDecoder`. - ``self.module`` is a :class:`~lightning.pytorch.LightningModule`. - ``train(datamodule, **trainer_kwargs)`` — fit on a datamodule. - ``decode(z_numpy, *cat_numpy) -> numpy`` — invert latent to expression. - ``save / load`` — checkpoint round-trip. """
[docs] def __init__( self, n_input: int, n_output: int, n_cat_list: Iterable[int] | None = None, # transformer config dim_model: int = 128, num_head: int = 8, dim_hid: int = 256, nlayers: int = 6, dropout: float = 0.1, # training/loss config distribution: str = "normal", learning_rate: float = 1e-4, reduce_lr_on_plateau: bool = False, lr_factor: float = 0.6, lr_patience: int = 5, lr_threshold: float = 1e-3, lr_min: float = 0.0, library_size_mode: str = "none", weight_decay: float = 0.0, use_adamw: bool = False, **kwargs, ): self.model_params = { "n_input": n_input, "n_output": n_output, "n_cat_list": n_cat_list, "dim_model": dim_model, "num_head": num_head, "dim_hid": dim_hid, "nlayers": nlayers, "dropout": dropout, "distribution": distribution, "learning_rate": learning_rate, "reduce_lr_on_plateau": reduce_lr_on_plateau, "lr_factor": lr_factor, "lr_patience": lr_patience, "lr_threshold": lr_threshold, "lr_min": lr_min, "library_size_mode": library_size_mode, "weight_decay": weight_decay, "use_adamw": use_adamw, } print( "ReconTransformerDecoder params: " f"n_input={n_input}, n_output={n_output}, " f"dim_model={dim_model}, nlayers={nlayers}, num_head={num_head}" ) self.module = ReconTransformerModule(**self.model_params)
def decode(self, z: np.ndarray, *cat_list: np.ndarray, decode_batch_size: int = 256) -> np.ndarray: device = next(self.module.parameters()).device self.module.eval() # flash_attn requires fp16/bf16 regardless of stored weight dtype # (Lightning bf16-mixed trains with autocast but saves weights as fp32) autocast_ctx = torch.autocast(device_type=device.type, dtype=torch.bfloat16) parts = [] with torch.no_grad(), autocast_ctx: for start in range(0, len(z), decode_batch_size): z_chunk = torch.from_numpy(z[start:start + decode_batch_size]).float().to(device) if cat_list and len(cat_list) > 0: cat_chunks = [ torch.from_numpy(cat[start:start + decode_batch_size]).float().to(device) for cat in cat_list ] out = self.module(z_chunk, *cat_chunks) else: out = self.module(z_chunk) parts.append(out.float().cpu()) return torch.cat(parts, dim=0).numpy() def train(self, datamodule: L.LightningDataModule = None, **train_kwargs) -> None: max_epochs = train_kwargs.pop('max_epochs', 400) logger = train_kwargs.pop('logger', None) callbacks = train_kwargs.pop('callbacks', []) precision = train_kwargs.pop('precision', 'bf16-mixed') strategy = train_kwargs.pop('strategy', 'auto') devices = train_kwargs.pop('devices', 'auto') num_nodes = train_kwargs.pop('num_nodes', 1) self.resume_from_checkpoint = train_kwargs.pop('resume_from_checkpoint', None) if precision == "bf16-mixed" and not torch.cuda.is_bf16_supported(): print("Warning: bfloat16 not supported on this hardware, using fp16 mixed precision") precision = "16-mixed" plugins = [] if strategy == "ddp": plugins.append(LightningEnvironment()) default_trainer_kwargs = { "precision": precision, "max_epochs": max_epochs, "logger": logger, "callbacks": callbacks, "strategy": strategy, "devices": devices, "num_nodes": num_nodes, "plugins": plugins, } num_devices = devices if isinstance(devices, int) else 1 if strategy == "ddp" and num_devices > 1 and datamodule is not None: datamodule.prepare_data() datamodule.setup(stage="fit") batch_size = getattr(datamodule, 'minibatch_size', 256) chunks_per_worker = getattr(datamodule, 'chunks_per_worker', 5) num_workers = getattr(datamodule, 'num_workers', 1) for split_name, data_attr, limit_key in [ ("train", "train_data", "limit_train_batches"), ("val", "val_data", "limit_val_batches"), ]: data = getattr(datamodule, data_attr, None) if data is None: continue n_samples = data.shape[0] n_chunks = data.numblocks[0] n_groups = math.ceil(n_chunks / chunks_per_worker) groups_per_rank = n_groups // num_devices effective_groups = (groups_per_rank // num_workers) * num_workers effective_chunks = effective_groups * chunks_per_worker avg_chunk_size = n_samples / n_chunks effective_samples = effective_chunks * avg_chunk_size limit = int(effective_samples / batch_size * 0.98) default_trainer_kwargs[limit_key] = limit print(f"DDP: {limit_key}={limit} " f"({split_name}: {n_chunks} chunks -> {n_groups} groups -> " f"{groups_per_rank}/rank -> {effective_groups} effective/rank, " f"~{int(effective_samples)} samples/rank, batch={batch_size}, *0.98)") if self.resume_from_checkpoint: print(f"Trainer configured to resume from: {self.resume_from_checkpoint}") default_trainer_kwargs.update(train_kwargs) trainer = L.Trainer(**default_trainer_kwargs) print(f"Starting training with precision: {precision}, strategy: {strategy}, devices: {devices}") trainer.fit(self.module, datamodule=datamodule, ckpt_path=self.resume_from_checkpoint) def save(self, path: str): if not path.endswith(".pt") and not path.endswith(".ckpt"): path = path + ".pt" os.makedirs(os.path.dirname(path), exist_ok=True) # Save state_dict to match ReconMLPDecoder.save behaviour. torch.save(self.module.state_dict(), path) print(f"Model weights saved to {path}") def load(self, path: str, map_location=None) -> None: """Load weights into the existing module. - ``.pt``: state_dict, or a whole module if that's what was saved - ``.ckpt``: Lightning checkpoint """ if path.endswith(".pt"): obj = torch.load(path, map_location=map_location, weights_only=False) if isinstance(obj, dict): missing, unexpected = self.module.load_state_dict(obj, strict=False) if missing or unexpected: print(f"Warning: missing={missing}, unexpected={unexpected}") elif isinstance(obj, nn.Module): self.module = obj else: raise TypeError(f"Unsupported .pt contents: {type(obj)}") elif path.endswith(".ckpt"): self.module = ReconTransformerModule.load_from_checkpoint(path, map_location=map_location) else: raise ValueError(f"Unsupported checkpoint extension: {path}") print(f"Transformer decoder loaded from {path}")