from torch.utils.data import DataLoader, Dataset
from lightning import LightningDataModule
import dask.array as da
import os
import warnings
import numpy as np
from scipy.sparse import csr_matrix, vstack
from sc_reconstruction.dataloaders.datasets import *
import zarr
from typing import Optional, Dict, Any
def transform_to_csr(path, chunk_size=1_000_000):
train_data = da.from_zarr(path)
train_data = train_data.rechunk((chunk_size, -1))
blocks = []
n_blocks = train_data.numblocks[0]
for i in range(n_blocks):
block = train_data.blocks[i, :].compute()
csr_block = csr_matrix(block)
blocks.append(csr_block)
global_csr = vstack(blocks)
adata = ad.AnnData(X=global_csr)
adata.write_h5ad(path.replace('.zarr', '.h5ad'))
return path.replace('.zarr', '.h5ad')
[docs]
class IterDaskDataModule(LightningDataModule):
'''
Iterative datamodule to align with the iterative dataloader, can be used for scVI standard training
'''
[docs]
def __init__(
self,
train_path,
val_path = None,
test_path = None,
chunk_size = 100_000,
minibatch_size=4096,
num_workers=10,
max_workers_percentage=0.8,
chunks_per_worker=5,
random_seed=42):
super().__init__()
self.train_path = train_path
self.val_path = val_path
self.test_path = test_path
self.chunk_size = chunk_size
self.minibatch_size = minibatch_size
self.max_workers_percentage = max_workers_percentage
self.num_workers = num_workers
self.n_vars = None
self.chunks_per_worker = chunks_per_worker
# no batch effect
self.n_batch = 1
self.random_seed = random_seed
if not os.path.isdir(self.train_path):
warnings.warn(f"Train directory not found: {self.train_path}")
if not os.path.isdir(self.val_path):
warnings.warn(f"Validation directory not found: {self.val_path}")
if not os.path.isdir(self.test_path):
warnings.warn(f"Test directory not found: {self.test_path}")
self.X = self._load_zarr_data(self.train_path, description="Prep data")
self.n_vars = self.X.shape[1]
@staticmethod
def _load_zarr_data(path, description="Non-specified data"):
if not os.path.isdir(path):
warnings.warn(f"{description} doesn't exist: {path}")
return None
try:
data = da.from_zarr(path)
print(f"{description} loading directly from: {path}")
return data
except Exception as e:
try:
data = da.from_zarr(path, 'X')
print(f"{description} loading from 'X': {path}")
return data
except Exception as e2:
warnings.warn(f"Failing loading {description}: {path}. Error: {e2}")
return None
def prepare_data(self):
'''
Load data from zarr files using Dask
'''
self.train_data = self._load_zarr_data(self.train_path, description="Train data")
self.val_data = self._load_zarr_data(self.val_path, description="Val data") if self.val_path is not None else None
self.test_data = self._load_zarr_data(self.test_path, description="Test data") if self.test_path is not None else None
def setup(self, stage=None):
if self.train_data is not None:
self.train_data = self.train_data.rechunk((self.chunk_size, self.n_vars))
if self.val_data is not None:
self.val_data = self.val_data.rechunk((self.chunk_size, self.n_vars))
if self.test_data is not None:
self.test_data = self.test_data.rechunk((self.chunk_size, self.n_vars))
print(f"Train data: shape={self.train_data.shape}, chunks={self.train_data.numblocks}")
print(f"Val data: shape={self.val_data.shape}, chunks={self.val_data.numblocks}")
print(f"Test data: shape={self.test_data.shape}, chunks={self.test_data.numblocks}")
if self.train_data is not None:
self.train_dataset = DaskSCVIIterableDataset(self.train_data, chunk_shuffle=True, batch_size=self.minibatch_size, seed=self.random_seed, chunks_per_worker=self.chunks_per_worker)
if self.val_data is not None:
self.val_dataset = DaskSCVIIterableDataset(self.val_data, chunk_shuffle=False, batch_size=self.minibatch_size, seed=self.random_seed, chunks_per_worker=self.chunks_per_worker)
if self.test_data is not None:
self.test_dataset = DaskSCVIIterableDataset(self.test_data, chunk_shuffle=False, batch_size=self.minibatch_size, seed=self.random_seed, chunks_per_worker=self.chunks_per_worker)
def shuffle_train_data(self):
self.train_dataset.shuffle_chunks()
print("Train dataset chunks shuffled.")
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=None,
num_workers=self.num_workers,
pin_memory=True
)
def val_dataloader(self):
return DataLoader(
self.val_dataset,
batch_size=None,
num_workers=self.num_workers,
pin_memory=True
)
def test_dataloader(self):
return DataLoader(
self.test_dataset,
batch_size=None,
num_workers=self.num_workers,
pin_memory=True
)
from lightning.pytorch.callbacks import Callback
[docs]
class DatasetEpochCallback(Callback):
def on_train_epoch_start(self, trainer, pl_module):
dm = trainer.datamodule
if dm is not None and hasattr(dm, "set_epoch"):
dm.set_epoch(trainer.current_epoch)
elif dm is not None and hasattr(dm, "train_dataset") and hasattr(dm.train_dataset, "set_epoch"):
dm.train_dataset.set_epoch(trainer.current_epoch)
import h5py
import anndata as ad
[docs]
class DaskPCADataModule():
'''
Mini-class for PCA with Dask
Uses h5ad train/val/test for PCA fitting and evaluation
val/test does not do anything
'''
[docs]
def __init__(self, train_path, val_path = None, test_path = None, chunk_size=1_000_000):
self.train_path = train_path
self.val_path = val_path
self.test_path = test_path
self.chunk_size = chunk_size
def prepare_data(self):
if not os.path.isfile(self.train_path) and not os.path.isdir(self.train_path):
raise FileNotFoundError(f"Train file not found: {self.train_path}, at least one of zarr or h5ad for training data is required")
if not self.train_path.endswith('.h5ad'):
if not os.path.isfile(self.train_path.replace('.zarr', '.h5ad')):
print(f"Provided train data is not in .h5ad format, transforming {self.train_path} to .h5ad")
self.train_path = transform_to_csr(self.train_path, self.chunk_size)
else:
self.train_path = self.train_path.replace('.zarr', '.h5ad')
with h5py.File(self.train_path, "r") as f:
self.adata = ad.AnnData(
X=ad.experimental.read_elem_lazy(
f["X"],
chunks=(self.chunk_size, -1)
)
)
def train_dataloader(self):
return self.adata
# Datamodule for decode-only training
class DecodeOnlyDataModule(LightningDataModule):
def __init__(
self,
train_path,
val_path = None,
test_path = None,
expression_key = 'X',
latent_key = 'scANVI',
batch_key = None,
create_pseudo_batch: bool = False,
minibatch_size = 128,
num_workers = 10,
max_workers_percentage = 0.8,
random_seed = 42):
super().__init__()
self.train_path = train_path
self.val_path = val_path
self.test_path = test_path
self.expression_key = expression_key
self.latent_key = latent_key
self.batch_key = batch_key
self.create_pseudo_batch = create_pseudo_batch
self.minibatch_size = minibatch_size
self.num_workers = num_workers
self.max_workers_percentage = max_workers_percentage
self.random_seed = random_seed
def prepare_data(self):
pass
def _create_data_dict(self, zarr_store: zarr.Group) -> Dict[str, Any]:
data_dict = {
self.expression_key: zarr_store[self.expression_key][:],
self.latent_key: zarr_store[self.latent_key][:],
}
if self.batch_key and self.batch_key in zarr_store:
data_dict[self.batch_key] = zarr_store[self.batch_key][:]
return data_dict
def _create_dataset(self, zarr_store: zarr.Group) -> DecodeOnlyDataset:
"""Create dataset from Zarr store"""
data_dict = self._create_data_dict(zarr_store)
return DecodeOnlyDataset(
data_dict=data_dict,
expression_key=self.expression_key,
latent_key=self.latent_key,
batch_key=self.batch_key,
pseudo_batch=self.create_pseudo_batch
)
def setup(self, stage=None):
if os.path.isdir(self.train_path):
self.train_zarr = zarr.open(self.train_path, mode='r')
else:
self.train_zarr = None
if self.val_path is not None and os.path.isdir(self.val_path):
self.val_zarr = zarr.open(self.val_path, mode='r')
else:
self.val_zarr = None
if self.test_path is not None and os.path.isdir(self.test_path):
self.test_zarr = zarr.open(self.test_path, mode='r')
else:
self.test_zarr = None
if stage == "fit" or stage is None:
if self.train_zarr is not None:
self.train_dataset = self._create_dataset(self.train_zarr)
print(f" - Train samples: {len(self.train_dataset)}")
if self.val_zarr is not None:
self.val_dataset = self._create_dataset(self.val_zarr)
print(f" - Validation samples: {len(self.val_dataset)}")
if stage == "test" or stage is None:
if self.test_zarr is not None:
self.test_dataset = self._create_dataset(self.test_zarr)
print(f" - Test samples: {len(self.test_dataset)}")
if self.train_dataset and self.train_dataset.batch is not None:
batch_type = "real" if self.train_dataset.has_real_batch else "pseudo"
print(f" - Batch type: {batch_type}")
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.minibatch_size,
num_workers=self.num_workers,
shuffle=True,
pin_memory=True
)
def val_dataloader(self):
return DataLoader(
self.val_dataset,
batch_size=self.minibatch_size,
num_workers=self.num_workers,
shuffle=False,
pin_memory=True
)
def test_dataloader(self):
return DataLoader(
self.test_dataset,
batch_size=self.minibatch_size,
num_workers=self.num_workers,
shuffle=False,
pin_memory=True
)
class DaskDecodeOnlyDataModule(LightningDataModule):
def __init__(
self,
train_path,
val_path=None,
test_path=None,
expression_key='X',
latent_key='scANVI',
batch_key=None,
chunk_size=100_000,
minibatch_size=4096,
num_workers=10,
chunks_per_worker=5,
random_seed=42
):
super().__init__()
self.train_path = train_path
self.val_path = val_path
self.test_path = test_path
self.expression_key = expression_key
self.latent_key = latent_key
self.batch_key = batch_key
self.chunk_size = chunk_size
self.minibatch_size = minibatch_size
self.num_workers = num_workers
self.chunks_per_worker = chunks_per_worker
self.random_seed = random_seed
# Data attributes
self.train_data = None
self.val_data = None
self.test_data = None
self.train_latent = None
self.val_latent = None
self.test_latent = None
self.train_batch = None
self.val_batch = None
self.test_batch = None
self.n_vars = None
def _load_zarr_arrays(self, path, description="Data"):
"""Load expression, latent, and optional batch arrays from Zarr"""
if not os.path.isdir(path):
warnings.warn(f"{description} directory not found: {path}")
return None, None, None
try:
store = zarr.open(path, mode='r')
if self.expression_key in store:
X_data = da.from_zarr(path, component=self.expression_key)
else:
raise KeyError(f"Expression key '{self.expression_key}' not found")
if self.latent_key in store:
latent_data = da.from_zarr(path, component=self.latent_key)
else:
raise KeyError(f"Latent key '{self.latent_key}' not found")
batch_data = None
if self.batch_key and self.batch_key in store:
batch_data = da.from_zarr(path, component=self.batch_key)
print(f"{description} loaded from {path}")
print(f" X shape: {X_data.shape}, latent shape: {latent_data.shape}")
if batch_data is not None:
print(f" batch shape: {batch_data.shape}")
return X_data, latent_data, batch_data
except Exception as e:
warnings.warn(f"Failed loading {description} from {path}: {e}")
return None, None, None
def prepare_data(self):
"""Load all data arrays"""
print("Loading data arrays...")
print("Train data path:", self.train_path)
print("Validation data path:", self.val_path)
print("Test data path:", self.test_path)
self.train_data, self.train_latent, self.train_batch = self._load_zarr_arrays(
self.train_path, "Train data"
)
if self.val_path:
self.val_data, self.val_latent, self.val_batch = self._load_zarr_arrays(
self.val_path, "Validation data"
)
if self.test_path:
self.test_data, self.test_latent, self.test_batch = self._load_zarr_arrays(
self.test_path, "Test data"
)
def setup(self, stage=None):
"""Setup datasets with proper chunking"""
# In DDP, prepare_data() only runs on local rank 0.
# Re-load here so every rank has the data references.
if self.train_data is None:
self.prepare_data()
if self.train_data is not None:
self.n_vars = self.train_data.shape[1]
self.train_data = self.train_data.rechunk((self.chunk_size, self.n_vars))
self.train_latent = self.train_latent.rechunk((self.chunk_size, self.train_latent.shape[1]))
if self.train_batch is not None:
self.train_batch = self.train_batch.rechunk((self.chunk_size, self.train_batch.shape[1]))
self.train_dataset = DaskDecodeOnlyIterableDataset(
X_data=self.train_data,
latent_data=self.train_latent,
batch_data=self.train_batch,
chunk_shuffle=True,
batch_size=self.minibatch_size,
seed=self.random_seed,
chunks_per_worker=self.chunks_per_worker
)
if self.val_data is not None:
self.val_data = self.val_data.rechunk((self.chunk_size, self.n_vars))
self.val_latent = self.val_latent.rechunk((self.chunk_size, self.val_latent.shape[1]))
if self.val_batch is not None:
self.val_batch = self.val_batch.rechunk((self.chunk_size, self.val_batch.shape[1]))
self.val_dataset = DaskDecodeOnlyIterableDataset(
X_data=self.val_data,
latent_data=self.val_latent,
batch_data=self.val_batch,
chunk_shuffle=False, # No shuffle for validation
batch_size=self.minibatch_size,
seed=self.random_seed,
chunks_per_worker=self.chunks_per_worker
)
if self.test_data is not None:
self.test_data = self.test_data.rechunk((self.chunk_size, self.n_vars))
self.test_latent = self.test_latent.rechunk((self.chunk_size, self.test_latent.shape[1]))
if self.test_batch is not None:
self.test_batch = self.test_batch.rechunk((self.chunk_size, self.test_batch.shape[1]))
self.test_dataset = DaskDecodeOnlyIterableDataset(
X_data=self.test_data,
latent_data=self.test_latent,
batch_data=self.test_batch,
chunk_shuffle=False, # No shuffle for test
batch_size=self.minibatch_size,
seed=self.random_seed,
chunks_per_worker=self.chunks_per_worker
)
# Print dataset info
print(f"\nDataset Info:")
if self.train_data is not None:
print(f"Train: X{self.train_data.shape} -> latent{self.train_latent.shape}, chunks={self.train_data.numblocks}")
if self.val_data is not None:
print(f"Val: X{self.val_data.shape} -> latent{self.val_latent.shape}, chunks={self.val_data.numblocks}")
if self.test_data is not None:
print(f"Test: X{self.test_data.shape} -> latent{self.test_latent.shape}, chunks={self.test_data.numblocks}")
def set_epoch(self, epoch: int):
"""Set epoch for reproducible shuffling"""
if hasattr(self, 'train_dataset'):
self.train_dataset.set_epoch(epoch)
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=None,
num_workers=self.num_workers,
pin_memory=True
)
def val_dataloader(self):
if hasattr(self, 'val_dataset'):
return DataLoader(
self.val_dataset,
batch_size=None,
num_workers=self.num_workers,
pin_memory=True
)
return None
def test_dataloader(self):
if hasattr(self, 'test_dataset'):
return DataLoader(
self.test_dataset,
batch_size=None,
num_workers=self.num_workers,
pin_memory=True
)
return None
class SubsetDaskDecodeOnlyDataModule(DaskDecodeOnlyDataModule):
def __init__(
self,
target_feature_indices: np.ndarray,
**kwargs
):
super().__init__(
**kwargs
)
self.target_feature_indices = target_feature_indices
def _load_zarr_arrays(self, path, description="Data"):
"""Load expression, latent, and optional batch arrays from Zarr"""
# Call parent method to load data
X_data, latent_data, batch_data = super()._load_zarr_arrays(path, description)
# If data was successfully loaded, subset the features
if X_data is not None:
print(f" Original X shape: {X_data.shape}")
X_data = X_data[:, self.target_feature_indices]
print(f" Subset X shape: {X_data.shape}")
return X_data, latent_data, batch_data
def prepare_data(self):
"""Load all data arrays with feature subsetting"""
super().prepare_data()
if self.train_data is not None:
self.n_vars = len(self.target_feature_indices)
print(f"\nFeature subsetting applied:")
print(f" Using {self.n_vars} features from original {self.train_data.shape[1]}")