diff --git a/micall/core/denovo.py b/micall/core/denovo.py index 5b23c6388..eddfc13b8 100644 --- a/micall/core/denovo.py +++ b/micall/core/denovo.py @@ -89,6 +89,12 @@ def genotype(fasta, db=DEFAULT_DATABASE, blast_csv=None, group_refs=None): fraction of the query that aligned against the reference (matches and mismatches). """ + contig_nums = {} # {contig_name: contig_num} + with open(fasta) as f: + for line in f: + if line.startswith('>'): + contig_name = line[1:-1] + contig_nums[contig_name] = len(contig_nums) + 1 blast_columns = ['qaccver', 'saccver', 'pident', @@ -115,7 +121,7 @@ def genotype(fasta, db=DEFAULT_DATABASE, blast_csv=None, group_refs=None): blast_writer = None else: blast_writer = DictWriter(blast_csv, - ['contig_name', + ['contig_num', 'ref_name', 'score', 'match', @@ -152,9 +158,10 @@ def genotype(fasta, db=DEFAULT_DATABASE, blast_csv=None, group_refs=None): if int(match['send']) < int(match['sstart']): matched_fraction *= -1 pident = round(float(match['pident'])) - samples[match['qaccver']] = (match['saccver'], matched_fraction) + contig_name = match['qaccver'] + samples[contig_name] = (match['saccver'], matched_fraction) if blast_writer: - blast_writer.writerow(dict(contig_name=match['qaccver'], + blast_writer.writerow(dict(contig_num=contig_nums[contig_name], ref_name=match['saccver'], score=match['score'], match=matched_fraction, diff --git a/micall/core/plot_contigs.py b/micall/core/plot_contigs.py index ded5deed6..5f0565b62 100644 --- a/micall/core/plot_contigs.py +++ b/micall/core/plot_contigs.py @@ -5,7 +5,7 @@ from io import StringIO from itertools import groupby from math import log10, copysign -from operator import itemgetter +from operator import itemgetter, attrgetter from pathlib import Path import yaml @@ -80,31 +80,44 @@ def __init__(self, start, end, h=20, label=None): super().__init__(x=x, y=0, w=w, h=h) self.label = label + def __repr__(self): + if self.direction >= 0: + start = self.x + end = self.x + self.w + else: + end = self.x + start = self.x + self.w + return f'Arrow({start}, {end}, label={self.label!r})' + def draw(self, x=0, y=0, xscale=1.0): h = self.h a = self.x * xscale b = (self.x + self.w) * xscale x = x * xscale r = h/2 - font_size = h * 0.75 - arrow_size = 7 * xscale + font_size = h * 0.55 + arrow_size = 7 if self.direction >= 0: line_start = a arrow_end = b + arrow_start = max(arrow_end-arrow_size, line_start) else: line_start = b arrow_end = a - arrow_start = arrow_end - arrow_size*self.direction - centre = (a + b - self.direction*arrow_size)/2 - centre_start = centre - self.direction*r - centre_end = centre + self.direction*r + arrow_start = min(arrow_end+arrow_size, line_start) + centre = (a + b)/2 + arrow_y = h/2 + if abs(centre - arrow_start) < r: + arrow_y -= r group = draw.Group(transform="translate({} {})".format(x, y)) + group.append(draw.Line(line_start, arrow_y, + arrow_start, arrow_y, + stroke='black')) group.append(draw.Circle(centre, h/2, r, fill='ivory', stroke='black')) - group.append(draw.Line(line_start, h/2, centre_start, h/2, stroke='black')) - group.append(draw.Line(centre_end, h/2, arrow_start, h/2, stroke='black')) - group.append(draw.Lines(arrow_end, h/2, - arrow_start, (h + arrow_size)/2, - arrow_start, (h - arrow_size)/2, + group.append(draw.Lines(arrow_end, arrow_y, + arrow_start, arrow_y + arrow_size/2, + arrow_start, arrow_y - arrow_size/2, + arrow_end, arrow_y, fill='black')) group.append(draw.Text(self.label, font_size, @@ -117,55 +130,59 @@ def draw(self, x=0, y=0, xscale=1.0): class ArrowGroup(Element): def __init__(self, arrows: typing.Sequence[Arrow], gap=3): self.arrows = [] - coordinates = [] # [(start, end, index)] + self.y_coordinates = [] + x_coordinates = [] # [(start, end, index)] x1 = x2 = None for i, arrow in enumerate(arrows): - self.arrows.append([0, arrow]) + self.arrows.append(arrow) + self.y_coordinates.append(0) arrow_x2 = arrow.x + arrow.w - coordinates.append((arrow.x, arrow_x2, i)) + x_coordinates.append((arrow.x, arrow_x2, i)) if x1 is None: x1, x2 = arrow.x, arrow_x2 else: x1 = min(x1, arrow.x) x2 = max(x2, arrow_x2) h = 0 - while coordinates: + while x_coordinates: if h > 0: h += gap - row_end = coordinates[0][0]-1 + row_end = x_coordinates[0][0]-1 row_height = 0 - for key in coordinates[:]: + for key in x_coordinates[:]: x1, x2, i = key if x1 < row_end: continue - coordinates.remove(key) - row = self.arrows[i] - arrow_h = row[1].h - row[0] = h + arrow_h + x_coordinates.remove(key) + arrow = self.arrows[i] + arrow_h = arrow.h + self.y_coordinates[i] = h + arrow_h row_height = max(row_height, arrow_h) row_end = x2 h += row_height - for row in self.arrows: - row[0] -= h + self.y_coordinates = [y-h for y in self.y_coordinates] super().__init__(x1, 0, w=x2-x1, h=h) def draw(self, x=0, y=0, xscale=1.0): group = draw.Group(transform="translate({} {})".format(x, y)) - for i, (child_y, arrow) in enumerate(self.arrows): + for i, (child_y, arrow) in enumerate(zip(self.y_coordinates, + self.arrows)): group.append(arrow.draw(y=child_y, xscale=xscale)) return group -def plot_genome_coverage(genome_coverage_csv, genome_coverage_svg_path): - f = build_coverage_figure(genome_coverage_csv) +def plot_genome_coverage(genome_coverage_csv, + blast_csv, + genome_coverage_svg_path): + f = build_coverage_figure(genome_coverage_csv, blast_csv) f.show(w=970).saveSvg(genome_coverage_svg_path) -def build_coverage_figure(genome_coverage_csv): +def build_coverage_figure(genome_coverage_csv, blast_csv=None): min_position, max_position = 1, 500 coordinate_depths = Counter() contig_depths = Counter() - contig_groups = defaultdict(set) + contig_groups = defaultdict(set) # {coordinates_name: {contig_name}} reader = DictReader(genome_coverage_csv) for row in reader: query_nuc_pos = int(row['query_nuc_pos']) @@ -191,6 +208,14 @@ def build_coverage_figure(genome_coverage_csv): position_offset = -min_position + 1 max_position += position_offset + blast_rows = [] + if blast_csv is not None: + for blast_row in DictReader(blast_csv): + for field_name in ('start', 'end', 'ref_start', 'ref_end'): + blast_row[field_name] = int(blast_row[field_name]) + blast_rows.append(blast_row) + blast_rows.sort(key=itemgetter('start', 'ref_start')) + landmarks_path = (Path(__file__).parent.parent / "data" / "landmark_references.yaml") landmark_groups = yaml.safe_load(landmarks_path.read_text()) @@ -221,18 +246,66 @@ def build_coverage_figure(genome_coverage_csv): break else: add_partial_banner(f, position_offset, max_position) - for _, contig_name in sorted((-contig_depths[name], name) - for name in contig_groups[coordinates_name]): + sorted_contig_names = [ + contig_name + for _, contig_name in sorted( + (-contig_depths[name], name) + for name in contig_groups[coordinates_name])] + ref_arrows = [] + for contig_name in sorted_contig_names: + contig_num, contig_ref = contig_name.split('-', 1) + arrow_count = 0 + for blast_row in blast_rows: + if blast_row['contig_num'] != contig_num: + continue + if blast_row['ref_name'] != contig_ref: + continue + arrow_count += 1 + ref_start = int(blast_row['ref_start']) + ref_end = int(blast_row['ref_end']) + ref_arrows.append(Arrow(ref_start, + ref_end, + label=f'{contig_num}.{arrow_count}')) + if ref_arrows: + f.add(ArrowGroup(ref_arrows)) + for contig_name in sorted_contig_names: genome_coverage_csv.seek(0) reader = DictReader(genome_coverage_csv) - build_contig(reader, f, contig_name, max_position, position_offset) + build_contig(reader, + f, + contig_name, + max_position, + position_offset, + blast_rows) if not f.elements: f.add(Track(1, max_position, label='No contigs found.', color='none')) return f -def build_contig(reader, f, contig_name, max_position, position_offset): +def build_contig(reader, + f, + contig_name, + max_position, + position_offset, + blast_rows): + contig_num, contig_ref = contig_name.split('-', 1) + blast_ranges = [] # [[start, end, blast_num]] + blast_starts = {} # {start: blast_num} + blast_ends = {} # {end: blast_num} + for blast_row in blast_rows: + if blast_row['contig_num'] != contig_num: + continue + if blast_row['ref_name'] != contig_ref: + continue + blast_num = len(blast_ranges) + 1 + blast_ranges.append([None, None, blast_num]) + blast_starts[blast_row['start']] = blast_num + blast_ends[blast_row['end']] = blast_num + event_positions = set(blast_starts) + event_positions.update(blast_ends) + event_positions = sorted(event_positions, reverse=True) + insertion_size = 0 insertion_ranges = [] # [(start, end)] for contig_name2, contig_rows in groupby(reader, itemgetter('contig')): @@ -263,6 +336,23 @@ def build_contig(reader, f, contig_name, max_position, position_offset): if contig_row['coverage'] is not None: coverage[pos - start] = (contig_row['coverage'] - contig_row['dels']) + contig_pos = int(contig_row['query_nuc_pos']) + while event_positions and event_positions[-1] <= contig_pos: + event_pos = event_positions.pop() + blast_num = blast_starts.get(event_pos) + if blast_num is not None: + blast_ranges[blast_num-1][0] = pos + blast_num = blast_ends.get(event_pos) + if blast_num is not None: + blast_ranges[blast_num-1][1] = pos + + arrows = [] + for arrow_start, arrow_end, blast_num in blast_ranges: + arrows.append(Arrow(arrow_start, + arrow_end, + label=f'{contig_num}.{blast_num}')) + if arrows: + f.add(ArrowGroup(arrows)) subtracks = [] for has_coverage, group_positions in groupby( enumerate(coverage), @@ -330,7 +420,11 @@ def summarize_figure(figure: Figure): summary = StringIO() for padding, track in figure.elements: - spans = getattr(track, 'tracks', [track]) + spans = getattr(track, 'arrows', None) + if spans is None: + spans = getattr(track, 'tracks', [track]) + else: + spans.sort(key=attrgetter('x', 'w', 'label')) for i, span in enumerate(spans): if i: summary.write(', ') @@ -349,6 +443,13 @@ def summarize_figure(figure: Figure): if count > 1: summary.write(f'x{count}') continue + direction = getattr(span, 'direction', None) + if direction is not None and direction != '': + if direction >= 0: + summary.write(f'{span.x}--{span.label}->{span.x+span.w}') + else: + summary.write(f'{span.x}<-{span.label}--{span.x+span.w}') + continue span_text = getattr(span.label, 'text', span.label) or '' summary.write(span_text) color = getattr(span, 'color') @@ -371,11 +472,16 @@ def main(): parser.add_argument('genome_coverage_csv', help='CSV file with coverage counts for each contig', type=FileType()) + parser.add_argument('blast_csv', + help='CSV file with BLAST results for each contig', + type=FileType()) parser.add_argument('genome_coverage_svg', help='SVG file to plot coverage counts for each contig') args = parser.parse_args() - plot_genome_coverage(args.genome_coverage_csv, args.genome_coverage_svg) + plot_genome_coverage(args.genome_coverage_csv, + args.blast_csv, + args.genome_coverage_svg) print('Wrote', args.genome_coverage_svg) diff --git a/micall/drivers/sample.py b/micall/drivers/sample.py index 738e0bdca..b502b5152 100644 --- a/micall/drivers/sample.py +++ b/micall/drivers/sample.py @@ -221,8 +221,11 @@ def process(self, coverage_maps_prefix=self.name, excluded_projects=excluded_projects) - with open(self.genome_coverage_csv) as genome_coverage_csv: - plot_genome_coverage(genome_coverage_csv, self.genome_coverage_svg) + with open(self.genome_coverage_csv) as genome_coverage_csv, \ + open(self.blast_csv) as blast_csv: + plot_genome_coverage(genome_coverage_csv, + blast_csv, + self.genome_coverage_svg) logger.info('Running cascade_report on %s.', self) with open(self.g2p_summary_csv) as g2p_summary_csv, \ diff --git a/micall/tests/test_denovo.py b/micall/tests/test_denovo.py index b53b9cedd..dfba2715e 100644 --- a/micall/tests/test_denovo.py +++ b/micall/tests/test_denovo.py @@ -136,10 +136,10 @@ def test_write_blast(tmpdir, hcv_db): """ blast_csv = StringIO() expected_blast_csv = """\ -contig_name,ref_name,score,match,pident,start,end,ref_start,ref_end -bar,HCV-1g,37,0.67,100,19,55,8506,8542 -bar,HCV-1a,41,0.75,100,15,55,8518,8558 -foo,HCV-1a,41,1.0,100,1,41,8187,8227 +contig_num,ref_name,score,match,pident,start,end,ref_start,ref_end +2,HCV-1g,37,0.67,100,19,55,8506,8542 +2,HCV-1a,41,0.75,100,15,55,8518,8558 +1,HCV-1a,41,1.0,100,1,41,8187,8227 """ write_contig_refs(str(contigs_fasta), contigs_csv, blast_csv=blast_csv) diff --git a/micall/tests/test_plot_contigs.py b/micall/tests/test_plot_contigs.py index 71f8d976c..34a8acaab 100644 --- a/micall/tests/test_plot_contigs.py +++ b/micall/tests/test_plot_contigs.py @@ -4,8 +4,8 @@ from turtle import Turtle import pytest -from PIL import Image, ImageChops -from drawSvg import Drawing, Line, Lines, Circle, Text +from PIL import Image +from drawSvg import Drawing, Line, Lines, Circle, Text, Rectangle from genetracks import Figure, Track, Multitrack, Label, Coverage from micall.core.plot_contigs import summarize_figure, \ @@ -142,7 +142,7 @@ def test_summarize_labels(): summary = summarize_figure(figure) - assert expected_summary == summary + assert summary == expected_summary def test_summarize_label_objects(): @@ -156,7 +156,7 @@ def test_summarize_label_objects(): summary = summarize_figure(figure) - assert expected_summary == summary + assert summary == expected_summary def test_summarize_multitracks(): @@ -171,7 +171,7 @@ def test_summarize_multitracks(): summary = summarize_figure(figure) - assert expected_summary == summary + assert summary == expected_summary def test_summarize_multitracks_with_separate_label(): @@ -187,7 +187,7 @@ def test_summarize_multitracks_with_separate_label(): summary = summarize_figure(figure) - assert expected_summary == summary + assert summary == expected_summary def test_summarize_regions(): @@ -200,7 +200,7 @@ def test_summarize_regions(): summary = summarize_figure(figure) - assert expected_summary == summary + assert summary == expected_summary def test_summarize_coverage(): @@ -214,7 +214,7 @@ def test_summarize_coverage(): summary = summarize_figure(figure) - assert expected_summary == summary + assert summary == expected_summary def test_summarize_zero_coverage(): @@ -237,7 +237,7 @@ def test_summarize_smooth_coverage(): summary = summarize_figure(figure) - assert expected_summary == summary + assert summary == expected_summary def test_summarize_smooth_coverage_ten_percent(): @@ -251,7 +251,37 @@ def test_summarize_smooth_coverage_ten_percent(): summary = summarize_figure(figure) - assert expected_summary == summary + assert summary == expected_summary + + +def test_summarize_arrow(): + figure = Figure() + figure.add(Arrow(10, 50, label='Foo')) + figure.add(Arrow(60, 30, label='Bar')) + expected_summary = """\ +10--Foo->50 +30<-Bar--60 +""" + + summary = summarize_figure(figure) + + assert summary == expected_summary + + +def test_summarize_arrow_group(): + figure = Figure() + figure.add(ArrowGroup([Arrow(10, 50, label='Foo'), + Arrow(60, 30, label='Bar')])) + figure.add(ArrowGroup([Arrow(1, 50, label='Baz'), + Arrow(90, 100, label='Boom')])) + expected_summary = """\ +10--Foo->50, 30<-Bar--60 +1--Baz->50, 90--Boom->100 +""" + + summary = summarize_figure(figure) + + assert summary == expected_summary def test_plot_genome_coverage(): @@ -516,24 +546,53 @@ def test_plot_genome_coverage_insertion(): assert expected_figure == summarize_figure(figure) +def test_plot_genome_coverage_blast(): + genome_coverage_csv = StringIO("""\ +contig,coordinates,query_nuc_pos,refseq_nuc_pos,ins,dels,coverage +1-HCV-1a,HCV-1a,1,8001,0,0,5 +1-HCV-1a,HCV-1a,2,8002,0,0,5 +1-HCV-1a,HCV-1a,3,8003,0,0,7 +1-HCV-1a,HCV-1a,4,8004,0,0,5 +1-HCV-1a,HCV-1a,5,8005,0,0,5 +1-HCV-1a,HCV-1a,6,8006,0,0,5 +""") + blast_csv = StringIO("""\ +contig_num,ref_name,score,match,pident,start,end,ref_start,ref_end +1,HCV-1g,30,0.33,90,1,2,5001,5002 +1,HCV-1a,40,0.33,100,5,6,7006,7005 +1,HCV-1a,50,0.5,100,1,3,8001,8003 +""") + expected_figure = """\ +5'[1-341], C[342-914], E1[915-1490], E2[1491-2579], p7[2580-2768], \ +NS2[2769-3419], NS3[3420-5312], NS4b[5475-6257], NS4a[5313-5474], \ +NS5a[6258-7601], NS5b[7602-9374], 3'[9375-9646] +7005<-1.2--7006, 8001--1.1->8003 +8001--1.1->8003, 8005--1.2->8006 +Coverage 5x2, 7, 5x3 +[8001-8006], 1-HCV-1a - depth 7(1-9646) +""" + + figure = build_coverage_figure(genome_coverage_csv, blast_csv) + + assert summarize_figure(figure) == expected_figure + + # noinspection DuplicatedCode def test_arrow(svg_differ): - expected_svg = Drawing(175.0, 35.0, origin=(0, 0)) - expected_svg.append(Circle(168/2, 20, 10, stroke='black', fill='ivory')) - expected_svg.append(Text('X', - 15, - 168/2, 20, + f, expected_svg = start_drawing(200, 55) + expected_svg.append(Line(0, 20, 168, 20, stroke='black')) + expected_svg.append(Circle(175/2, 20, 10, stroke='black', fill='ivory')) + expected_svg.append(Text('1.2', + 11, + 175/2, 20, text_anchor='middle', - dy="0.3em")) - expected_svg.append(Line(0, 20, 74, 20, stroke='black')) - expected_svg.append(Line(94, 20, 168, 20, stroke='black')) + dy="0.35em")) expected_svg.append(Lines(175, 20, 168, 23.5, 168, 16.5, 175, 20, fill='black')) - f = Figure() - f.add(Arrow(0, 175, h=20, label='X')) + f.add(Arrow(0, 175, h=20, label='1.2')) svg = f.show() svg_differ.assert_equal(svg, expected_svg, 'test_arrow') @@ -541,27 +600,113 @@ def test_arrow(svg_differ): # noinspection DuplicatedCode def test_reverse_arrow(svg_differ): - expected_svg = Drawing(175.0, 35.0, origin=(0, 0)) - expected_svg.append(Circle(7+168/2, 20, 10, stroke='black', fill='ivory')) + f, expected_svg = start_drawing(200, 55) + expected_svg.append(Line(7, 20, 175, 20, stroke='black')) + expected_svg.append(Circle(175/2, 20, 10, stroke='black', fill='ivory')) expected_svg.append(Text('X', - 15, - 7+168/2, 20, + 11, + 175/2, 20, text_anchor='middle', - dy="0.3em")) - expected_svg.append(Line(7, 20, 81, 20, stroke='black')) - expected_svg.append(Line(101, 20, 175, 20, stroke='black')) + dy="0.35em")) expected_svg.append(Lines(0, 20, 7, 23.5, 7, 16.5, 0, 20, fill='black')) - f = Figure() f.add(Arrow(175, 0, h=20, label='X')) svg = f.show() svg_differ.assert_equal(svg, expected_svg, 'test_arrow') +# noinspection DuplicatedCode +def test_scaled_arrow(svg_differ): + expected_svg = Drawing(100, 35, origin=(0, 0)) + expected_svg.append(Line(0, 20, 93, 20, stroke='black')) + expected_svg.append(Circle(50, 20, 10, stroke='black', fill='ivory')) + expected_svg.append(Text('2.3', + 11, + 50, 20, + text_anchor='middle', + dy="0.35em")) + expected_svg.append(Lines(100, 20, + 93, 23.5, + 93, 16.5, + 100, 20, + fill='black')) + + f = Figure() + f.add(Arrow(0, 200, h=20, label='2.3')) + svg = f.show(w=100) + + svg_differ.assert_equal(svg, expected_svg, 'test_arrow') + + +# noinspection DuplicatedCode +def test_small_arrow(svg_differ): + f, expected_svg = start_drawing(200, 55) + expected_svg.append(Line(100, 10, 125, 10, stroke='black')) + expected_svg.append(Circle(116, 20, 10, stroke='black', fill='ivory')) + expected_svg.append(Text('2.3', + 11, + 116, 20, + text_anchor='middle', + dy="0.35em")) + expected_svg.append(Lines(132, 10, + 125, 13.5, + 125, 6.5, + 132, 10, + fill='black')) + + f.add(Arrow(100, 132, h=20, label='2.3')) + svg = f.show() + + svg_differ.assert_equal(svg, expected_svg, 'test_arrow') + + +# noinspection DuplicatedCode +def test_tiny_arrow(svg_differ): + f, expected_svg = start_drawing(200.0, 55.0) + expected_svg.append(Circle(102, 20, 10, stroke='black', fill='ivory')) + expected_svg.append(Text('2.3', + 11, + 102, 20, + text_anchor='middle', + dy="0.35em")) + expected_svg.append(Lines(104, 10, + 100, 13.5, + 100, 6.5, + 104, 10, + fill='black')) + + f.add(Arrow(100, 104, h=20, label='2.3')) + svg = f.show() + + svg_differ.assert_equal(svg, expected_svg, 'test_arrow') + + +def start_drawing(width, height): + expected_svg = Drawing(width, height, origin=(0, 0)) + expected_svg.append(Rectangle(0, height-15, + 200, 10, + stroke='lightgrey', + fill='lightgrey')) + expected_svg.append(Text('Header', + 10, + width/2, height-15, + font_family='monospace', + text_anchor='middle')) + f = Figure() + f.add(Track(0, width, label='Header')) + return f, expected_svg + + +def test_arrow_repr(svg_differ): + arrow = Arrow(175, 0, label='X') + + assert repr(arrow) == "Arrow(175, 0, label='X')" + + # noinspection DuplicatedCode def test_arrow_group(svg_differ): expected_figure = Figure()