Source code for sc_reconstruction.utils.data_tools
import dask.array as da
import zarr
from scipy.sparse import csr_matrix, vstack
import anndata as ad
import os
import warnings
from tqdm import tqdm
[docs]
def transform_to_csr(path, chunk_size=1_000_000):
if not os.path.isdir(path):
warnings.warn(f"Path doesn't exist: {path}")
return None
datasets = []
try:
datasets.append(da.from_zarr(path))
except Exception:
try:
datasets.append(da.from_zarr(path, 'X'))
except Exception:
try:
root = zarr.open_group(path, mode='r')
try:
for g in root.group_keys():
loaded = False
try:
datasets.append(da.from_zarr(path, f"{g}/X"))
loaded = True
except Exception:
pass
if not loaded:
try:
datasets.append(da.from_zarr(path, g))
except Exception:
warnings.warn(f"Skipping subgroup '{g}' in {path}: no array or 'X'.")
except Exception:
pass
except Exception as e2:
warnings.warn(f"Failing loading: {path} Error: {e2}")
return None
print(f"Data loaded. Found {len(datasets)} subgroups. Transforming to CSR format...")
csr_parts = []
for data in tqdm(datasets):
data = data.rechunk((chunk_size, -1))
blocks = []
n_blocks = data.numblocks[0]
for i in range(n_blocks):
block = data.blocks[i, :].compute()
blocks.append(csr_matrix(block))
csr_parts.append(vstack(blocks))
global_csr = vstack(csr_parts) if len(csr_parts) > 1 else csr_parts[0]
adata = ad.AnnData(X=global_csr)
adata.write_h5ad(path.replace('.zarr', '.h5ad'))
return path.replace('.zarr', '.h5ad')
[docs]
def get_zarr_attr(meta_root, target_key):
zarr_store = zarr.open(meta_root, mode='r')
if target_key is None:
print('return all attributes: train, val, test combinations')
return zarr_store.attrs['train_combinations'] + zarr_store.attrs['val_combinations'] + zarr_store.attrs['test_combinations']
if target_key in zarr_store.attrs:
return zarr_store.attrs[target_key]
else:
raise KeyError(f"Key '{target_key}' not found in Zarr attributes.")
[docs]
def split_zarr_dataloader(zarr_root, target_combinations, target_key='test_combinations', X_key='X', idx_key=None):
'''
Specified data loader for mainly inference purpose
'''
split_zarr_paths = []
zarr_store = zarr.open(zarr_root, mode='r')
for target_combination in target_combinations:
if idx_key is not None:
yield target_combination, zarr_store[target_combination][X_key][:], zarr_store[target_combination][idx_key][:]
else:
yield target_combination, zarr_store[target_combination][X_key][:]
[docs]
def filter_genes(X, gene_names, selected_genes):
gene_idx = [gene_names.index(gene) for gene in selected_genes]
return X[:, gene_idx]