Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
icemelon committed Feb 3, 2020
1 parent ac24536 commit 8dd4a6b
Show file tree
Hide file tree
Showing 5 changed files with 2 additions and 14 deletions.
7 changes: 0 additions & 7 deletions python/tvm/autotvm/graph_tuner/base_graph_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,6 @@
from ._base import INVALID_LAYOUT_TIME


# Setup topi_op_name -> layout function
# NOTE: To add more ops, change the following dictionary.
# OP2LAYOUT = {
# "topi_nn_conv2d": topi.nn.conv2d_infer_layout,
# "topi_nn_depthwise_conv2d_nchw": topi.nn.depthwise_conv2d_infer_layout,
# }

def get_infer_layout(task_name):
if task_name.startswith("conv2d"):
return topi.nn.conv2d_infer_layout
Expand Down
4 changes: 0 additions & 4 deletions python/tvm/autotvm/graph_tuner/utils/traverse_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,9 @@ def _traverse_expr(node):
_expr2graph_impl(node, target_ops, node_dict, node_list)
return
elif isinstance(node, TupleGetItem):
# TODO(@icemelon9): figure out why we need this?
node_entry["op"] = "TupleGetItem"
in_node_idx = node_dict[node.tuple_value]
node_entry["inputs"].append([in_node_idx, node.index, 0])
elif isinstance(node, Tuple):
# TODO(@icemelon9): figure out why we need this?
node_entry["op"] = "Tuple"
for tuple_item in node:
in_node_idx = node_dict[tuple_item]
if isinstance(tuple_item, TupleGetItem):
Expand Down
1 change: 0 additions & 1 deletion topi/python/topi/argwhere.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# under the License.
# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks
"""Argwhere operator"""
import tvm
from tvm import hybrid

@hybrid.script
Expand Down
2 changes: 1 addition & 1 deletion vta/scripts/tune_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def tune_tasks(tasks,
print("Extracting tasks...")
tasks = extract_from_program(func=relay_prog,
params=params,
ops=(relay.op.get("nn.conv2d"),)
ops=(relay.op.get("nn.conv2d"),),
target=target,
target_host=env.target_host)

Expand Down
2 changes: 1 addition & 1 deletion vta/tutorials/autotvm/tune_relay_vta.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def tune_and_evaluate(tuning_opt):
relay_prog, params = compile_network(env, target, network, start_pack, stop_pack)
tasks = autotvm.task.extract_from_program(func=relay_prog,
params=params,
ops=(relay.op.get("nn.conv2d"),)
ops=(relay.op.get("nn.conv2d"),),
target=target,
target_host=env.target_host)

Expand Down

0 comments on commit 8dd4a6b

Please sign in to comment.