Source code for sc_reconstruction.models.reconpca

from typing import Any, Dict
import numpy as np
import zarr
import os
import anndata as ad
from sc_reconstruction.dataloaders.datamodules import DaskPCADataModule
from sc_reconstruction.models._base_model import BaseReconstructionModel

def set_mem():
    import rmm
    import cupy as cp
    from rmm.allocators.cupy import rmm_cupy_allocator
    rmm.reinitialize(managed_memory=True)
    cp.cuda.set_allocator(rmm_cupy_allocator)






[docs] class ReconPCA(BaseReconstructionModel): """PCA reconstruction model with a scalable GPU implementation. Adapted from `rapids-singlecell <https://github.com/scverse/rapids_singlecell>`_ so the fit scales to 100M-cell datasets via a ``dask_cuda.LocalCUDACluster`` and chunked SVD over the input zarr store. """
[docs] def __init__(self, n_components: int = 300): super().__init__() self.n_components = n_components self.mean = None self.pc = None self.cluster = None self.client = None
def _setup_cluster(self, gpu_ids: str = "0", temp_dir: str = "/lustre/groups/ml01/workspace/xiaotong.fu/reconstruction/temp"): from dask_cuda import LocalCUDACluster from dask.distributed import Client self.cluster = LocalCUDACluster( local_directory=temp_dir, CUDA_VISIBLE_DEVICES=gpu_ids ) self.client = Client(self.cluster) self.client.run(set_mem) def prepare(self, data_path: str, batch_size: int = 1_000_000, **kwargs) -> ad.AnnData: ''' * not used in the current version Prepare the data for training data_path: str - The path to the training data, should be a h5ad file with X as acsr matrix batch_size: int - The batch size ''' with h5py.File(data_path, "r") as f: adata = ad.AnnData( X=ad.experimental.read_elem_as_dask( f["X"], chunks=(batch_size, -1) ) ) return adata def train(self, datamodule: DaskPCADataModule, gpu_ids: str = "0", save_mean: bool = True, save_path: str = None, **kwargs) -> None: ''' Train PCA model adata: ad.AnnData - X: dask array of csr matrix gpu_ids: str - The GPU IDs to use save_mean: bool - Whether to save the mean ''' import rapids_singlecell as rsc import anndata as ad self._setup_cluster(gpu_ids) print("set up cluster") datamodule.prepare_data() adata = datamodule.train_dataloader() print("adata:", adata.X) # have to be done before saving mean # how does this affect the save mean? rsc.get.anndata_to_GPU(adata) if save_mean: print("saving mean") block_sums = adata.X.map_blocks(lambda b: b.sum(axis=0), dtype=adata.X.dtype) total_sum = block_sums.sum(axis=0).compute() self.mean = total_sum / adata.X.shape[0] self.mean = np.ascontiguousarray(self.mean.get()) else: print("not saving mean") print("running PCA") rsc.pp.pca( adata, n_comps=self.n_components, key_added="X_pca", mask_var=None ) self.pc = adata.varm["X_pca"] self.pc = np.ascontiguousarray(self.pc) print("pc:", self.pc.shape, "device:", self.pc.device) if save_path: self.save(save_path) def predict(self, X: np.ndarray, **inference_kwargs) -> np.ndarray: """ Takes a raw NumPy array (batch of data), prepares it, runs the reconstruction, and returns a NumPy array. Parameters: X: np.ndarray - Raw input data. Returns: np.ndarray: The predicted/reconstructed output. """ latent = (X - self.mean) @ self.pc norm_recon = np.dot(latent, self.pc.T) pred = norm_recon + self.mean return pred def save(self, dir_path: str) -> None: print("saving mean and pc") os.makedirs(dir_path, exist_ok=True) if self.mean is not None: zarr.save(f"{dir_path}/mean.zarr", self.mean) print("saved mean") if self.pc is not None: zarr.save(f"{dir_path}/pc_{self.n_components}.zarr", self.pc) print("saved pc") print(f"Model saved to {dir_path}") def load(self, path: list[str], map_location = None) -> None: """Load the model path: list[str] - The path to the model file. path[0]: str - The path to the mean zarr file. path[1]: str - The path to the pc zarr file. """ self.mean = zarr.load(path[0]) self.pc = zarr.load(path[1]) self.n_components = self.pc.shape[1] print(f"Model loaded from {path}") def _teardown_cluster(self): if self.client: self.client.close() if self.cluster: self.cluster.close() print("Cluster resources released") def __del__(self): self._teardown_cluster()