Source code for sc_reconstruction.utils.model_loader

"""
Utility for loading models based on checkpoint filenames
"""

import os
import re
from typing import Dict, Any, Optional, Tuple
from hydra.utils import instantiate
from omegaconf import DictConfig
import ast

[docs] def parse_model_params_from_filename(filename: str, mode = 'scVI') -> Dict[str, Any]: """ Parse model parameters from a checkpoint filename Format: epochs_n_hidden_n_latent_n_layers_date Example: 20_1024_300_3_20250331 Args: filename: Checkpoint filename Returns: Dict of model parameters """ # Remove file extension if present if filename.endswith('.pt'): filename = filename[:-3] # Parse parameters parts = filename.split('_') if 'VQVAE' in mode: '400_[1024, 1024, 1024, 1024]_512_1024_1_20251001' try: hidden_dims_str = parts[1] if hidden_dims_str.startswith('[') and hidden_dims_str.endswith(']'): hidden_dims = ast.literal_eval(hidden_dims_str) else: hidden_dims = [int(hidden_dims_str)] return { 'epochs': int(parts[0]), 'n_hidden': hidden_dims, 'n_latent': int(parts[2]), 'num_embeddings': int(parts[3]), 'vq_weight': int(parts[4]), 'date': parts[-1] } except (ValueError, SyntaxError) as e: raise ValueError(f"Failed to parse VQVAE checkpoint filename: {filename}. Error: {str(e)}") elif 'AE' in mode: # not used since AE now is using lightning loader try: hidden_dims_str = parts[1] if hidden_dims_str.startswith('[') and hidden_dims_str.endswith(']'): hidden_dims = ast.literal_eval(hidden_dims_str) else: hidden_dims = [int(hidden_dims_str)] return { 'epochs': int(parts[0]), 'n_hidden': hidden_dims, 'n_latent': int(parts[2]), 'date': parts[3] } except (ValueError, SyntaxError) as e: raise ValueError(f"Failed to parse AE checkpoint filename: {filename}. Error: {str(e)}") else: # Ensure we have at least 5 parts (epochs, n_hidden, n_latent, n_layers, KL_weights, date) if len(parts) < 6: raise ValueError(f"Invalid checkpoint filename format: {filename}") try: params = { 'epochs': int(parts[0]), 'n_hidden': int(parts[1]), 'n_latent': int(parts[2]), 'n_layers': int(parts[3]), 'date': parts[-1] } print('loading model with params', params) return params except (IndexError, ValueError) as e: raise ValueError(f"Failed to parse checkpoint filename: {filename}. Error: {str(e)}")
[docs] def get_model_class_from_name(model_name: str) -> str: """ Get the appropriate model class target from model name Args: model_name: Name of the model (e.g., 'scVI', 'PCA') Returns: Target class path as string """ model_targets = { 'scVI': 'sc_reconstruction.models.reconscvi.ReconSCVI', 'nlscVI': 'sc_reconstruction.models.reconnlscvi.ReconNLSCVI', 'mlscVI': 'sc_reconstruction.models.reconmlscvi.ReconMLSCVI', 'PCA': 'sc_reconstruction.models.reconpca.PCA', 'DRVI': 'sc_reconstruction.models.recondrvi.ReconDRVI', 'AE': 'sc_reconstruction.models.reconae.ReconAE', 'olAE': 'sc_reconstruction.models.reconae.ReconAE', 'mlAE': 'sc_reconstruction.models.reconae.ReconAE', 'MLEAE': 'sc_reconstruction.models.reconae.ReconAE', 'VQVAE': 'sc_reconstruction.models.reconvqvae.ReconVQVAE', 'olVQVAE': 'sc_reconstruction.models.reconvqvae.ReconVQVAE', 'mlVQVAE': 'sc_reconstruction.models.reconvqvae.ReconVQVAE', } if model_name not in model_targets: raise ValueError(f"Unknown model name: {model_name}. Supported models: {list(model_targets.keys())}") return model_targets[model_name]
[docs] def inst_model_from_checkpoint(model_name: str, checkpoint_file: str, checkpoint_path: str, additional_params: Optional[Dict[str, Any]] = None) -> Any: """ Load a model from a checkpoint file, inferring the architecture from the filename Args: model_name: Name of the model (e.g., 'scVI', 'PCA') checkpoint_path: Path to the checkpoint file additional_params: Additional parameters to pass to the model constructor Returns: Loaded model """ if not os.path.exists(checkpoint_path): raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}") # Parse parameters from filename try: params = parse_model_params_from_filename(checkpoint_file, model_name) except ValueError: params = {} model_target = get_model_class_from_name(model_name) model_config = { '_target_': model_target, } # Add parameters from filename if 'n_hidden' in params: model_config['n_hidden'] = params['n_hidden'] if 'n_latent' in params: model_config['n_latent'] = params['n_latent'] if 'n_layers' in params: model_config['n_layers'] = params['n_layers'] if 'num_embeddings' in params: model_config['num_embeddings'] = params['num_embeddings'] if 'vq_weight' in params: model_config['vq_weight'] = params['vq_weight'] # Add additional parameters if additional_params: print('loading model with additional params', additional_params) model_config.update(additional_params) # Instantiate model model = instantiate(model_config) return model
[docs] def create_model_from_cfg(cfg: DictConfig, running_device: str) -> Tuple[Any, str]: """ Create a model from configuration, with fallback to parsing from filename Args: cfg: Configuration with model info Returns: Tuple of (model, checkpoint_path) """ model_name = cfg.model.meta.name checkpoint_file = cfg.model.load.model_name checkpoint_path = cfg.model.load.path if hasattr(cfg.model.load, 'additional_params'): additional_params = cfg.model.load.additional_params else: additional_params = {} if hasattr(cfg.model, 'model_args') and cfg.model.model_args is not None: model = instantiate(cfg.model.model_args) print("Instantiated model architecture from config") else: model = inst_model_from_checkpoint(model_name, checkpoint_file, checkpoint_path, additional_params) print("Instantiated model architecture from checkpoint") model.load(checkpoint_path, map_location=running_device) print(f"Model loaded from {checkpoint_path} with name {model_name}") return model, checkpoint_path