Skip to content

Commit

Permalink
use gene merger for annotation-based runs too
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewprzh committed May 5, 2024
1 parent d0d3bb5 commit 573532a
Showing 1 changed file with 45 additions and 36 deletions.
81 changes: 45 additions & 36 deletions src/graph_based_model_construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,16 +112,6 @@ def select_reference_gene(self, transcript_introns, transcript_range, transcript
if transcript_strand == '.' or self.gene_info.gene_strands[gene_id] == transcript_strand:
return gene_id

overlap_dict = {}
gene_regions = self.gene_info.get_gene_regions()
for gene_id in gene_regions.keys():
gene_coverage = read_coverage_fraction([transcript_range], [gene_regions[gene_id]])
if gene_coverage > 0.0 and \
(transcript_strand == '.' or self.gene_info.gene_strands[gene_id] == transcript_strand):
overlap_dict[gene_id] = gene_coverage

if overlap_dict:
return get_top_count(overlap_dict)
return None

def process(self, read_assignment_storage):
Expand Down Expand Up @@ -152,9 +142,8 @@ def process(self, read_assignment_storage):
if any(value < 0 for value in self.read_assignment_counts.values()):
logger.warning("Negative values in read assignment counts")

if not self.gene_info.all_isoforms_exons:
transcript_joiner = TranscriptToGeneJoiner(self.transcript_model_storage)
self.transcript_model_storage = transcript_joiner.join_transcripts()
transcript_joiner = TranscriptToGeneJoiner(self.transcript_model_storage, self.gene_info)
self.transcript_model_storage = transcript_joiner.join_transcripts()

if self.params.sqanti_output:
self.compare_models_with_known()
Expand Down Expand Up @@ -620,9 +609,7 @@ def generate_monoexon_from_clustered(self, clustered_reads, forward=True):
strand = '+' if forward else '-'
coordinates = (five_prime_pos, three_prime_pos) if forward else (three_prime_pos, five_prime_pos)
new_transcript_id = self.transcript_prefix + str(self.get_transcript_id())
transcript_gene = self.select_reference_gene([], coordinates, strand)
if not transcript_gene:
transcript_gene = "novel_gene_" + self.gene_info.chr_id + "_" + str(self.get_transcript_id())
transcript_gene = "novel_gene_" + self.gene_info.chr_id + "_" + str(self.get_transcript_id())
transcript_type = TranscriptModelType.novel_not_in_catalog
id_suffix = self.nnic_transcript_suffix

Expand Down Expand Up @@ -921,49 +908,66 @@ def thread_starts(self, intron, start, trusted=False):


class TranscriptToGeneJoiner:
def __init__(self, transcipt_model_storage):
def __init__(self, transcipt_model_storage, gene_info):
self.gene_info = gene_info
self.transcipt_model_storage = transcipt_model_storage
self.gene_introns = {}
self.gene_introns = defaultdict(set)
self.gene_strands = {}
self.gene_exon_regions = {}
self.gene_to_transcripts = {}
self.gene_regions = {}
self.gene_to_transcripts = defaultdict(set)

for gene_id in self.gene_info.gene_strands.keys():
self.gene_strands[gene_id] = self.gene_info.gene_strands[gene_id]
self.gene_regions[gene_id] = self.gene_info.get_gene_regions()[gene_id]
for transcript_id in self.gene_info.gene_id_map.keys():
gene_id = self.gene_info.gene_id_map[transcript_id]
self.gene_introns[gene_id].update(self.gene_info.all_isoforms_introns[transcript_id])
self.gene_to_transcripts[gene_id].add(transcript_id)

for t in self.transcipt_model_storage:
self.gene_exon_regions[t.gene_id] = t.exon_blocks
self.gene_introns[t.gene_id] = set(junctions_from_blocks(t.exon_blocks))
self.gene_strands[t.gene_id] = t.strand
self.gene_to_transcripts[t.gene_id] = {t.transcript_id}
if t.transcript_type == TranscriptModelType.known:
continue

if t.gene_id not in self.gene_regions:
self.gene_regions[t.gene_id] = (t.exon_blocks[0][0], t.exon_blocks[-1][1])
self.gene_strands[t.gene_id] = t.strand
else:
self.gene_regions[t.gene_id] = (min(self.gene_regions[t.gene_id][0], t.exon_blocks[0][0]),
max(self.gene_regions[t.gene_id][1], t.exon_blocks[-1][1]))
assert self.gene_strands[t.gene_id] == t.strand
self.gene_introns[t.gene_id].update(junctions_from_blocks(t.exon_blocks))
self.gene_to_transcripts[t.gene_id].add(t.transcript_id)
self.scores = {}

def count_score(self, gene1, gene2):
logger.debug("Counting score %s %s" % (gene2, gene1))
if self.gene_strands[gene1] != self.gene_strands[gene2]:
return 0.0
intronic_overlap = len(self.gene_introns[gene1].intersection(self.gene_introns[gene2])) / \
max(1, len(self.gene_introns[gene1].union(self.gene_introns[gene2])))
exonic_ranges1 = self.gene_exon_regions[gene1]
exonic_ranges2 = self.gene_exon_regions[gene2]
position_overlap = jaccard_similarity(exonic_ranges1, exonic_ranges2)
intronic_overlap = (len(self.gene_introns[gene1].intersection(self.gene_introns[gene2])) /
max(1, len(self.gene_introns[gene1].union(self.gene_introns[gene2]))))
gene_range1 = self.gene_regions[gene1]
gene_range2 = self.gene_regions[gene2]
position_overlap = jaccard_similarity([gene_range1], [gene_range2])
return position_overlap + intronic_overlap

def count_scores(self):
for g1_id in self.gene_to_transcripts.keys():
for g2_id in self.gene_to_transcripts.keys():
if g1_id == g2_id:
if g1_id == g2_id or (g1_id in self.gene_info.gene_strands and g2_id in self.gene_info.gene_strands):
continue
gene_pair = tuple(sorted([g1_id, g2_id]))
if gene_pair not in self.scores:
self.scores[gene_pair] = self.count_score(g1_id, g2_id)

def merge_genes(self, gene1, gene2):
logger.debug("Merging %s into %s" % (gene2, gene1))
exonic_ranges1 = self.gene_exon_regions[gene1]
exonic_ranges2 = self.gene_exon_regions[gene2]
self.gene_exon_regions[gene1] = merge_ranges(exonic_ranges1, exonic_ranges2)
self.gene_regions[gene1] = (min(self.gene_regions[gene1][0], self.gene_regions[gene2][0]),
max(self.gene_regions[gene1][1], self.gene_regions[gene2][1]))
self.gene_introns[gene1].update(self.gene_introns[gene2])
self.gene_to_transcripts[gene1].update(self.gene_to_transcripts[gene2])
if self.gene_strands[gene2] != self.gene_strands[gene1]:
logger.error("Merging genes with different strands: %s, %s" % (gene1, gene2))
del self.gene_exon_regions[gene2]
del self.gene_regions[gene2]
del self.gene_introns[gene2]
del self.gene_to_transcripts[gene2]
del self.gene_strands[gene2]
Expand All @@ -981,11 +985,16 @@ def merge_genes(self, gene1, gene2):

def join_transcripts(self):
self.count_scores()
while len(self.gene_to_transcripts) > 1:
while len(self.scores) > 1:
best_gene_pair = max(self.scores, key=self.scores.get)
if self.scores[best_gene_pair] < 0.1:
break
self.merge_genes(*best_gene_pair)
if best_gene_pair[0] in self.gene_info.gene_strands:
assert best_gene_pair[1] not in self.gene_info.gene_strands
self.merge_genes(best_gene_pair[0], best_gene_pair[1])
else:
assert best_gene_pair[0] not in self.gene_info.gene_strands
self.merge_genes(best_gene_pair[1], best_gene_pair[0])

transcript_to_new_gene_id = {}
for gene_id, t_list in self.gene_to_transcripts.items():
Expand Down

0 comments on commit 573532a

Please sign in to comment.