Skip to content

Commit

Permalink
[relay][frontend] Return Module from get_workload (apache#3483)
Browse files Browse the repository at this point in the history
* [relay][frontend] Return Module from get_workload

* pass entry_func to autotvm

* disable tune

* add property to module

* mod.entry_func to main

* .main -> mod["main"]

* fix
  • Loading branch information
zhiics authored and jroesch committed Jul 6, 2019
1 parent 33b2f06 commit 1cad4c6
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion python/vta/top/graphpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, transform.Pass)
mod = relay.Module.from_expr(expr)
mod = opt_pass(mod)
entry = mod[mod.entry_func]
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body

def _to_shape(shape):
Expand Down
2 changes: 1 addition & 1 deletion scripts/tune_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def compile_network(opt, env, target):
# Perform quantization in Relay
with relay.quantize.qconfig(global_scale=8.0,
skip_conv_layers=[0]):
relay_prog = relay.quantize.quantize(mod[mod.entry_func], params=params)
relay_prog = relay.quantize.quantize(mod["main"], params=params)

# Perform graph packing and constant folding for VTA target
if target.device_name == "vta":
Expand Down
2 changes: 1 addition & 1 deletion tutorials/autotvm/tune_relay_vta.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def compile_network(env, target, model, start_pack, stop_pack):
# Perform quantization in Relay
with relay.quantize.qconfig(global_scale=8.0,
skip_conv_layers=[0]):
relay_prog = relay.quantize.quantize(mod[mod.entry_func], params=params)
relay_prog = relay.quantize.quantize(mod["main"], params=params)

# Perform graph packing and constant folding for VTA target
if target.device_name == "vta":
Expand Down
2 changes: 1 addition & 1 deletion tutorials/frontend/deploy_resnet_on_vta.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@
# Perform quantization in Relay
with relay.quantize.qconfig(global_scale=8.0,
skip_conv_layers=[0]):
relay_prog = relay.quantize.quantize(mod[mod.entry_func], params=params)
relay_prog = relay.quantize.quantize(mod["main"], params=params)

# Perform graph packing and constant folding for VTA target
if target.device_name == "vta":
Expand Down

0 comments on commit 1cad4c6

Please sign in to comment.