Source code for sc_reconstruction.models.reconmlscvi

from __future__ import annotations
from typing import Any, Dict
import torch
import numpy as np
from anndata import AnnData

from scvi.model import SCVI
from scvi.model.base import UnsupervisedTrainingMixin, BaseMinifiedModeModelClass
from scvi import REGISTRY_KEYS


import logging
import warnings
from typing import TYPE_CHECKING

from scvi import REGISTRY_KEYS, settings
from scvi.data import AnnDataManager
from scvi.data._constants import ADATA_MINIFY_TYPE
from scvi.data._utils import _get_adata_minify_type
from scvi.data.fields import (
    CategoricalJointObsField,
    CategoricalObsField,
    LayerField,
    NumericalJointObsField,
    NumericalObsField,
)
from scvi.model._utils import _init_library_size
from scvi.model.base import EmbeddingMixin, UnsupervisedTrainingMixin
from scvi.module import VAE
from scvi.utils import setup_anndata_dsp

import os
from sc_reconstruction.models._base_model import BaseReconstructionModel
from scvi.module._constants import MODULE_KEYS





[docs] class ReconMLSCVI(SCVI, BaseReconstructionModel): """ A structured SCVI model for reconstruction that adheres to the base class interface. """
[docs] def __init__( self, adata: AnnData | None = None, registry: dict | None = None, n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, dropout_rate: float = 0.1, dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene", gene_likelihood: Literal["zinb", "nb", "poisson", "normal"] = "zinb", use_observed_lib_size: bool = True, latent_distribution: Literal["normal", "ln"] = "normal", library_log_means: np.ndarray | None = None, library_log_vars: np.ndarray | None = None, **kwargs, ): self.__class__.__name__ = "SCVI" BaseMinifiedModeModelClass.__init__(self, adata, registry) self._module_kwargs = { "n_hidden": n_hidden, "n_latent": n_latent, "n_layers": n_layers, "dropout_rate": dropout_rate, "dispersion": dispersion, "gene_likelihood": gene_likelihood, "latent_distribution": latent_distribution, **kwargs, } self._model_summary_string = ( "SCVI model with the following parameters: \n" f"n_hidden: {n_hidden}, n_latent: {n_latent}, n_layers: {n_layers}, " f"dropout_rate: {dropout_rate}, dispersion: {dispersion}, " f"gene_likelihood: {gene_likelihood}, latent_distribution: {latent_distribution}." ) self._custom_library_log_means = library_log_means self._custom_library_log_vars = library_log_vars if self._custom_library_log_means is None or self._custom_library_log_vars is None: warnings.warn( "Library size parameters not provided. Will take default values [0,1]. Unreliabe if not for inference only.", UserWarning, stacklevel=settings.warnings_stacklevel, ) self._custom_library_log_means = np.array(0.0).reshape(1, 1) self._custom_library_log_vars = np.array(1.0).reshape(1, 1) if self._module_init_on_train: self.module = None warnings.warn( "Model was initialized without `adata`. The module will be initialized when " "calling `train`. This behavior is experimental and may change in the future.", UserWarning, stacklevel=settings.warnings_stacklevel, ) else: if adata is not None: n_cats_per_cov = ( self.adata_manager.get_state_registry( REGISTRY_KEYS.CAT_COVS_KEY ).n_cats_per_key if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry else None ) else: # custom datamodule if ( len( self.registry["field_registries"][f"{REGISTRY_KEYS.CAT_COVS_KEY}"][ "state_registry" ] ) > 0 ): n_cats_per_cov = tuple( self.registry["field_registries"][f"{REGISTRY_KEYS.CAT_COVS_KEY}"][ "state_registry" ]["n_cats_per_key"] ) else: n_cats_per_cov = None n_batch = self.summary_stats.n_batch use_size_factor_key = self.registry_["setup_args"][ f"{REGISTRY_KEYS.SIZE_FACTOR_KEY}_key" ] # Determine library size parameters if self._custom_library_log_means is not None and self._custom_library_log_vars is not None: library_log_means = self._custom_library_log_means library_log_vars = self._custom_library_log_vars else: library_log_means, library_log_vars = None, None if ( not use_size_factor_key and self.minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR and not use_observed_lib_size ): library_log_means, library_log_vars = _init_library_size( self.adata_manager, n_batch ) self.module = self._module_cls( n_input=self.summary_stats.n_vars, n_batch=n_batch, n_labels=self.summary_stats.n_labels, n_continuous_cov=self.summary_stats.get("n_extra_continuous_covs", 0), n_cats_per_cov=n_cats_per_cov, n_hidden=n_hidden, n_latent=n_latent, n_layers=n_layers, dropout_rate=dropout_rate, dispersion=dispersion, gene_likelihood=gene_likelihood, use_observed_lib_size=use_observed_lib_size, latent_distribution=latent_distribution, use_size_factor_key=use_size_factor_key, library_log_means=library_log_means, library_log_vars=library_log_vars, **kwargs, ) self.module.minified_data_type = self.minified_data_type self.init_params_ = self._get_init_params(locals())
def prepare(self, adata: Any, **kwargs) -> None: """Register adata with the SCVI data manager and cache the reference.""" self.setup_anndata(adata, **kwargs) self.adata = adata def train(self, **train_kwargs) -> None: """Drop the unused ``save_path`` kwarg and delegate to ``SCVI.train``.""" train_kwargs.pop('save_path', None) super().train(**train_kwargs) def _prepare_batch(self, X: np.ndarray) -> dict: """ Convert a raw NumPy array into the mini_batch dictionary required by self.module.forward. """ return { REGISTRY_KEYS.X_KEY: torch.from_numpy(X), REGISTRY_KEYS.BATCH_KEY: torch.zeros((X.shape[0], 1), dtype=torch.int64), REGISTRY_KEYS.LABELS_KEY: torch.zeros((X.shape[0], 1), dtype=torch.int64), } def _predict_batch(self, mini_batch: dict, inference_kwargs: dict) -> torch.Tensor: """ Runs the model forward pass on a mini_batch and returns the reconstructed output. """ self.module.eval() with torch.no_grad(): _, generative_outputs = self.module.forward( tensors=mini_batch, inference_kwargs=inference_kwargs, compute_loss=False ) return generative_outputs["px"].loc def _predict_batch_deterministic(self, mini_batch: dict, inference_kwargs: dict) -> torch.Tensor: """ Runs the model forward pass on a mini_batch and returns the reconstructed output. """ self.module.eval() with torch.no_grad(): inference_outputs, generative_outputs = self.module.forward( tensors=mini_batch, inference_kwargs=inference_kwargs, compute_loss=False ) qz = inference_outputs["qz"] z_mean = qz.loc ql = inference_outputs["ql"] l_mean = ql.loc generative_inputs = self.module._get_generative_input( mini_batch, inference_outputs ) generative_inputs[MODULE_KEYS.Z_KEY] = z_mean if not self.use_observed_lib_size: generative_inputs[MODULE_KEYS.LIBRARY_KEY] = l_mean generative_outputs_deterministic = self.module.generative(**generative_inputs) return generative_outputs_deterministic["px"].loc def _predict_deterministic(self, X: np.ndarray, inference_kwargs: dict = None) -> np.ndarray: """ deterministic prediction for scvi style models """ mini_batch = self._prepare_batch(X) pred_tensor = self._predict_batch_deterministic(mini_batch, inference_kwargs) return pred_tensor.cpu().numpy() def _predict_random(self, X: np.ndarray, inference_kwargs: dict = None) -> np.ndarray: """ random prediction for scvi style models Parameters: X: np.ndarray - Raw input data. inference_kwargs: Optional dictionary of inference parameters (default: {"n_samples": 1, "return_mean": True}). Returns: np.ndarray: The predicted/reconstructed output. """ mini_batch = self._prepare_batch(X) pred_tensor = self._predict_batch(mini_batch, inference_kwargs) return pred_tensor.cpu().numpy() def predict(self, X: np.ndarray, inference_kwargs: dict = None, pred_type = "random") -> np.ndarray: """ Parameters: X: np.ndarray - Raw input data. inference_kwargs: Optional dictionary of inference parameters (default: {"n_samples": 1, "return_mean": True}). pred_type: Optional string indicating the type of prediction ("random" or "deterministic"). """ if pred_type == "random": return self._predict_random(X, inference_kwargs) elif pred_type == "deterministic": return self._predict_deterministic(X, inference_kwargs) else: raise ValueError(f"Invalid prediction type: {pred_type}") def save(self, path: str, **kwargs) -> None: """Save the model module to ``path``. Args: path: Destination path; ``.pt`` is appended if no extension is set. kwargs: Forwarded to the parent class (e.g. ``save_anndata``, ``overwrite``). """ # Add .pt extension if no .pt or .ckpt extension is present if not path.endswith('.pt') and not path.endswith('.ckpt'): path = path + '.pt' os.makedirs(os.path.dirname(path), exist_ok=True) torch.save(self.module, path) print(f"Model saved to {path}") def load(self, path: str, map_location=None) -> None: """Load the model""" self.module = torch.load(path, map_location=map_location, weights_only=False) print(f"Model loaded from {path}")