Source code for sc_reconstruction.models.reconae

from __future__ import annotations

import lightning as L
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch
import warnings

'''
Autoencoder model for data reconstruction:
- Autoencoder class: Main model class inheriting from LightningModule.
- ReconAE class: Wrapper for reconstruction tasks inheriting from BaseReconstructionModel.
'''



from scvi.distributions import (
            NegativeBinomial
        )
        
class BaseEncoder(nn.Module):
    def __init__(self, input_dim: int, n_hidden: list, n_latent: int):
        super().__init__()
        layers = []
        prev_dim = input_dim
        for hidden_dim in n_hidden:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Dropout(0.1)
            ])
            prev_dim = hidden_dim
        layers.append(nn.Linear(prev_dim, n_latent))
        self.encoder = nn.Sequential(*layers)

    def forward(self, x):
        return self.encoder(x)
        
class BaseDecoder(nn.Module):
    def __init__(self, 
                 n_latent: int, 
                 n_hidden: list, 
                 output_dim: int,
                 output_activation: nn.Module):
        super().__init__()
        layers = []
        prev_dim = n_latent
        for hidden_dim in n_hidden:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Dropout(0)
            ])
            prev_dim = hidden_dim
        layers.append(nn.Linear(prev_dim, output_dim))
        self.decoder = nn.Sequential(*layers)
        self.output_activation = output_activation
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, library_size=None):
        reconstruction = self.decoder(x)
        if library_size is not None:
            reconstruction = self.softmax(reconstruction) * library_size
        else:
            reconstruction = self.output_activation(reconstruction)
        return reconstruction

class Autoencoder(L.LightningModule):
    def __init__(
        self,
        input_dim: int,
        n_hidden: list,
        n_latent: int,
        distribution: str = 'normal',
        learning_rate: float = 0.001,
        reduce_lr_on_plateau: bool = False,
        lr_factor: float = 0.6,
        lr_patience: int = 5,
        lr_threshold: float = 1e-3,
        lr_min: float = 0.0,
        library_size_mode: str = "none", # "none", "observed", or "modeled"
        decoder_output_activation: str | nn.Module | None = None,
        **trainer_kwargs
    ):
        super().__init__()

        if library_size_mode not in ["none", "observed", "modeled"]:
            raise ValueError("library_size_mode must be 'none', 'observed', or 'modeled'")
        
        self.save_hyperparameters()
        self.input_dim = input_dim
        self.n_hidden = n_hidden
        self.n_latent = n_latent
        self.distribution = distribution
        self.learning_rate = learning_rate
        self.reduce_lr_on_plateau = reduce_lr_on_plateau
        self.lr_factor = lr_factor
        self.lr_threshold = lr_threshold
        self.lr_min = lr_min
        self.lr_patience = lr_patience
        self.library_size_mode = library_size_mode


        if not isinstance(decoder_output_activation, nn.Module):
            self.output_activation = self._get_activation_fn(decoder_output_activation)
            if decoder_output_activation is None and 'nb' in self.distribution:
                warnings.warn("For 'nb' distribution with no specified decoder output activation, using 'softplus' activation.")
                self.output_activation = self._get_activation_fn('softplus')
        else:
            self.output_activation = decoder_output_activation
        

        self.encoder = BaseEncoder(self.input_dim, self.n_hidden, self.n_latent)
        
        if self.library_size_mode == "modeled":
            self.l_encoder = BaseEncoder(self.input_dim, [self.n_hidden[0]], 1)

        if library_size_mode == "none":
            self.forward_fn = self._forward_none
        elif library_size_mode == "observed":
            self.forward_fn = self._forward_observed
        else:  # modeled
            self.forward_fn = self._forward_modeled

        self.decoder = BaseDecoder(self.n_latent, 
                                   self.n_hidden[::-1], 
                                   self.input_dim, 
                                   self.output_activation
                                   )
        
        if self.distribution in ['normal_mle_gene', 'nb_gene']:
            initial = torch.zeros(self.input_dim).normal_(mean=0.0, std=0.1)
            self.px_r = nn.Parameter(initial)

    def _get_activation_fn(self, name: str | None) -> nn.Module:
        """Map string name to activation function module"""
        if name is None:
            return nn.Identity()
        name = name.lower()
        if name == 'linear' or name == 'identity':
            return nn.Identity()
        elif name == 'softmax':
            return nn.Softmax(dim=-1)
        elif name == 'relu':
            return nn.ReLU()
        elif name == 'sigmoid':
            return nn.Sigmoid()
        elif name == 'softplus':
            return nn.Softplus()
        else:
            raise ValueError(f"Unsupported activation: {name}")
        
    def forward(self, x):
        return self.forward_fn(x)

    def _forward_none(self, x):
        z = self.encoder(x)
        return self.decoder(z)

    def _forward_observed(self, x):
        z = self.encoder(x)
        library_size = torch.sum(x, dim=1).unsqueeze(1) 
        return self.decoder(z, library_size)

    def _forward_modeled(self, x):
        z = self.encoder(x)
        library_size = torch.exp(self.l_encoder(x))
        return self.decoder(z, library_size)

    def encode(self, x):
        return self.encoder(x)

    def _shared_step(self, batch, stage='train'):
        x = batch['X']
        reconstruction = self(x,)
        loss = self.compute_loss(x, reconstruction)
        self.log_metrics({'loss': loss}, stage)
        return loss

    def training_step(self, batch, batch_idx):
        lr = self.optimizers().param_groups[0]['lr']
        self.log('lr', lr, on_step=True, on_epoch=True, prog_bar=True)
        return self._shared_step(batch, 'train')

    def validation_step(self, batch, batch_idx):
        return self._shared_step(batch, 'val')
    
    def test_step(self, batch, batch_idx):
        return self._shared_step(batch, 'test')
        
    def compute_loss(self, x, reconstruction):
        if self.distribution == 'normal':
            return nn.MSELoss()(reconstruction, x)
        elif self.distribution == 'huber':
            return nn.HuberLoss()(reconstruction, x)
        elif self.distribution == 'l1':
            return nn.L1Loss()(reconstruction, x)
        elif self.distribution == 'normal_mle_fixed':
            sigma = torch.tensor(1e-1, device=x.device)
            dist = torch.distributions.Normal(loc=reconstruction, scale=sigma)
            log_prob = dist.log_prob(x)
            neg_log_likelihood = -log_prob.sum(dim=1).mean()
            return neg_log_likelihood
        elif self.distribution == 'normal_mle_gene':
            sigma = torch.exp(torch.clamp(self.px_r, min=-10.0, max=10.0))
            sigma_batch = sigma.unsqueeze(0)
            dist = torch.distributions.Normal(loc=reconstruction, scale=sigma_batch)
            log_prob = dist.log_prob(x)
            neg_log_likelihood = -log_prob.sum(dim=1).mean()
            return neg_log_likelihood
        elif self.distribution == 'nb_gene':
            mu = reconstruction
            theta = torch.exp(torch.clamp(self.px_r, min=-10.0, max=10.0))
            theta_batch = theta.unsqueeze(0)
            nb = NegativeBinomial(mu=mu, theta=theta_batch)
            log_prob = nb.log_prob(x)
            neg_log_likelihood = -log_prob.sum(dim=1).mean()
            return neg_log_likelihood

        else:
            raise ValueError(f"Unsupported distribution: {self.distribution}")

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        if self.reduce_lr_on_plateau:
            scheduler = {
                'scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(
                    optimizer,
                    patience=self.lr_patience,
                    factor=self.lr_factor,
                    threshold=self.lr_threshold,
                    min_lr=self.lr_min,
                    threshold_mode="abs"
                ),
                'monitor': 'val/loss_epoch',
                'interval': 'epoch',
                'frequency': 1,
                'name': 'lr_scheduler'
            }
            
            return {
                "optimizer": optimizer,
                "lr_scheduler": scheduler
            }
        else:
            return optimizer

    def log_metrics(self, loss_dict, stage='train'):
        for key, value in loss_dict.items():
            self.log(f'{stage}/{key}', value, on_step=True, on_epoch=True, prog_bar=True)


        

from sc_reconstruction.models._base_model import BaseReconstructionModel
from anndata import AnnData
import os

[docs] class ReconAE(BaseReconstructionModel):
[docs] def __init__( self, input_dim: int, n_hidden: list, n_latent: int, distribution: str = 'normal', # 'normal' for mse, 'nb_gene' for negative binomial with gene-wise dispersion library_size_mode: str = "none", # "none", "observed", or "modeled" learning_rate: float = 0.001, reduce_lr_on_plateau: bool = False, lr_factor: float = 0.6, lr_patience: int = 5, lr_threshold: float = 1e-3, lr_min: float = 0.0, decoder_output_activation: str | nn.Module | None = None, ): self.model_params = { 'input_dim': input_dim, 'n_hidden': n_hidden, 'n_latent': n_latent, 'distribution': distribution, 'library_size_mode': library_size_mode, 'learning_rate': learning_rate, 'reduce_lr_on_plateau': reduce_lr_on_plateau, 'lr_factor': lr_factor, 'lr_patience': lr_patience, 'lr_threshold': lr_threshold, 'lr_min': lr_min, 'decoder_output_activation': decoder_output_activation, } self.module = Autoencoder(**self.model_params)
def prepare(self, adata: AnnData | None = None, **kwargs): if adata is not None: self.adata = adata def train(self, datamodule: L.LightningDataModule = None, **train_kwargs): trainer = L.Trainer( **train_kwargs ) trainer.fit( self.module, datamodule=datamodule ) def get_latent_representation(self, X: np.ndarray) -> np.ndarray: device = next(self.module.parameters()).device self.module.eval() with torch.no_grad(): x_tensor = torch.from_numpy(X).to(device) z = self.module.encode(x_tensor) return z.cpu().numpy() def predict(self, X: np.ndarray) -> np.ndarray: # not optimal device = next(self.module.parameters()).device self.module.eval() with torch.no_grad(): x_tensor = torch.from_numpy(X).to(device) reconstruction = self.module(x_tensor) return reconstruction.cpu().numpy() def predict_relu(self, X: np.ndarray) -> np.ndarray: # not optimal device = next(self.module.parameters()).device self.module.eval() with torch.no_grad(): x_tensor = torch.from_numpy(X).to(device) # if self.model_params['library_size_mode'] == "observed": # library_size = torch.sum(x_tensor, dim=1) # elif self.model_params['library_size_mode'] == "modeled": # library_size = self.module.l_encoder(x_tensor) # else: # library_size = None # reconstruction = self.module(x_tensor, library_size) reconstruction = self.module(x_tensor) reconstruction = torch.relu(reconstruction) return reconstruction.cpu().numpy() def save(self, path: str): if not path.endswith('.pt') and not path.endswith('.ckpt'): path = path + '.pt' os.makedirs(os.path.dirname(path), exist_ok=True) torch.save(self.module, path) print(f"Model checkpoint saved to {path}") def load(self, path: str, map_location=None) -> None: """Load the model""" self.module = Autoencoder.load_from_checkpoint(path, map_location=map_location) print(f"Model loaded from {path}")