Skip to content

Commit

Permalink
Improved argument handling in text trees.
Browse files Browse the repository at this point in the history
Closes #55 by raising a ValueError for any unsupported arguments.
  • Loading branch information
jeromekelleher committed Jun 25, 2019
1 parent efdd7ef commit 33c6ec1
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 36 deletions.
31 changes: 18 additions & 13 deletions python/tests/test_drawing.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,21 @@ def test_no_node_labels(self):
for u in t.nodes():
self.assertEqual(text.find(str(u)), -1)

def test_unused_args(self):
t = self.get_binary_tree()
self.assertRaises(ValueError, t.draw, format=self.drawing_format, width=300)
self.assertRaises(ValueError, t.draw, format=self.drawing_format, height=300)
self.assertRaises(
ValueError, t.draw, format=self.drawing_format, mutation_labels={})
self.assertRaises(
ValueError, t.draw, format=self.drawing_format, mutation_colours={})
self.assertRaises(
ValueError, t.draw, format=self.drawing_format, edge_colours={})
self.assertRaises(
ValueError, t.draw, format=self.drawing_format, node_colours={})
self.assertRaises(
ValueError, t.draw, format=self.drawing_format, max_tree_height=1234)


class TestDrawUnicode(TestDrawText):
"""
Expand Down Expand Up @@ -691,19 +706,6 @@ def test_simple_tree_sequence(self):
ts = tskit.load_text(nodes, edges, strict=False)
self.verify_text_rendering(ts.draw_text(), ts_drawing)

def test_tree_height_scale(self):
tree = msprime.simulate(4, random_seed=2).first()
with self.assertRaises(ValueError):
tree.draw_text(tree_height_scale="time")

t1 = tree.draw_text(tree_height_scale="rank")
t2 = tree.draw_text()
self.assertEqual(t1, t2)

for bad_scale in [0, "", "NOT A SCALE"]:
with self.assertRaises(ValueError):
tree.draw_text(tree_height_scale=bad_scale)

def test_max_tree_height(self):
nodes = io.StringIO("""\
id is_sample population individual time metadata
Expand Down Expand Up @@ -763,6 +765,9 @@ def test_max_tree_height(self):
"0 1 2 3\n")
t = ts.first()
self.verify_text_rendering(t.draw_text(max_tree_height="tree"), tree)
for bad_max_tree_height in [1, "sdfr", ""]:
with self.assertRaises(ValueError):
t.draw_text(max_tree_height=bad_max_tree_height)


class TestDrawSvg(TestTreeDraw):
Expand Down
44 changes: 21 additions & 23 deletions python/tskit/drawing.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,6 @@ def draw_tree(
mutation_labels=None, mutation_colours=None, format=None, edge_colours=None,
tree_height_scale=None, max_tree_height=None):

# We can't draw trees with zero roots.
if tree.num_roots == 0:
raise ValueError("Cannot draw a tree with zero roots")

# See tree.draw() for documentation on these arguments.
fmt = check_format(format)
if fmt == "svg":
Expand All @@ -72,11 +68,25 @@ def draw_tree(
max_tree_height=max_tree_height)
return td.draw()
else:
# TODO catch the unused args here and throw a ValueError
if width is not None:
raise ValueError("Text trees do not support width")
if height is not None:
raise ValueError("Text trees do not support height")
if mutation_labels is not None:
raise ValueError("Text trees do not support mutation_labels")
if mutation_colours is not None:
raise ValueError("Text trees do not support mutation_colours")
if node_colours is not None:
raise ValueError("Text trees do not support node_colours")
if edge_colours is not None:
raise ValueError("Text trees do not support edge_colours")
if max_tree_height is not None:
raise ValueError("Text trees do not support max_tree_height")

use_ascii = fmt == "ascii"
text_tree = TextTree(
tree, node_labels=node_labels, tree_height_scale=tree_height_scale,
max_tree_height=max_tree_height, use_ascii=use_ascii)
tree, node_labels=node_labels, max_tree_height=max_tree_height,
use_ascii=use_ascii)
return str(text_tree)


Expand Down Expand Up @@ -586,8 +596,7 @@ def __init__(self, ts):

def draw(self):
trees = [
TextTree(tree, tree_height_scale="rank", max_tree_height="ts")
for tree in self.ts.trees()]
TextTree(tree, max_tree_height="ts") for tree in self.ts.trees()]
self.width = sum(tree.width + 2 for tree in trees) - 1
self.height = max(tree.height for tree in trees)
self.canvas = np.zeros((self.height, self.width), dtype=str)
Expand Down Expand Up @@ -655,10 +664,8 @@ class TextTree(object):
to a 2D array.
"""
def __init__(
self, tree, node_labels=None, tree_height_scale=None, max_tree_height=None,
use_ascii=False):
self, tree, node_labels=None, max_tree_height=None, use_ascii=False):
self.tree = tree
self.tree_height_scale = tree_height_scale
self.max_tree_height = max_tree_height
self.__set_charset(use_ascii)
self.num_leaves = len(list(tree.leaves()))
Expand All @@ -676,13 +683,6 @@ def __init__(
# Labels for nodes
self.node_labels = {}

# TODO Clear up the logic here. What do we actually support?
if tree_height_scale not in [None, "time", "rank"]:
raise ValueError("tree_height_scale must be one of 'time' or 'rank'")
numeric_max_tree_height = max_tree_height not in [None, "tree", "ts"]
if tree_height_scale == "rank" and numeric_max_tree_height:
raise ValueError("Cannot specify numeric max_tree_height with rank scale")

# Set the node labels and colours.
for u in tree.nodes():
if node_labels is None:
Expand All @@ -700,15 +700,13 @@ def __init__(
self.__draw()

def __assign_time_positions(self):
if self.tree_height_scale == "time":
raise ValueError("time scaling not currently supported in text trees")
assert self.tree_height_scale in [None, "rank"]
assert self.max_tree_height in [None, "tree", "ts"]
tree = self.tree
if self.max_tree_height in [None, "tree"]:
times = {tree.time(u) for u in tree.nodes()}
elif self.max_tree_height == "ts":
times = {node.time for node in tree.tree_sequence.nodes()}
else:
raise ValueError("max_tree_height must be 'tree' or 'ts'")
depth = {t: 2 * j for j, t in enumerate(sorted(times, reverse=True))}
for u in self.tree.nodes():
self.time_position[u] = depth[self.tree.time(u)]
Expand Down

0 comments on commit 33c6ec1

Please sign in to comment.