From 8dd4a6b2e74ef6f6dffb9e80782a2f8c91badbfc Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Mon, 3 Feb 2020 10:54:16 -0800 Subject: [PATCH] fix --- python/tvm/autotvm/graph_tuner/base_graph_tuner.py | 7 ------- python/tvm/autotvm/graph_tuner/utils/traverse_graph.py | 4 ---- topi/python/topi/argwhere.py | 1 - vta/scripts/tune_resnet.py | 2 +- vta/tutorials/autotvm/tune_relay_vta.py | 2 +- 5 files changed, 2 insertions(+), 14 deletions(-) diff --git a/python/tvm/autotvm/graph_tuner/base_graph_tuner.py b/python/tvm/autotvm/graph_tuner/base_graph_tuner.py index cfaf7b8d1dcc1..5802bd3d745d2 100644 --- a/python/tvm/autotvm/graph_tuner/base_graph_tuner.py +++ b/python/tvm/autotvm/graph_tuner/base_graph_tuner.py @@ -35,13 +35,6 @@ from ._base import INVALID_LAYOUT_TIME -# Setup topi_op_name -> layout function -# NOTE: To add more ops, change the following dictionary. -# OP2LAYOUT = { -# "topi_nn_conv2d": topi.nn.conv2d_infer_layout, -# "topi_nn_depthwise_conv2d_nchw": topi.nn.depthwise_conv2d_infer_layout, -# } - def get_infer_layout(task_name): if task_name.startswith("conv2d"): return topi.nn.conv2d_infer_layout diff --git a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py index 6dbf8f1509726..a896bd1c51cf6 100644 --- a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py +++ b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py @@ -140,13 +140,9 @@ def _traverse_expr(node): _expr2graph_impl(node, target_ops, node_dict, node_list) return elif isinstance(node, TupleGetItem): - # TODO(@icemelon9): figure out why we need this? - node_entry["op"] = "TupleGetItem" in_node_idx = node_dict[node.tuple_value] node_entry["inputs"].append([in_node_idx, node.index, 0]) elif isinstance(node, Tuple): - # TODO(@icemelon9): figure out why we need this? - node_entry["op"] = "Tuple" for tuple_item in node: in_node_idx = node_dict[tuple_item] if isinstance(tuple_item, TupleGetItem): diff --git a/topi/python/topi/argwhere.py b/topi/python/topi/argwhere.py index 7ed728afb0a79..c2a9adea0c2ad 100644 --- a/topi/python/topi/argwhere.py +++ b/topi/python/topi/argwhere.py @@ -16,7 +16,6 @@ # under the License. # pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks """Argwhere operator""" -import tvm from tvm import hybrid @hybrid.script diff --git a/vta/scripts/tune_resnet.py b/vta/scripts/tune_resnet.py index 501ed404a8850..ca6466183b009 100644 --- a/vta/scripts/tune_resnet.py +++ b/vta/scripts/tune_resnet.py @@ -246,7 +246,7 @@ def tune_tasks(tasks, print("Extracting tasks...") tasks = extract_from_program(func=relay_prog, params=params, - ops=(relay.op.get("nn.conv2d"),) + ops=(relay.op.get("nn.conv2d"),), target=target, target_host=env.target_host) diff --git a/vta/tutorials/autotvm/tune_relay_vta.py b/vta/tutorials/autotvm/tune_relay_vta.py index 81885a6e027b5..1cfe969976ba9 100644 --- a/vta/tutorials/autotvm/tune_relay_vta.py +++ b/vta/tutorials/autotvm/tune_relay_vta.py @@ -350,7 +350,7 @@ def tune_and_evaluate(tuning_opt): relay_prog, params = compile_network(env, target, network, start_pack, stop_pack) tasks = autotvm.task.extract_from_program(func=relay_prog, params=params, - ops=(relay.op.get("nn.conv2d"),) + ops=(relay.op.get("nn.conv2d"),), target=target, target_host=env.target_host)