Source code for sc_reconstruction.adapters.state_decoder_adapter
"""Frozen MLP decoder adapter for State Transition evaluation.
Wraps a pre-trained MLPDecoder checkpoint so that it can be called as a
standard nn.Module in the inference / evaluation scripts. All parameters
are frozen (requires_grad=False) and the module is kept in eval mode.
"""
import torch
import torch.nn as nn
[docs]
class FrozenMLPDecoderAdapter(nn.Module):
"""Adapter that loads a frozen MLPDecoder from a Lightning checkpoint.
Parameters
----------
checkpoint_path : str
Path to the ``.ckpt`` file produced by Lightning's ``ModelCheckpoint``.
map_location : str or torch.device, optional
Device to map weights onto (default: cpu).
"""
[docs]
def __init__(self, checkpoint_path: str, map_location="cpu"):
super().__init__()
from sc_reconstruction.decoders.reconmlp import MLPDecoder
# weights_only=False: PyTorch 2.6 changed the default to True, which blocks
# omegaconf types stored in Lightning checkpoint hyperparameters.
# These are our own trusted checkpoints.
self.decoder = MLPDecoder.load_from_checkpoint(
checkpoint_path, map_location=map_location, weights_only=False
)
for p in self.decoder.parameters():
p.requires_grad = False
self.decoder.eval()
self._n_input = self.decoder.n_input
self._n_output = self.decoder.n_output
# keep eval mode even if .train() is called on a parent module
def train(self, mode=True):
return super().train(False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Decode embeddings to gene expression.
Accepts arbitrary leading batch dimensions, e.g. ``(B, S, D)``
from the ST model's set-level output.
"""
shape = x.shape
flat = x.reshape(-1, shape[-1])
out = self.decoder(flat)
return out.reshape(*shape[:-1], -1)
@property
def n_input(self) -> int:
return self._n_input
@property
def n_output(self) -> int:
return self._n_output
def gene_dim(self) -> int:
return self._n_output