Source code for sc_reconstruction.models.reconknn
import h5py
import cupy as cp
import numpy as np
import cuml as cm
import anndata as ad
from tqdm import tqdm
[docs]
class ReconKNN():
[docs]
def __init__(self, n_neighbors: int = 5, metric: str = 'euclidean',
data_path: str = None, batch_size: int = 1_000_000):
self.n_neighbors = n_neighbors
self.metric = metric
self.batch_size = batch_size
# with h5py.File(data_path, "r") as f:
# adata = ad.AnnData(
# X=ad.experimental.read_elem_as_dask(
# f["X"],
# chunks=(self.batch_size, -1)
# )
# )
# indices = np.random.choice(range(len(adata)), size=int(len(adata)*0.1), replace=False)
# self.adata = adata[indices].copy()
adata = ad.read_h5ad(data_path)
rng = np.random.default_rng(seed=42)
indices = rng.choice(range(len(adata)), size=int(len(adata) * 0.1), replace=False)
self.adata = adata[indices].X.toarray()
# def process_subbatch(self, x, y, idx_shift):
# x = cp.asarray(x)
# distances = []
# indices = []
# for e in y:
# e = cp.asarray(e)
# tmp = cp.sum((e - x) ** 2, axis=1)
# idx = cp.argsort(tmp)[:self.n_neighbors]
# distances.append(tmp[idx].get())
# indices.append(idx.get() + idx_shift)
# return np.stack(distances), y[np.stack(indices)]
def process_subbatch(self, x, y):
data_gpu = cp.asarray(x)
nn = cm.neighbors.NearestNeighbors(
n_neighbors=self.n_neighbors,
output_type="numpy",
metric="euclidean"
).fit(data_gpu)
points_gpu = cp.asarray(y)
distances, indices = nn.kneighbors(points_gpu)
del data_gpu, points_gpu, nn
# cp.get_default_memory_pool().free_all_blocks()
# cp.get_default_pinned_memory_pool().free_all_blocks()
return distances, x[indices]
def predict(self, X: np.ndarray, **inference_kwargs) -> np.ndarray:
distances = []
values = []
# for block in tqdm(self.adata.X.blocks.ravel()):
# dst, val = self.process_subbatch(block.compute().toarray(), X)
# distances.append(dst)
# values.append(val)
n = 100
n_rows = int(self.adata.shape[0] / n) + 2
for i in range(n):
dst, val = self.process_subbatch(self.adata[(i * n_rows):((i + 1) * n_rows)], X)
distances.append(dst)
values.append(val)
distances = np.concatenate(distances, axis=1)
values = np.concatenate(values, axis=1)
idx = np.argsort(distances, axis=1)[:, :self.n_neighbors]
reconstructed = np.stack([values[i, idx[i]] for i in range(values.shape[0])]).mean(axis=1)
return reconstructed