From 6ff6741a025b62bf3d7929917c9e12af3675a472 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Thu, 20 Jun 2019 01:55:49 +0000 Subject: [PATCH] return mod from frontend for autotvm --- tutorials/autotvm/tune_relay_arm.py | 3 ++- tutorials/autotvm/tune_relay_cuda.py | 3 ++- tutorials/autotvm/tune_relay_mobile_gpu.py | 3 ++- tutorials/autotvm/tune_relay_x86.py | 3 ++- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/tutorials/autotvm/tune_relay_arm.py b/tutorials/autotvm/tune_relay_arm.py index 2c1dca9921eb..290f9756f195 100644 --- a/tutorials/autotvm/tune_relay_arm.py +++ b/tutorials/autotvm/tune_relay_arm.py @@ -96,7 +96,8 @@ def get_network(name, batch_size): # an example for mxnet model from mxnet.gluon.model_zoo.vision import get_model block = get_model('resnet18_v1', pretrained=True) - net, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype) + mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype) + net = mod[mod.entry_func] net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs) else: raise ValueError("Unsupported network: " + name) diff --git a/tutorials/autotvm/tune_relay_cuda.py b/tutorials/autotvm/tune_relay_cuda.py index 571334e8c106..c158e4b9fe36 100644 --- a/tutorials/autotvm/tune_relay_cuda.py +++ b/tutorials/autotvm/tune_relay_cuda.py @@ -96,7 +96,8 @@ def get_network(name, batch_size): # an example for mxnet model from mxnet.gluon.model_zoo.vision import get_model block = get_model('resnet18_v1', pretrained=True) - net, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype) + mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype) + net = mod[mod.entry_func] net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs) else: raise ValueError("Unsupported network: " + name) diff --git a/tutorials/autotvm/tune_relay_mobile_gpu.py b/tutorials/autotvm/tune_relay_mobile_gpu.py index 1e4cf6d52ade..c011268fda51 100644 --- a/tutorials/autotvm/tune_relay_mobile_gpu.py +++ b/tutorials/autotvm/tune_relay_mobile_gpu.py @@ -97,7 +97,8 @@ def get_network(name, batch_size): # an example for mxnet model from mxnet.gluon.model_zoo.vision import get_model block = get_model('resnet18_v1', pretrained=True) - net, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype) + mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype) + net = mod[mod.entry_func] net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs) else: raise ValueError("Unsupported network: " + name) diff --git a/tutorials/autotvm/tune_relay_x86.py b/tutorials/autotvm/tune_relay_x86.py index ad35c198bc77..c8d9def206fe 100644 --- a/tutorials/autotvm/tune_relay_x86.py +++ b/tutorials/autotvm/tune_relay_x86.py @@ -64,7 +64,8 @@ def get_network(name, batch_size): # an example for mxnet model from mxnet.gluon.model_zoo.vision import get_model block = get_model('resnet18_v1', pretrained=True) - net, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype) + mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype) + net = mod[mod.entry_func] net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs) else: raise ValueError("Unsupported network: " + name)