sc_reconstruction.decoders.ReconTransformerDecoder#

class sc_reconstruction.decoders.ReconTransformerDecoder(n_input, n_output, n_cat_list=None, dim_model=128, num_head=8, dim_hid=256, nlayers=6, dropout=0.1, distribution='normal', learning_rate=0.0001, reduce_lr_on_plateau=False, lr_factor=0.6, lr_patience=5, lr_threshold=0.001, lr_min=0.0, library_size_mode='none', weight_decay=0.0, use_adamw=False, **kwargs)[source]#

Bases: BaseReconstructionDecoder

Recon-style wrapper mirroring ReconMLPDecoder.

  • self.module is a LightningModule.

  • train(datamodule, **trainer_kwargs) — fit on a datamodule.

  • decode(z_numpy, *cat_numpy) -> numpy — invert latent to expression.

  • save / load — checkpoint round-trip.

Parameters:
__init__(n_input, n_output, n_cat_list=None, dim_model=128, num_head=8, dim_hid=256, nlayers=6, dropout=0.1, distribution='normal', learning_rate=0.0001, reduce_lr_on_plateau=False, lr_factor=0.6, lr_patience=5, lr_threshold=0.001, lr_min=0.0, library_size_mode='none', weight_decay=0.0, use_adamw=False, **kwargs)[source]#
Parameters:

Methods

__init__(n_input, n_output[, n_cat_list, ...])

decode(z, *cat_list[, decode_batch_size])

Decode the latent representation back.

load(path[, map_location])

Load weights into the existing module.

save(path)

Save the trained decoder.

train([datamodule])