import pandas as pd
import os
from tqdm import tqdm
import shutil
import subprocess
[docs]def run_reads_aggregation(
metadata_path: str,
bam_file_path: str,
bam_file_extension: str,
junction_file_path: str,
junction_file_extension: str,
neighbor_file: str,
read_count_path: str,
N_neighbor: int = 10,
out_directory: str = "./"
):
"""
Aggregate single-cell BAM files by incorporating reads from neighbor cells,
applying junction majority voting and read count normalization.
This function performs the following operations:
1. Reads metadata, neighbor, and read count files.
2. For each target cell:
- Identifies frequent junctions using majority voting across neighbors (Junctions are retained only if they appear in at least half of the neighbors).
- Normalizes neighbor BAM files to match the target read count (via up/downsampling).
- Filters junction reads to retain only those appearing frequently.
- Aggregates filtered reads from neighbors and unfiltered reads from the target cell.
3. Outputs a final merged BAM file per target cell.
Parameters
----------
metadata_path : str
Path to the metadata file.
bam_file_path : str
Directory containing the BAM files generated by STAR.
bam_file_extension : str
Suffix of BAM files after the sample name.
Example: for "SRR18379095.std.Aligned.sortedByCoord.out.bam", use ".std.Aligned.sortedByCoord.out.bam".
junction_file_path : str
Directory containing junction files (STAR SJ.out.tab format).
junction_file_extension : str
Suffix of junction files after the sample name.
Example: for "SRR18379095.std.SJ.out.tab", use ".std.SJ.out.tab".
neighbor_file : str
CSV file specifying neighbor relationships. Must include 'main_name' and 'neighbor' columns.
read_count_path : str
Path to a CSV file with two columns: 'sample' and 'num_seqs', representing read counts for each cell.
N_neighbor : int, optional
Number of neighbors per target cell. Default is 10.
out_directory : str, optional
Output directory to save results. Default is the current directory.
Returns
-------
None
For each cell in the metadata save final aggregated BAM
<out_directory>/cell_aggregation/<cell_id>.aggr.final.bam
"""
pd_aggr = pd.read_csv(neighbor_file)
pd_single_size = pd.read_csv(read_count_path)
sample_list = list(pd.read_csv(metadata_path, sep='\t')["CB"])
final_out_dir = os.path.join(out_directory, "cell_aggregation")
os.makedirs(final_out_dir, exist_ok=True)
temp_out_dir = os.path.join(final_out_dir, "temp")
os.makedirs(temp_out_dir, exist_ok=True)
for target in tqdm(sample_list):
target_size = pd_single_size[pd_single_size["sample"] == target].iloc[0]["num_seqs"]
_neighbor = list(pd_aggr[pd_aggr["main_name"] == target]["neighbor"])
os.makedirs(os.path.join(temp_out_dir, target), exist_ok=True)
'''
Majority voting: find the frequent junction reads
'''
for _i, _temp_n in enumerate(_neighbor):
_df_junc = pd.read_csv(os.path.join(junction_file_path, _temp_n+junction_file_extension), sep="\t",usecols=[0, 1, 2, 7], names=["chr", "first_base", "last_base","multi_map"+_temp_n])
if _i == 0:
df_merge = _df_junc
else:
df_merge = pd.merge(df_merge, _df_junc, how="outer", left_on=["chr", "first_base", "last_base"], right_on=["chr", "first_base", "last_base"])
## count the occurence of the neighborhood junctions reads, only keep junction reads which is exist in half of the neighbor cells
df_merge["nont_na"] = N_neighbor - df_merge.drop(columns=["chr", "first_base", "last_base"]).isna().sum(axis=1)
df_keep_junct = df_merge[df_merge["nont_na"] >= int(N_neighbor/2)]
## save to bed file
df_keep_junct[["chr", "first_base", "last_base"]].to_csv(os.path.join(temp_out_dir, target, "keep_junction.bed"), sep="\t", index=False, header=False)
'''
Bam file batch size normalization
'''
for _n in _neighbor:
_n_seq = pd_single_size[pd_single_size["sample"] == _n].iloc[0]["num_seqs"]
shutil.copyfile(os.path.join(bam_file_path, _n+bam_file_extension), os.path.join(temp_out_dir, target, _n+".bam"))
if _n_seq == target_size:
os.rename(os.path.join(temp_out_dir, target, _n+".bam"), os.path.join(temp_out_dir, target, _n+".norm.bam"))
##===== Upsampling:
elif _n_seq < target_size:
## random sample some of the sequcen and then add together with original one
# concate itself n times, where n is the integer part of target_size/ _n_seq
_cat_self_n = int(target_size/ _n_seq)
if _cat_self_n == 1:
_add_seq_perct = (target_size - _n_seq)/_n_seq
else:
_add_seq_perct = (target_size - _n_seq*_cat_self_n)/_n_seq
## sample the reset seq reads
with open(os.path.join(temp_out_dir, target, _n + ".sample.bam"), "wb") as out_f:
subprocess.run(
["samtools", "view", "-b", "-s", str(_add_seq_perct), os.path.join(temp_out_dir, target, _n + ".bam")],
stdout=out_f,
check=True
)
## concatenate all
combine_name = ""
current_name = os.path.join(temp_out_dir, target, _n+'.bam')
for i in range(_cat_self_n):
if i == 0:
combine_name = current_name
else:
combine_name = combine_name + " " + current_name
combine_name = combine_name + " " + os.path.join(temp_out_dir, target, _n+'.sample.bam')
result_name = os.path.join(temp_out_dir, target, _n+".norm.bam")
os.system(f"samtools merge {result_name} {combine_name}")
os.remove(os.path.join(temp_out_dir, target, _n+".sample.bam"))
os.remove(os.path.join(temp_out_dir, target, _n+".bam"))
##===== Downsampling:
if _n_seq > target_size:
_keep_seq_perct = target_size/_n_seq
os.system(f"samtools view -b -s {_keep_seq_perct} {os.path.join(temp_out_dir, target, _n+'.bam')} > {os.path.join(temp_out_dir, target, _n+'.norm.bam')}")
os.remove(os.path.join(temp_out_dir, target, _n+".bam"))
'''
Bam file split to junction readsfile
'''
if _n != target:
os.system(f"samtools view -h {os.path.join(temp_out_dir, target, _n+'.norm.bam')} | awk '$0 ~ /^@/ || $6 ~ /N/' | samtools view -b > {os.path.join(temp_out_dir, target, _n+'.junction.norm.bam')}")
os.system(f"samtools index {os.path.join(temp_out_dir, target, _n+'.junction.norm.bam')}")
'''
Filter to only keep frequent junctions
'''
os.system(f"samtools view -h -L {os.path.join(temp_out_dir, target, 'keep_junction.bed')} {os.path.join(temp_out_dir, target, _n+'.junction.norm.bam')} > {os.path.join(temp_out_dir, target, _n+'.mj.junction.norm.bam')}")
'''
Final concate all normalized bam files
'''
# Final concate all normalized fastq files
os.system(f"samtools merge {os.path.join(final_out_dir, target+'.aggr.final.bam')} {os.path.join(temp_out_dir, target, '*.mj.junction.norm.bam')} {os.path.join(temp_out_dir, target, target+'.norm.bam')}")
shutil.rmtree((os.path.join(temp_out_dir, target)))