from __future__ import annotations
import lightning as L
import torch.nn as nn
import torch
import numpy as np
import warnings
from collections.abc import Iterable
import os
class BaseDecoder(nn.Module):
def __init__(self,
n_latent: int,
n_hidden: list,
output_dim: int,
output_activation: nn.Module):
super().__init__()
layers = []
prev_dim = n_latent
for hidden_dim in n_hidden:
layers.extend([
nn.Linear(prev_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Dropout(0)
])
prev_dim = hidden_dim
layers.append(nn.Linear(prev_dim, output_dim))
self.decoder = nn.Sequential(*layers)
self.output_activation = output_activation
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, library_size=None):
reconstruction = self.decoder(x)
if library_size is not None:
reconstruction = self.softmax(reconstruction) * library_size
else:
reconstruction = self.output_activation(reconstruction)
return reconstruction
class MLPDecoder(L.LightningModule):
def __init__(
self,
n_input: int,
n_output: int,
n_hidden: list,
n_cat_list: Iterable[int] = None,
distribution: str = 'normal',
learning_rate: float = 0.001,
reduce_lr_on_plateau: bool = False,
lr_factor: float = 0.6,
lr_patience: int = 5,
lr_threshold: float = 1e-3,
lr_min: float = 0.0,
library_size_mode: str = 'none',
decoder_output_activation: str = None,
**kwargs,
):
super().__init__()
self.save_hyperparameters()
self.n_input = n_input
self.n_output = n_output
self.n_hidden = n_hidden
self.n_cat_list = n_cat_list
self.distribution = distribution
self.learning_rate = learning_rate
self.reduce_lr_on_plateau = reduce_lr_on_plateau
self.lr_factor = lr_factor
self.lr_patience = lr_patience
self.lr_threshold = lr_threshold
self.lr_min = lr_min
self.library_size_mode = library_size_mode
self.decoder_output_activation = decoder_output_activation
print(f"Initializing MLPDecoder with n_input={n_input}, n_output={n_output}, n_hidden={n_hidden}")
total_input_dim = n_input
if n_cat_list is not None:
total_input_dim += sum(n_cat_list)
activation_fn = self._get_activation_fn(decoder_output_activation)
self.decoder = BaseDecoder(n_latent = total_input_dim,
n_hidden = n_hidden,
output_dim = n_output,
output_activation = activation_fn
)
# For normal_mle_gene distribution
if self.distribution == 'normal_mle_gene':
initial = torch.zeros(n_output).normal_(mean=0.0, std=0.1)
self.px_r = nn.Parameter(initial)
def _get_activation_fn(self, name: str | None) -> nn.Module:
"""Map string name to activation function module"""
if name is None:
return nn.Identity()
name = name.lower()
if name == 'linear' or name == 'identity':
return nn.Identity()
elif name == 'softmax':
return nn.Softmax(dim=-1)
elif name == 'relu':
return nn.ReLU()
elif name == 'sigmoid':
return nn.Sigmoid()
elif name == 'tanh':
return nn.Tanh()
elif name == 'softplus':
return nn.Softplus()
else:
raise ValueError(f"Unsupported activation: {name}")
def forward(self, z: torch.Tensor, *cat_list: torch.Tensor) -> torch.Tensor:
"""Forward pass - supports both unconditional and conditional decoding"""
if cat_list and len(cat_list) > 0:
z_cat = torch.cat([z] + list(cat_list), dim=1)
else:
z_cat = z
reconstruction = self.decoder(z_cat)
if self.library_size_mode != 'none':
warnings.warn("library_size_mode is set but not implemented for decoding only. Ignoring library size mode.")
return reconstruction
def _shared_step(self, batch, stage='train'):
x = batch['x']
z = batch['z']
# Handle conditional vs unconditional
if 'batch_onehot' in batch and batch['batch_onehot'] is not None:
batch_onehot = batch['batch_onehot']
reconstruction = self.forward(z, batch_onehot)
else:
reconstruction = self.forward(z)
loss = self.compute_loss(x, reconstruction)
self.log_metrics({'loss': loss}, stage)
return loss
def training_step(self, batch, batch_idx):
lr = self.optimizers().param_groups[0]['lr']
self.log('lr', lr, on_step=True, on_epoch=True, prog_bar=True)
return self._shared_step(batch, 'train')
def validation_step(self, batch, batch_idx):
return self._shared_step(batch, 'val')
def test_step(self, batch, batch_idx):
return self._shared_step(batch, 'test')
def compute_loss(self, x, reconstruction):
"""Reconstruction loss matching ``Autoencoder.compute_loss`` in reconae.py."""
if self.distribution == 'normal':
return nn.MSELoss()(reconstruction, x)
elif self.distribution == 'huber':
return nn.HuberLoss()(reconstruction, x)
elif self.distribution == 'l1':
return nn.L1Loss()(reconstruction, x)
elif self.distribution == 'normal_mle_fixed':
sigma = torch.tensor(1e-1, device=x.device)
dist = torch.distributions.Normal(loc=reconstruction, scale=sigma)
log_prob = dist.log_prob(x)
neg_log_likelihood = -log_prob.sum(dim=1).mean()
return neg_log_likelihood
elif self.distribution == 'normal_mle_gene':
sigma = torch.exp(torch.clamp(self.px_r, min=-10.0, max=10.0))
sigma_batch = sigma.unsqueeze(0)
dist = torch.distributions.Normal(loc=reconstruction, scale=sigma_batch)
log_prob = dist.log_prob(x)
neg_log_likelihood = -log_prob.sum(dim=1).mean()
return neg_log_likelihood
else:
raise ValueError(f"Unsupported distribution: {self.distribution}")
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
if self.reduce_lr_on_plateau:
scheduler = {
'scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
patience=self.lr_patience,
factor=self.lr_factor,
threshold=self.lr_threshold,
min_lr=self.lr_min,
threshold_mode="abs"
),
'monitor': 'val/loss_epoch',
'interval': 'epoch',
'frequency': 1,
'name': 'lr_scheduler'
}
return {
"optimizer": optimizer,
"lr_scheduler": scheduler
}
else:
return optimizer
def log_metrics(self, loss_dict, stage='train'):
for key, value in loss_dict.items():
self.log(f'{stage}/{key}', value, on_step=True, on_epoch=True, prog_bar=True)
from sc_reconstruction.decoders._base_decoder import BaseReconstructionDecoder
[docs]
class ReconMLPDecoder(BaseReconstructionDecoder):
[docs]
def __init__(
self,
n_input: int,
n_output: int,
n_cat_list: Iterable[int] = None,
n_layers: int = 1,
n_hidden: int = 128,
distribution: str = 'normal',
learning_rate: float = 0.001,
reduce_lr_on_plateau: bool = False,
lr_factor: float = 0.6,
lr_patience: int = 5,
lr_threshold: float = 1e-3,
lr_min: float = 0.0,
library_size_mode: str = 'none',
decoder_output_activation: str = None,
**kwargs
):
# Build hidden layers configuration
hidden_dims = [n_hidden] * n_layers
self.model_params = {
"n_input": n_input,
"n_output": n_output,
"n_hidden": hidden_dims,
"n_cat_list": n_cat_list,
"distribution": distribution,
"learning_rate": learning_rate,
"reduce_lr_on_plateau": reduce_lr_on_plateau,
"lr_factor": lr_factor,
"lr_patience": lr_patience,
"lr_threshold": lr_threshold,
"lr_min": lr_min,
"library_size_mode": library_size_mode,
"decoder_output_activation": decoder_output_activation,
}
print(f"ReconMLPDecoder params: n_input={n_input}, n_output={n_output}, n_hidden={hidden_dims}")
self.module = MLPDecoder(**self.model_params)
def train(self, datamodule: L.LightningDataModule = None, **train_kwargs) -> None:
trainer = L.Trainer(**train_kwargs)
trainer.fit(self.module, datamodule=datamodule)
def decode(self, z: np.ndarray, *cat_list: np.ndarray) -> np.ndarray:
"""Decode latent representations to data space"""
device = next(self.module.parameters()).device
self.module.eval()
with torch.no_grad():
z_tensor = torch.from_numpy(z).float().to(device)
if cat_list and len(cat_list) > 0:
cat_tensors = [torch.from_numpy(cat).float().to(device) for cat in cat_list]
reconstruction = self.module(z_tensor, *cat_tensors)
else:
reconstruction = self.module(z_tensor)
return reconstruction.cpu().numpy()
# def predict(self, z: np.ndarray, *cat_list: np.ndarray) -> np.ndarray:
# """Alias for decode"""
# return self.decode(z, *cat_list)
def save(self, path: str):
if not path.endswith('.pt') and not path.endswith('.ckpt'):
path = path + '.pt'
os.makedirs(os.path.dirname(path), exist_ok=True)
torch.save(self.module.state_dict(), path)
print(f"Model weights saved to {path}")
def load(self, path: str, map_location=None) -> None:
"""Load the model weights"""
if path.endswith('.pt'):
self.module = torch.load(path, map_location=map_location, weights_only=False)
elif path.endswith('.ckpt'):
self.module = MLPDecoder.load_from_checkpoint(path, map_location=map_location)
print(f"MLP decoder loaded from {path}")