Skip to content

Commit

Permalink
Add blast display to contig_coverage maps, as part of #442.
Browse files Browse the repository at this point in the history
Add handling for small arrows.
  • Loading branch information
donkirkby committed Oct 29, 2019
1 parent 6bcfe89 commit 6aa5f5d
Show file tree
Hide file tree
Showing 5 changed files with 333 additions and 72 deletions.
13 changes: 10 additions & 3 deletions micall/core/denovo.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ def genotype(fasta, db=DEFAULT_DATABASE, blast_csv=None, group_refs=None):
fraction of the query that aligned against the reference (matches and
mismatches).
"""
contig_nums = {} # {contig_name: contig_num}
with open(fasta) as f:
for line in f:
if line.startswith('>'):
contig_name = line[1:-1]
contig_nums[contig_name] = len(contig_nums) + 1
blast_columns = ['qaccver',
'saccver',
'pident',
Expand All @@ -115,7 +121,7 @@ def genotype(fasta, db=DEFAULT_DATABASE, blast_csv=None, group_refs=None):
blast_writer = None
else:
blast_writer = DictWriter(blast_csv,
['contig_name',
['contig_num',
'ref_name',
'score',
'match',
Expand Down Expand Up @@ -152,9 +158,10 @@ def genotype(fasta, db=DEFAULT_DATABASE, blast_csv=None, group_refs=None):
if int(match['send']) < int(match['sstart']):
matched_fraction *= -1
pident = round(float(match['pident']))
samples[match['qaccver']] = (match['saccver'], matched_fraction)
contig_name = match['qaccver']
samples[contig_name] = (match['saccver'], matched_fraction)
if blast_writer:
blast_writer.writerow(dict(contig_name=match['qaccver'],
blast_writer.writerow(dict(contig_num=contig_nums[contig_name],
ref_name=match['saccver'],
score=match['score'],
match=matched_fraction,
Expand Down
176 changes: 141 additions & 35 deletions micall/core/plot_contigs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from io import StringIO
from itertools import groupby
from math import log10, copysign
from operator import itemgetter
from operator import itemgetter, attrgetter
from pathlib import Path

import yaml
Expand Down Expand Up @@ -80,31 +80,44 @@ def __init__(self, start, end, h=20, label=None):
super().__init__(x=x, y=0, w=w, h=h)
self.label = label

def __repr__(self):
if self.direction >= 0:
start = self.x
end = self.x + self.w
else:
end = self.x
start = self.x + self.w
return f'Arrow({start}, {end}, label={self.label!r})'

def draw(self, x=0, y=0, xscale=1.0):
h = self.h
a = self.x * xscale
b = (self.x + self.w) * xscale
x = x * xscale
r = h/2
font_size = h * 0.75
arrow_size = 7 * xscale
font_size = h * 0.55
arrow_size = 7
if self.direction >= 0:
line_start = a
arrow_end = b
arrow_start = max(arrow_end-arrow_size, line_start)
else:
line_start = b
arrow_end = a
arrow_start = arrow_end - arrow_size*self.direction
centre = (a + b - self.direction*arrow_size)/2
centre_start = centre - self.direction*r
centre_end = centre + self.direction*r
arrow_start = min(arrow_end+arrow_size, line_start)
centre = (a + b)/2
arrow_y = h/2
if abs(centre - arrow_start) < r:
arrow_y -= r
group = draw.Group(transform="translate({} {})".format(x, y))
group.append(draw.Line(line_start, arrow_y,
arrow_start, arrow_y,
stroke='black'))
group.append(draw.Circle(centre, h/2, r, fill='ivory', stroke='black'))
group.append(draw.Line(line_start, h/2, centre_start, h/2, stroke='black'))
group.append(draw.Line(centre_end, h/2, arrow_start, h/2, stroke='black'))
group.append(draw.Lines(arrow_end, h/2,
arrow_start, (h + arrow_size)/2,
arrow_start, (h - arrow_size)/2,
group.append(draw.Lines(arrow_end, arrow_y,
arrow_start, arrow_y + arrow_size/2,
arrow_start, arrow_y - arrow_size/2,
arrow_end, arrow_y,
fill='black'))
group.append(draw.Text(self.label,
font_size,
Expand All @@ -117,55 +130,59 @@ def draw(self, x=0, y=0, xscale=1.0):
class ArrowGroup(Element):
def __init__(self, arrows: typing.Sequence[Arrow], gap=3):
self.arrows = []
coordinates = [] # [(start, end, index)]
self.y_coordinates = []
x_coordinates = [] # [(start, end, index)]
x1 = x2 = None
for i, arrow in enumerate(arrows):
self.arrows.append([0, arrow])
self.arrows.append(arrow)
self.y_coordinates.append(0)
arrow_x2 = arrow.x + arrow.w
coordinates.append((arrow.x, arrow_x2, i))
x_coordinates.append((arrow.x, arrow_x2, i))
if x1 is None:
x1, x2 = arrow.x, arrow_x2
else:
x1 = min(x1, arrow.x)
x2 = max(x2, arrow_x2)
h = 0
while coordinates:
while x_coordinates:
if h > 0:
h += gap
row_end = coordinates[0][0]-1
row_end = x_coordinates[0][0]-1
row_height = 0
for key in coordinates[:]:
for key in x_coordinates[:]:
x1, x2, i = key
if x1 < row_end:
continue
coordinates.remove(key)
row = self.arrows[i]
arrow_h = row[1].h
row[0] = h + arrow_h
x_coordinates.remove(key)
arrow = self.arrows[i]
arrow_h = arrow.h
self.y_coordinates[i] = h + arrow_h
row_height = max(row_height, arrow_h)
row_end = x2
h += row_height
for row in self.arrows:
row[0] -= h
self.y_coordinates = [y-h for y in self.y_coordinates]
super().__init__(x1, 0, w=x2-x1, h=h)

def draw(self, x=0, y=0, xscale=1.0):
group = draw.Group(transform="translate({} {})".format(x, y))
for i, (child_y, arrow) in enumerate(self.arrows):
for i, (child_y, arrow) in enumerate(zip(self.y_coordinates,
self.arrows)):
group.append(arrow.draw(y=child_y, xscale=xscale))
return group


def plot_genome_coverage(genome_coverage_csv, genome_coverage_svg_path):
f = build_coverage_figure(genome_coverage_csv)
def plot_genome_coverage(genome_coverage_csv,
blast_csv,
genome_coverage_svg_path):
f = build_coverage_figure(genome_coverage_csv, blast_csv)
f.show(w=970).saveSvg(genome_coverage_svg_path)


def build_coverage_figure(genome_coverage_csv):
def build_coverage_figure(genome_coverage_csv, blast_csv=None):
min_position, max_position = 1, 500
coordinate_depths = Counter()
contig_depths = Counter()
contig_groups = defaultdict(set)
contig_groups = defaultdict(set) # {coordinates_name: {contig_name}}
reader = DictReader(genome_coverage_csv)
for row in reader:
query_nuc_pos = int(row['query_nuc_pos'])
Expand All @@ -191,6 +208,14 @@ def build_coverage_figure(genome_coverage_csv):
position_offset = -min_position + 1
max_position += position_offset

blast_rows = []
if blast_csv is not None:
for blast_row in DictReader(blast_csv):
for field_name in ('start', 'end', 'ref_start', 'ref_end'):
blast_row[field_name] = int(blast_row[field_name])
blast_rows.append(blast_row)
blast_rows.sort(key=itemgetter('start', 'ref_start'))

landmarks_path = (Path(__file__).parent.parent / "data" /
"landmark_references.yaml")
landmark_groups = yaml.safe_load(landmarks_path.read_text())
Expand Down Expand Up @@ -221,18 +246,66 @@ def build_coverage_figure(genome_coverage_csv):
break
else:
add_partial_banner(f, position_offset, max_position)
for _, contig_name in sorted((-contig_depths[name], name)
for name in contig_groups[coordinates_name]):
sorted_contig_names = [
contig_name
for _, contig_name in sorted(
(-contig_depths[name], name)
for name in contig_groups[coordinates_name])]
ref_arrows = []
for contig_name in sorted_contig_names:
contig_num, contig_ref = contig_name.split('-', 1)
arrow_count = 0
for blast_row in blast_rows:
if blast_row['contig_num'] != contig_num:
continue
if blast_row['ref_name'] != contig_ref:
continue
arrow_count += 1
ref_start = int(blast_row['ref_start'])
ref_end = int(blast_row['ref_end'])
ref_arrows.append(Arrow(ref_start,
ref_end,
label=f'{contig_num}.{arrow_count}'))
if ref_arrows:
f.add(ArrowGroup(ref_arrows))
for contig_name in sorted_contig_names:
genome_coverage_csv.seek(0)
reader = DictReader(genome_coverage_csv)
build_contig(reader, f, contig_name, max_position, position_offset)
build_contig(reader,
f,
contig_name,
max_position,
position_offset,
blast_rows)

if not f.elements:
f.add(Track(1, max_position, label='No contigs found.', color='none'))
return f


def build_contig(reader, f, contig_name, max_position, position_offset):
def build_contig(reader,
f,
contig_name,
max_position,
position_offset,
blast_rows):
contig_num, contig_ref = contig_name.split('-', 1)
blast_ranges = [] # [[start, end, blast_num]]
blast_starts = {} # {start: blast_num}
blast_ends = {} # {end: blast_num}
for blast_row in blast_rows:
if blast_row['contig_num'] != contig_num:
continue
if blast_row['ref_name'] != contig_ref:
continue
blast_num = len(blast_ranges) + 1
blast_ranges.append([None, None, blast_num])
blast_starts[blast_row['start']] = blast_num
blast_ends[blast_row['end']] = blast_num
event_positions = set(blast_starts)
event_positions.update(blast_ends)
event_positions = sorted(event_positions, reverse=True)

insertion_size = 0
insertion_ranges = [] # [(start, end)]
for contig_name2, contig_rows in groupby(reader, itemgetter('contig')):
Expand Down Expand Up @@ -263,6 +336,23 @@ def build_contig(reader, f, contig_name, max_position, position_offset):
if contig_row['coverage'] is not None:
coverage[pos - start] = (contig_row['coverage'] -
contig_row['dels'])
contig_pos = int(contig_row['query_nuc_pos'])
while event_positions and event_positions[-1] <= contig_pos:
event_pos = event_positions.pop()
blast_num = blast_starts.get(event_pos)
if blast_num is not None:
blast_ranges[blast_num-1][0] = pos
blast_num = blast_ends.get(event_pos)
if blast_num is not None:
blast_ranges[blast_num-1][1] = pos

arrows = []
for arrow_start, arrow_end, blast_num in blast_ranges:
arrows.append(Arrow(arrow_start,
arrow_end,
label=f'{contig_num}.{blast_num}'))
if arrows:
f.add(ArrowGroup(arrows))
subtracks = []
for has_coverage, group_positions in groupby(
enumerate(coverage),
Expand Down Expand Up @@ -330,7 +420,11 @@ def summarize_figure(figure: Figure):

summary = StringIO()
for padding, track in figure.elements:
spans = getattr(track, 'tracks', [track])
spans = getattr(track, 'arrows', None)
if spans is None:
spans = getattr(track, 'tracks', [track])
else:
spans.sort(key=attrgetter('x', 'w', 'label'))
for i, span in enumerate(spans):
if i:
summary.write(', ')
Expand All @@ -349,6 +443,13 @@ def summarize_figure(figure: Figure):
if count > 1:
summary.write(f'x{count}')
continue
direction = getattr(span, 'direction', None)
if direction is not None and direction != '':
if direction >= 0:
summary.write(f'{span.x}--{span.label}->{span.x+span.w}')
else:
summary.write(f'{span.x}<-{span.label}--{span.x+span.w}')
continue
span_text = getattr(span.label, 'text', span.label) or ''
summary.write(span_text)
color = getattr(span, 'color')
Expand All @@ -371,11 +472,16 @@ def main():
parser.add_argument('genome_coverage_csv',
help='CSV file with coverage counts for each contig',
type=FileType())
parser.add_argument('blast_csv',
help='CSV file with BLAST results for each contig',
type=FileType())
parser.add_argument('genome_coverage_svg',
help='SVG file to plot coverage counts for each contig')
args = parser.parse_args()

plot_genome_coverage(args.genome_coverage_csv, args.genome_coverage_svg)
plot_genome_coverage(args.genome_coverage_csv,
args.blast_csv,
args.genome_coverage_svg)
print('Wrote', args.genome_coverage_svg)


Expand Down
7 changes: 5 additions & 2 deletions micall/drivers/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,11 @@ def process(self,
coverage_maps_prefix=self.name,
excluded_projects=excluded_projects)

with open(self.genome_coverage_csv) as genome_coverage_csv:
plot_genome_coverage(genome_coverage_csv, self.genome_coverage_svg)
with open(self.genome_coverage_csv) as genome_coverage_csv, \
open(self.blast_csv) as blast_csv:
plot_genome_coverage(genome_coverage_csv,
blast_csv,
self.genome_coverage_svg)

logger.info('Running cascade_report on %s.', self)
with open(self.g2p_summary_csv) as g2p_summary_csv, \
Expand Down
8 changes: 4 additions & 4 deletions micall/tests/test_denovo.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,10 @@ def test_write_blast(tmpdir, hcv_db):
"""
blast_csv = StringIO()
expected_blast_csv = """\
contig_name,ref_name,score,match,pident,start,end,ref_start,ref_end
bar,HCV-1g,37,0.67,100,19,55,8506,8542
bar,HCV-1a,41,0.75,100,15,55,8518,8558
foo,HCV-1a,41,1.0,100,1,41,8187,8227
contig_num,ref_name,score,match,pident,start,end,ref_start,ref_end
2,HCV-1g,37,0.67,100,19,55,8506,8542
2,HCV-1a,41,0.75,100,15,55,8518,8558
1,HCV-1a,41,1.0,100,1,41,8187,8227
"""

write_contig_refs(str(contigs_fasta), contigs_csv, blast_csv=blast_csv)
Expand Down
Loading

0 comments on commit 6aa5f5d

Please sign in to comment.