Skip to content

Commit

Permalink
return mod from frontend for autotvm (apache#3401)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics authored and wweic committed Jun 27, 2019
1 parent eca3f0b commit d6e335a
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 4 deletions.
3 changes: 2 additions & 1 deletion tutorials/autotvm/tune_relay_arm.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ def get_network(name, batch_size):
# an example for mxnet model
from mxnet.gluon.model_zoo.vision import get_model
block = get_model('resnet18_v1', pretrained=True)
net, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
net = mod[mod.entry_func]
net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs)
else:
raise ValueError("Unsupported network: " + name)
Expand Down
3 changes: 2 additions & 1 deletion tutorials/autotvm/tune_relay_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ def get_network(name, batch_size):
# an example for mxnet model
from mxnet.gluon.model_zoo.vision import get_model
block = get_model('resnet18_v1', pretrained=True)
net, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
net = mod[mod.entry_func]
net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs)
else:
raise ValueError("Unsupported network: " + name)
Expand Down
3 changes: 2 additions & 1 deletion tutorials/autotvm/tune_relay_mobile_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ def get_network(name, batch_size):
# an example for mxnet model
from mxnet.gluon.model_zoo.vision import get_model
block = get_model('resnet18_v1', pretrained=True)
net, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
net = mod[mod.entry_func]
net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs)
else:
raise ValueError("Unsupported network: " + name)
Expand Down
3 changes: 2 additions & 1 deletion tutorials/autotvm/tune_relay_x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def get_network(name, batch_size):
# an example for mxnet model
from mxnet.gluon.model_zoo.vision import get_model
block = get_model('resnet18_v1', pretrained=True)
net, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
net = mod[mod.entry_func]
net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs)
else:
raise ValueError("Unsupported network: " + name)
Expand Down

0 comments on commit d6e335a

Please sign in to comment.