Source code for sc_reconstruction.models.reconscvi

from __future__ import annotations

from typing import Any, Dict
import torch
import numpy as np
from anndata import AnnData

from scvi.model import SCVI
from scvi.model.base import UnsupervisedTrainingMixin
from scvi import REGISTRY_KEYS

import os
from sc_reconstruction.models._base_model import BaseReconstructionModel
from scvi.module._constants import MODULE_KEYS


[docs] class ReconSCVI(SCVI, BaseReconstructionModel): """ A structured SCVI model for reconstruction that adheres to the base class interface. """
[docs] def __init__( self, adata: AnnData | None = None, registry: Dict | None = None, **kwargs ): self.__class__.__name__ = "SCVI" super().__init__(adata=adata, registry=registry, **kwargs)
def prepare(self, adata: Any, **kwargs) -> None: """Register adata with the SCVI data manager and cache the reference.""" self.setup_anndata(adata, **kwargs) self.adata = adata def train(self, **train_kwargs) -> None: """Drop the unused ``save_path`` kwarg and delegate to ``SCVI.train``.""" train_kwargs.pop('save_path', None) super().train(**train_kwargs) def _prepare_batch(self, X: np.ndarray) -> dict: """ Convert a raw NumPy array into the mini_batch dictionary required by self.module.forward. """ return { REGISTRY_KEYS.X_KEY: torch.from_numpy(X), REGISTRY_KEYS.BATCH_KEY: torch.zeros((X.shape[0], 1), dtype=torch.int64), REGISTRY_KEYS.LABELS_KEY: torch.zeros((X.shape[0], 1), dtype=torch.int64), } def _predict_batch(self, mini_batch: dict, inference_kwargs: dict) -> torch.Tensor: """ Runs the model forward pass on a mini_batch and returns the reconstructed output. """ self.module.eval() with torch.no_grad(): _, generative_outputs = self.module.forward( tensors=mini_batch, inference_kwargs=inference_kwargs, compute_loss=False ) return generative_outputs["px"].loc def _predict_batch_deterministic(self, mini_batch: dict, inference_kwargs: dict) -> torch.Tensor: """ Runs the model forward pass on a mini_batch and returns the reconstructed output. """ self.module.eval() with torch.no_grad(): inference_outputs, generative_outputs = self.module.forward( tensors=mini_batch, inference_kwargs=inference_kwargs, compute_loss=False ) qz = inference_outputs["qz"] z_mean = qz.loc ql = inference_outputs["ql"] l_mean = ql.loc generative_inputs = self.module._get_generative_input( mini_batch, inference_outputs ) generative_inputs[MODULE_KEYS.Z_KEY] = z_mean if not self.use_observed_lib_size: generative_inputs[MODULE_KEYS.LIBRARY_KEY] = l_mean generative_outputs_deterministic = self.module.generative(**generative_inputs) return generative_outputs_deterministic["px"].loc def _predict_deterministic(self, X: np.ndarray, inference_kwargs: dict = None) -> np.ndarray: """ deterministic prediction for scvi style models """ mini_batch = self._prepare_batch(X) pred_tensor = self._predict_batch_deterministic(mini_batch, inference_kwargs) return pred_tensor.cpu().numpy() def _predict_random(self, X: np.ndarray, inference_kwargs: dict = None) -> np.ndarray: """ random prediction for scvi style models Parameters: X: np.ndarray - Raw input data. inference_kwargs: Optional dictionary of inference parameters (default: {"n_samples": 1, "return_mean": True}). Returns: np.ndarray: The predicted/reconstructed output. """ mini_batch = self._prepare_batch(X) pred_tensor = self._predict_batch(mini_batch, inference_kwargs) return pred_tensor.cpu().numpy() def predict(self, X: np.ndarray, inference_kwargs: dict = None, pred_type = "random") -> np.ndarray: """ Parameters: X: np.ndarray - Raw input data. inference_kwargs: Optional dictionary of inference parameters (default: {"n_samples": 1, "return_mean": True}). pred_type: Optional string indicating the type of prediction ("random" or "deterministic"). """ if pred_type == "random": return self._predict_random(X, inference_kwargs) elif pred_type == "deterministic": return self._predict_deterministic(X, inference_kwargs) else: raise ValueError(f"Invalid prediction type: {pred_type}") def get_latent_representation(self, X: np.ndarray, **inference_kwargs) -> np.ndarray: mini_batch = self._prepare_batch(X) self.module.eval() with torch.no_grad(): inference_outputs, generative_outputs = self.module.forward( tensors=mini_batch, inference_kwargs=inference_kwargs, compute_loss=False ) qz = inference_outputs["qz"] z_mean = qz.loc return z_mean.cpu().numpy() def save(self, path: str, **kwargs) -> None: """Save the model module to ``path``. Args: path: Destination path; ``.pt`` is appended if no extension is set. kwargs: Forwarded to the parent class (e.g. ``save_anndata``, ``overwrite``). """ # Add .pt extension if no .pt or .ckpt extension is present if not path.endswith('.pt') and not path.endswith('.ckpt'): path = path + '.pt' os.makedirs(os.path.dirname(path), exist_ok=True) torch.save(self.module, path) print(f"Model saved to {path}") def load(self, path: str, map_location=None) -> None: """Load the model""" self.module = torch.load(path, map_location=map_location, weights_only=False) print(f"Model loaded from {path}")