Skip to content

Commit

Permalink
Merge pull request #232 from andersen-lab/226-freyja-covariants-fixes
Browse files Browse the repository at this point in the history
226 freyja covariants fixes
  • Loading branch information
dylanpilz authored Jun 7, 2024
2 parents 7b4c78b + 48f32de commit f82a9bd
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 32 deletions.
6 changes: 4 additions & 2 deletions freyja/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,9 +845,11 @@ def filter(query_mutations, input_bam, min_site, max_site, output):
'"count" or "freq" to sort patterns by count or frequency '
'(in descending order). Set to "site" to sort patterns by '
'start site (n ascending order).'), show_default=True)
@click.option('--threads', default=1, help='number of parallet processes to '
'use. Recommended for large BAM files.', show_default=True)
def covariants(input_bam, min_site, max_site, output,
ref_genome, annot, min_quality, min_count, spans_region,
sort_by):
sort_by, threads):
"""
Finds mutations co-occurring on the same read pair
in BAM_FILE between MIN_SITE and MAX_SITE
Expand All @@ -860,7 +862,7 @@ def covariants(input_bam, min_site, max_site, output,
from freyja.read_analysis_tools import covariants as _covariants
_covariants(input_bam, min_site, max_site, output,
ref_genome, annot, min_quality, min_count, spans_region,
sort_by)
sort_by, threads)


@cli.command()
Expand Down
95 changes: 65 additions & 30 deletions freyja/read_analysis_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.patches import Patch, Rectangle
import concurrent.futures as cf
from concurrent.futures import ProcessPoolExecutor
import pysam
from Bio.Seq import MutableSeq
from Bio import SeqIO

from freyja.read_analysis_utils import nt_position, get_colnames_and_sites, \
read_pair_generator, \
filter_covariants_output, parse_gff
read_pair_generator, \
filter_covariants_output, parse_gff, get_gene


def extract(query_mutations, input_bam, output, same_read):
Expand Down Expand Up @@ -84,15 +86,15 @@ def extract(query_mutations, input_bam, output, same_read):
if x is None:
break

if x.cigarstring is None:
continue

ref_pos = set(x.get_reference_positions())
start = x.reference_start
sites_in = list(ref_pos & set(snp_sites))

seq = x.query_alignment_sequence

if x.cigarstring is None:
# checks for a possible fail case
continue
cigar = re.findall(r'(\d+)([A-Z]{1})', x.cigarstring)

# Find insertions
Expand Down Expand Up @@ -315,15 +317,54 @@ def filter(query_mutations, input_bam, min_site, max_site, output):
return final_reads


def run_parallel(cores, regions, input_bam, ref_fasta, gff_file, min_quality,
min_count, spans_region):
with ProcessPoolExecutor(max_workers=cores) as p:
futures = {p.submit(process_covariants, input_bam,
region[0], region[1], ref_fasta, gff_file,
min_quality, min_count, spans_region)
for region in regions}
for future in cf.as_completed(futures):
yield future.result()


def covariants(input_bam, min_site, max_site, output,
ref_fasta, gff_file, min_quality, min_count, spans_region,
sort_by):
sort_by, cores):
# Split regions into smaller chunks for parallel processing
regions = [(min_site, max_site)]
if cores > 1:
region_size = (max_site - min_site) // cores
regions = [(min_site + i * region_size, min_site +
(i + 1) * region_size) for i in range(cores)]
regions[-1] = (regions[-1][0], max_site)

# Run parallel processing
results = run_parallel(cores, regions, input_bam, ref_fasta,
gff_file, min_quality, min_count, spans_region)

# Aggregate results
df = pd.concat(results)

df = df.sort_values('Max_count', ascending=False).drop_duplicates(
subset='Covariants', keep='first')

# Sort patterns
if sort_by.lower() == 'count':
df = df.sort_values('Count', ascending=False)
elif sort_by.lower() == 'freq':
df = df.sort_values('Freq', ascending=False)
elif sort_by.lower() == 'site':
df['sort_col'] = [nt_position(s.split(' ')[0]) for s in df.Covariants]
df = df.sort_values('sort_col').drop(labels='sort_col', axis=1)

df.to_csv(output, sep='\t', index=False)
print(f'covariants: Output saved to {output}')
return df


def get_gene(locus):
for gene in gene_positions:
start, end = gene_positions[gene]
if locus in range(start, end+1):
return gene, start
def process_covariants(input_bam, min_site, max_site, ref_fasta, gff_file,
min_quality, min_count, spans_region):

# Load reference genome
ref_genome = MutableSeq(next(SeqIO.parse(ref_fasta, 'fasta')).seq)
Expand Down Expand Up @@ -393,16 +434,16 @@ def get_gene(locus):

snps_found = []

# checks for a possible fail case, mapq = 0
if x.cigarstring is None:
continue

# Update coverage ranges
if coverage_start is None or start < coverage_start:
coverage_start = start
if coverage_end is None or end > coverage_end:
coverage_end = end

if x.cigarstring is None:
# checks for a possible fail case
continue

cigar = re.findall(r'(\d+)([A-Z]{1})', x.cigarstring)
if 'I' in x.cigarstring:
i = 0
Expand Down Expand Up @@ -470,16 +511,21 @@ def get_gene(locus):
last_del_site = start+i

# Find SNPs

softclip_offset = 0
if cigar[0][1] == 'S':
softclip_offset = int(cigar[0][0])
softclip_offset += int(cigar[0][0])
if (len(cigar) > 1 and cigar[0][1] == 'H' and cigar[1][1] == 'S'):
softclip_offset += int(cigar[1][0])

pairs = x.get_aligned_pairs(matches_only=True)

for tup in pairs:
read_site, ref_site = tup

read_site -= softclip_offset
ref_base = ref_genome[ref_site]

if seq[read_site] != 'N' and ref_base != seq[read_site]:
snps_found.append(
f'{ref_base.upper()}{ref_site+1}{seq[read_site]}'
Expand All @@ -493,7 +539,7 @@ def get_gene(locus):
if gff_file is not None:
for ins in insertions_found:
locus = ins[0]
gene_info = get_gene(locus)
gene_info = get_gene(locus, gene_positions)
ins_string = str(ins).replace(' ', '')
if gene_info is None or ins_string in nt_to_aa:
continue
Expand All @@ -512,7 +558,7 @@ def get_gene(locus):

for deletion in deletions_found:
locus = deletion[0]
gene_info = get_gene(locus)
gene_info = get_gene(locus, gene_positions)
deletion_string = str(deletion).replace(' ', '')
if gene_info is None or deletion_string in nt_to_aa:
continue
Expand All @@ -532,7 +578,7 @@ def get_gene(locus):

for snp in snps_found:
locus = int(snp[1:-1])
gene_info = get_gene(locus)
gene_info = get_gene(locus, gene_positions)
if gene_info is None or snp in nt_to_aa:
continue

Expand Down Expand Up @@ -672,17 +718,6 @@ def get_gene(locus):

df = df[df['Count'] >= min_count]

# Sort patterns
if sort_by.lower() == 'count':
df = df.sort_values('Count', ascending=False)
elif sort_by.lower() == 'freq':
df = df.sort_values('Freq', ascending=False)
elif sort_by.lower() == 'site':
df['sort_col'] = [nt_position(s.split(' ')[0]) for s in df.Covariants]
df = df.sort_values('sort_col').drop(labels='sort_col', axis=1)

df.to_csv(output, sep='\t', index=False)
print(f'covariants: Output saved to {output}')
return df


Expand Down
7 changes: 7 additions & 0 deletions freyja/read_analysis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,10 @@ def translate_snps(snps, ref, gene_positions):
output[snp] = aa_mut

return output


def get_gene(locus, gene_positions):
for gene in gene_positions:
start, end = gene_positions[gene]
if locus in range(start, end+1):
return gene, start

0 comments on commit f82a9bd

Please sign in to comment.