Skip to content

Commit

Permalink
Add ArrowGroup to genome coverage diagrams, as part of #442.
Browse files Browse the repository at this point in the history
Also stop merging contigs into the full-genome reference.
  • Loading branch information
donkirkby committed Oct 29, 2019
1 parent 77754f3 commit 6bcfe89
Show file tree
Hide file tree
Showing 5 changed files with 350 additions and 95 deletions.
82 changes: 68 additions & 14 deletions micall/core/plot_contigs.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import typing
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter, FileType
from collections import Counter, defaultdict
from csv import DictReader
from io import StringIO
from itertools import groupby
from math import log10
from math import log10, copysign
from operator import itemgetter
from pathlib import Path

Expand Down Expand Up @@ -69,36 +70,89 @@ def get_color(self, coverage):


class Arrow(Element):
def __init__(self, start, end, h=30, label=None):
super().__init__(x=start, y=0, w=end-start, h=h)
def __init__(self, start, end, h=20, label=None):
x = start
w = end-start
self.direction = copysign(1, w)
if w < 0:
x = end
w = -w
super().__init__(x=x, y=0, w=w, h=h)
self.label = label

def draw(self, x=0, y=0, xscale=1.0):
h = self.h
a = self.x * xscale
b = (self.x + self.w) * xscale
x = x * xscale
direction = 1
r = 10 * xscale
r = h/2
font_size = h * 0.75
arrow_size = 7 * xscale
arrow_end = b
arrow_start = arrow_end - arrow_size*direction
centre = (a + b - direction*arrow_size)/2
centre_start = centre - direction*r
centre_end = centre + direction*r
if self.direction >= 0:
line_start = a
arrow_end = b
else:
line_start = b
arrow_end = a
arrow_start = arrow_end - arrow_size*self.direction
centre = (a + b - self.direction*arrow_size)/2
centre_start = centre - self.direction*r
centre_end = centre + self.direction*r
group = draw.Group(transform="translate({} {})".format(x, y))
group.append(draw.Circle(centre, h/2, r, fill='ivory', stroke='black'))
group.append(draw.Line(a, h/2, centre_start, h/2, stroke='black'))
group.append(draw.Line(line_start, h/2, centre_start, h/2, stroke='black'))
group.append(draw.Line(centre_end, h/2, arrow_start, h/2, stroke='black'))
group.append(draw.Lines(arrow_end, h/2,
arrow_start, (h + arrow_size)/2,
arrow_start, (h - arrow_size)/2,
fill='black'))
group.append(draw.Text(self.label,
15,
centre, h/2,
font_size,
centre, h / 2,
text_anchor='middle',
dy="0.3em"))
dy="0.35em"))
return group


class ArrowGroup(Element):
def __init__(self, arrows: typing.Sequence[Arrow], gap=3):
self.arrows = []
coordinates = [] # [(start, end, index)]
x1 = x2 = None
for i, arrow in enumerate(arrows):
self.arrows.append([0, arrow])
arrow_x2 = arrow.x + arrow.w
coordinates.append((arrow.x, arrow_x2, i))
if x1 is None:
x1, x2 = arrow.x, arrow_x2
else:
x1 = min(x1, arrow.x)
x2 = max(x2, arrow_x2)
h = 0
while coordinates:
if h > 0:
h += gap
row_end = coordinates[0][0]-1
row_height = 0
for key in coordinates[:]:
x1, x2, i = key
if x1 < row_end:
continue
coordinates.remove(key)
row = self.arrows[i]
arrow_h = row[1].h
row[0] = h + arrow_h
row_height = max(row_height, arrow_h)
row_end = x2
h += row_height
for row in self.arrows:
row[0] -= h
super().__init__(x1, 0, w=x2-x1, h=h)

def draw(self, x=0, y=0, xscale=1.0):
group = draw.Group(transform="translate({} {})".format(x, y))
for i, (child_y, arrow) in enumerate(self.arrows):
group.append(arrow.draw(y=child_y, xscale=xscale))
return group


Expand Down
3 changes: 2 additions & 1 deletion micall/core/remap.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
PARTIAL_CONTIG_SUFFIX = 'partial'
REVERSED_CONTIG_SUFFIX = 'reversed'
EXCLUDED_CONTIG_SUFFIX = 'excluded'
ARE_CONTIGS_MERGED = False

# SAM file format
SAM_FIELDS = [
Expand Down Expand Up @@ -795,7 +796,7 @@ def read_contigs(contigs_csv, excluded_seeds=None):
match_fraction = float(row['match'])
is_match = 0.25 <= match_fraction
is_reversed = match_fraction < 0
if not is_match:
if not (ARE_CONTIGS_MERGED and is_match):
contig_name = get_contig_name(i,
row['ref'],
is_match,
Expand Down
181 changes: 161 additions & 20 deletions micall/tests/test_plot_contigs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from genetracks import Figure, Track, Multitrack, Label, Coverage

from micall.core.plot_contigs import summarize_figure, \
build_coverage_figure, SmoothCoverage, add_partial_banner, Arrow
build_coverage_figure, SmoothCoverage, add_partial_banner, Arrow, ArrowGroup

HCV_HEADER = ('C[342-915], E1[915-1491], E2[1491-2580], P7[2580-2769], '
'NS2[2769-3420], NS3[3420-5313], NS4A[5313-5475], '
Expand All @@ -24,28 +24,84 @@ class SvgDiffer:
def __init__(self):
self.work_dir: Path = Path(__file__).parent / 'svg_diffs'
self.work_dir.mkdir(exist_ok=True)
self.mismatch_found = False
for work_file in self.work_dir.iterdir():
if work_file.name == '.gitignore':
continue
assert work_file.suffix in ('.svg', '.png')
work_file.unlink()

def diff_pixel(self, actual_pixel, expected_pixel):
ar, ag, ab, aa = actual_pixel
er, eg, eb, ea = expected_pixel
if actual_pixel != expected_pixel:
self.mismatch_found = True
# Colour
dr = 0xff
dg = (ag + eg) // 5
db = (ab + eb) // 5

# Opacity
da = 0xff
else:
# Colour
dr, dg, db = ar, ag, ab

# Opacity
da = aa // 3
return dr, dg, db, da

def assert_equal(self,
svg_actual: Drawing,
svg_expected: Drawing,
name: str):
# Display image when in live turtle mode.
display_image = getattr(Turtle, 'display_image', None)
if display_image is not None:
encoded = standard_b64encode(svg_actual.rasterize().pngData)
display_image(0, 0, image=encoded.decode('UTF-8'))

png_actual = drawing_to_image(svg_actual)
png_expected = drawing_to_image(svg_expected)
png_diff = ImageChops.difference(png_actual, png_expected)
w = max(png_actual.width, png_expected.width)
h = max(png_actual.height, png_expected.height)

png_actual_padded = Image.new(png_actual.mode, (w, h))
png_expected_padded = Image.new(png_expected.mode, (w, h))
png_actual_padded.paste(png_actual)
png_expected_padded.paste(png_expected)
png_diff = Image.new(png_actual.mode, (w, h))
self.mismatch_found = False
png_diff.putdata([self.diff_pixel(actual_pixel, expected_pixel)
for actual_pixel, expected_pixel in zip(
png_actual_padded.getdata(),
png_expected_padded.getdata())])

extrema = png_diff.getextrema()
if extrema == ((0, 0), (0, 0), (0, 0), (0, 0)):
# Display image when in live turtle mode.
display_image = getattr(Turtle, 'display_image', None)
if display_image is not None:
t = Turtle()
try:
w = t.screen.cv.cget('width')
h = t.screen.cv.cget('height')
ox, oy = w/2, h/2
text_height = 20
t.penup()
t.goto(-ox, oy)
t.right(90)
t.forward(text_height)
t.write(f'Actual')
display_image(ox+t.xcor(), oy-t.ycor(),
image=encode_image(png_actual))
t.forward(png_actual.height)
t.forward(text_height)
t.write(f'Diff')
display_image(ox+t.xcor(), oy-t.ycor(),
image=encode_image(png_diff))
t.forward(png_diff.height)
t.forward(text_height)
t.write('Expected')
display_image(ox+t.xcor(), oy-t.ycor(),
image=encode_image(png_expected))
t.forward(png_expected.height)
except Exception as ex:
t.write(str(ex))

if not self.mismatch_found:
return
text_actual = svg_actual.asSvg()
(self.work_dir / (name+'_actual.svg')).write_text(text_actual)
Expand All @@ -63,6 +119,13 @@ def drawing_to_image(drawing: Drawing) -> Image:
return image


def encode_image(image: Image) -> bytes:
writer = BytesIO()
image.save(writer, format='PNG')
encoded = standard_b64encode(writer.getvalue())
return encoded.decode('UTF-8')


@pytest.fixture(scope='session')
def svg_differ():
return SvgDiffer()
Expand Down Expand Up @@ -453,23 +516,101 @@ def test_plot_genome_coverage_insertion():
assert expected_figure == summarize_figure(figure)


# noinspection DuplicatedCode
def test_arrow(svg_differ):
expected_svg = Drawing(200.0, 60.0, origin=(0, 0))
expected_svg.append(Circle(168/2, 40, 10, stroke='black', fill='ivory'))
expected_svg = Drawing(175.0, 35.0, origin=(0, 0))
expected_svg.append(Circle(168/2, 20, 10, stroke='black', fill='ivory'))
expected_svg.append(Text('X',
15,
168/2, 40,
168/2, 20,
text_anchor='middle',
dy="0.3em"))
expected_svg.append(Line(0, 40, 74, 40, stroke='black'))
expected_svg.append(Line(94, 40, 168, 40, stroke='black'))
expected_svg.append(Lines(175, 40,
168, 43.5,
168, 36.5,
175, 40,
expected_svg.append(Line(0, 20, 74, 20, stroke='black'))
expected_svg.append(Line(94, 20, 168, 20, stroke='black'))
expected_svg.append(Lines(175, 20,
168, 23.5,
168, 16.5,
175, 20,
fill='black'))
f = Figure()
f.add(Arrow(0, 175, label='X'))
f.add(Arrow(0, 175, h=20, label='X'))
svg = f.show()

svg_differ.assert_equal(svg, expected_svg, 'test_arrow')


# noinspection DuplicatedCode
def test_reverse_arrow(svg_differ):
expected_svg = Drawing(175.0, 35.0, origin=(0, 0))
expected_svg.append(Circle(7+168/2, 20, 10, stroke='black', fill='ivory'))
expected_svg.append(Text('X',
15,
7+168/2, 20,
text_anchor='middle',
dy="0.3em"))
expected_svg.append(Line(7, 20, 81, 20, stroke='black'))
expected_svg.append(Line(101, 20, 175, 20, stroke='black'))
expected_svg.append(Lines(0, 20,
7, 23.5,
7, 16.5,
0, 20,
fill='black'))
f = Figure()
f.add(Arrow(175, 0, h=20, label='X'))
svg = f.show()

svg_differ.assert_equal(svg, expected_svg, 'test_arrow')


# noinspection DuplicatedCode
def test_arrow_group(svg_differ):
expected_figure = Figure()
expected_figure.add(Track(1, 500, label='Header'))
h = 30
expected_figure.add(Arrow(1, 200, label='X', h=h), gap=-h)
expected_figure.add(Arrow(300, 500, label='Y', h=h))
expected_svg = expected_figure.show()

f = Figure()
f.add(Track(1, 500, label='Header'))
f.add(ArrowGroup([Arrow(1, 200, label='X', h=h),
Arrow(300, 500, label='Y', h=h)]))
svg = f.show()

svg_differ.assert_equal(svg, expected_svg, 'test_arrow_group')


# noinspection DuplicatedCode
def test_arrow_group_overlap(svg_differ):
expected_figure = Figure()
expected_figure.add(Track(1, 500, label='Header'))
h = 20
expected_figure.add(Arrow(1, 300, label='X', h=h), gap=3)
expected_figure.add(Arrow(1, 300, label='Y', h=h))
expected_svg = expected_figure.show()

f = Figure()
f.add(Track(1, 500, label='Header'))
f.add(ArrowGroup([Arrow(1, 300, label='X', h=h),
Arrow(1, 300, label='Y', h=h)]))
svg = f.show()

svg_differ.assert_equal(svg, expected_svg, 'test_arrow_group')


# noinspection DuplicatedCode
def test_arrow_group_reverse_overlap(svg_differ):
expected_figure = Figure()
expected_figure.add(Track(1, 500, label='Header'))
h = 20
expected_figure.add(Arrow(1, 300, label='X', h=h), gap=3)
expected_figure.add(Arrow(400, 250, label='Y', h=h))
expected_svg = expected_figure.show()

f = Figure()
f.add(Track(1, 500, label='Header'))
f.add(ArrowGroup([Arrow(1, 300, label='X', h=h),
Arrow(400, 250, label='Y', h=h)]))
svg = f.show()

svg_differ.assert_equal(svg, expected_svg, 'test_arrow_group')
Loading

0 comments on commit 6bcfe89

Please sign in to comment.