Skip to content

Commit

Permalink
Update TEDD about InferBound failure. 1. TEDD doesn't call inferbound…
Browse files Browse the repository at this point in the history
… for DFG. 2. Update tutorial about the InferBound failure.
  • Loading branch information
yongfeng-nv committed Feb 10, 2020
1 parent 0daac78 commit bb8752d
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 26 deletions.
40 changes: 21 additions & 19 deletions python/tvm/contrib/tedd.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def __init__(self, sch):
for rel_idx, rel in enumerate(stage.relations):
self.dict[rel] = [stage_idx, rel_idx]
for tensor_idx in range(stage.op.num_outputs):
self.dict[frozenset({stage.op.name, tensor_idx})] = [stage_idx, tensor_idx]
self.dict[frozenset({stage.op.name,
tensor_idx})] = [stage_idx, tensor_idx]

def get_dom_path(self, obj):
if obj is None:
Expand Down Expand Up @@ -187,8 +188,8 @@ def legend_dot(g):
subgraph.node('legend', label, shape='none', margin='0')


def extract_dom_for_viz(sch):
json_str = dump_json(sch)
def extract_dom_for_viz(sch, need_range=True):
json_str = dump_json(sch, need_range)
s = json.loads(json_str)
s = insert_dot_id(s)
return s
Expand All @@ -214,7 +215,7 @@ def dump_graph(dot_string,
return None


def dump_json(sch):
def dump_json(sch, need_range):
"""Serialize data for visualization from a schedule in JSON format.
Parameters
Expand All @@ -229,7 +230,8 @@ def dump_json(sch):
"""
def encode_itervar(itervar, stage, index, range_map):
"""Extract and encode IterVar visualization data to a dictionary"""
ivrange = range_map[itervar] if range_map is not None and itervar in range_map else None
ivrange = range_map[
itervar] if range_map is not None and itervar in range_map else None
bind_thread = None
tensor_intrin = None
if itervar in stage.iter_var_attrs:
Expand Down Expand Up @@ -273,8 +275,7 @@ def get_leaf_itervar_index(itervar, leaf_iv):
for itervar in stage.all_iter_vars:
leaf_index = get_leaf_itervar_index(itervar, stage.leaf_iter_vars)
itervars.append(
encode_itervar(itervar, stage, leaf_index,
range_map))
encode_itervar(itervar, stage, leaf_index, range_map))
return itervars

def encode_itervar_relation(obj_manager, rel):
Expand Down Expand Up @@ -329,8 +330,7 @@ def encode_tensors(obj_manager, stage):
tensors = []
for i in range(stage.op.num_outputs):
tensor = stage.op.output(i)
tensors.append(
encode_tensor(obj_manager, tensor, stage))
tensors.append(encode_tensor(obj_manager, tensor, stage))
tensors.sort(key=lambda tensor: tensor["value_index"])
return tensors

Expand Down Expand Up @@ -362,7 +362,7 @@ def encode_stage(obj_manager, stage, range_map):
}
return stage_dict

def encode_schedule(sch):
def encode_schedule(sch, need_range):
"""Extract and encode data from a schedule for visualization to a nested dictionary.
It is useful for JSON to serialize schedule.
Expand All @@ -378,13 +378,15 @@ def encode_schedule(sch):
"""
assert isinstance(sch, tvm.schedule.Schedule
), 'Input is not a tvm.schedule.Schedule object.'
try:
range_map = tvm.schedule.InferBound(sch)
except tvm._ffi.base.TVMError as expt:
warnings.warn(
'Ranges are not available, because InferBound fails with the following error:\n'
+ str(expt))
range_map = None
range_map = None
if need_range:
try:
range_map = tvm.schedule.InferBound(sch)
except tvm._ffi.base.TVMError as expt:
warnings.warn(
'Ranges are not available, because InferBound fails with the following error:\n'
+ str(expt))

obj_manager = ObjectManager(sch)
stages = []
for stage in sch.stages:
Expand All @@ -394,7 +396,7 @@ def encode_schedule(sch):
"stages": stages,
}

return json.dumps(sch, default=encode_schedule)
return json.dumps(sch, default=lambda s: encode_schedule(s, need_range))


def viz_schedule_tree(sch,
Expand Down Expand Up @@ -724,7 +726,7 @@ def dfg_dot(g, sch):
g.edge(src, dst)

graph = create_dataflow_graph("Dataflow Graph")
s = extract_dom_for_viz(sch)
s = extract_dom_for_viz(sch, need_range=False)
for stage in s['stages']:
stage_node_dot(graph, stage)
for tensor in stage["output_tensors"]:
Expand Down
27 changes: 20 additions & 7 deletions tutorials/language/tedd.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@
t_bias = topi.add(t_conv, B)
t_relu = topi.nn.relu(t_bias)
s = topi.generic.schedule_conv2d_hwcn([t_relu])
s = s.normalize()


######################################################################
# Render Graphs with TEDD
# -------------------
Expand All @@ -78,11 +77,7 @@
# to render SVG figures showing in notebook directly.

tedd.viz_dataflow_graph(s, dot_file_path = '/tmp/dfg.dot')
tedd.viz_schedule_tree(s, dot_file_path = '/tmp/scheduletree.dot')
tedd.viz_itervar_relationship_graph(s, dot_file_path = '/tmp/itervar.dot')
#tedd.viz_dataflow_graph(s, show_svg = True)
#tedd.viz_schedule_tree(s, show_svg = True)
#tedd.viz_itervar_relationship_graph(s, show_svg = True)

######################################################################
# .. image:: https://github.com/dmlc/web-data/raw/master/tvm/tutorial/tedd_dfg.png
Expand All @@ -92,11 +87,25 @@
# scope shown in the middle and inputs/outputs information on the sides.
# Edges show nodes' dependency.

tedd.viz_schedule_tree(s, dot_file_path = '/tmp/scheduletree.dot')
#tedd.viz_schedule_tree(s, show_svg = True)
######################################################################
# We just rendered the schedule tree graph. You may notice an warning about ranges not
# available.
# The message also suggests to call normalize() to infer range information. We will
# skip inspecting the first schedule tree and encourage you to compare the graphs before
# and and after normalize() for its impact.

s = s.normalize()
tedd.viz_schedule_tree(s, dot_file_path = '/tmp/scheduletree2.dot')
#tedd.viz_schedule_tree(s, show_svg = True)

######################################################################
# .. image:: https://github.com/dmlc/web-data/raw/master/tvm/tutorial/tedd_st.png
# :align: center
# :scale: 100%
# The second one is a schedule tree. Every block under ROOT represents a
# Now, let us take a close look at the second schedule tree. Every block under ROOT
# represents a
# stage. Stage name shows in the top row and compute shows in the bottom row.
# The middle rows are for IterVars, the higher the outer, the lower the inner.
# An IterVar row contains its index, name, type, and other optional information.
Expand All @@ -119,6 +128,9 @@
# a schedule tree. The edges among IterVars and compute within one stage are
# omitted, making every stage a block, for better readability.

tedd.viz_itervar_relationship_graph(s, dot_file_path = '/tmp/itervar.dot')
#tedd.viz_itervar_relationship_graph(s, show_svg = True)

######################################################################
# .. image:: https://github.com/dmlc/web-data/raw/master/tvm/tutorial/tedd_itervar_rel.png
# :align: center
Expand All @@ -131,6 +143,7 @@
# IterVars don't drive any transformation node and have non-negative indices,
# such as ax0.ax1.fused.ax2.fused.ax3.fused.outer with index of 0.


######################################################################
# Summary
# -------
Expand Down

0 comments on commit bb8752d

Please sign in to comment.