Source code for DOLPHIN.graph_generation.process_graph_final

from .func_step03_GNN_main import get_graph_input
import torch
from tqdm import tqdm
import math
import anndata
import os
import pandas as pd

[docs]def run_model_input( metadata_path: str, out_name: str, out_directory: str = "./", gnn_run_num: int = 100, celltypename: str = None ): """ Combines feature matrix and adjacency matrix and generates input for the DOLPHIN model. Parameters ---------- metadata_path : str Path to the metadata file (e.g., a csv file with cell information). out_name : str Output filename for the feature matrix CSV. out_directory : str Output directory to save the combined feature matrix, default save to ./data/ folder. gnn_run_num : int Number of samples per GNN batch. celltypename : str, optional Column name in metadata indicating cell types. Default is None. Returns ------- None Saves the final torch tensor file as `model_<out_name>.pt` in the output directory. This file contains a list of PyTorch Geometric `Data` objects, one per cell. Each object includes: - x : Feature matrix of the cell (normalized exon counts, shaped `[num_features, 1]`) - edge_index : Graph connectivity (exon-exon connection indices) - edge_attr : Edge weights for the exon graph - y : Label for the cell (optional; set to numerical index if `celltypename` is not provided) - x_fea : Original feature vector for the cell - x_adj : Raw adjacency matrix for the cell - sample_name : The ID of the cell """ df_label = pd.read_csv(metadata_path, sep='\t') total_sample_size = len(df_label) mapper = None if celltypename and celltypename in df_label.columns: unique_celltypes = sorted(df_label[celltypename].dropna().unique()) mapper = {celltype: idx for idx, celltype in enumerate(unique_celltypes)} final_out_dir = os.path.join(out_directory, "data") os.makedirs(final_out_dir, exist_ok=True) temp_out_dir = os.path.join(final_out_dir, "temp") os.makedirs(temp_out_dir, exist_ok=True) print("Start Construct Data Input for model input") with tqdm(total=total_sample_size) as pbar_gnn: for i in range(0, total_sample_size): if i%gnn_run_num==0: pbar_gnn= get_graph_input(pbar_gnn, i, gnn_run_num, temp_out_dir, final_out_dir, out_name, celltypename=celltypename, mapper=mapper) ##### combine all geometric .pt files total_number_gnn_anndata = math.ceil(total_sample_size/gnn_run_num) for _idx, _gnn_idx in enumerate(range(0, total_number_gnn_anndata)): _temp_gnn = torch.load(os.path.join(temp_out_dir, "model_"+out_name+"_"+str(_gnn_idx)+".pt")) if _idx ==0: combine_gnn = _temp_gnn else: combine_gnn += _temp_gnn torch.save(combine_gnn, os.path.join(final_out_dir, "model_"+out_name+".pt"))