From 866c0df2b74147972e7cf94eba11ba7882414a17 Mon Sep 17 00:00:00 2001 From: masahi Date: Fri, 28 Feb 2020 11:59:15 +0900 Subject: [PATCH] [Relay, Torch] Clean up and refactor PyTorch frontend (#4944) * The initial import of refactored implementation, all tests passed * enable mobilenet v2 test * minor cleanup * reorg * fix lint * use input names that come with torch IR * fix typo * introduce parse_operators * fix lint * add _ prefix --- python/tvm/relay/frontend/pytorch.py | 597 +++++++++--------- tests/python/frontend/pytorch/test_forward.py | 25 +- tutorials/frontend/from_pytorch.py | 14 +- 3 files changed, 324 insertions(+), 312 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index edd6ad84ae3e..fd66e3c1f367 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -18,6 +18,9 @@ # pylint: disable=consider-iterating-dictionary, invalid-name, unused-argument, unused-variable, broad-except # pylint: disable=import-outside-toplevel, simplifiable-if-expression, unnecessary-comprehension """PT: PyTorch frontend.""" +import itertools +from packaging import version + import numpy as np import tvm @@ -396,9 +399,11 @@ def _impl(inputs, input_types): def _size(): def _impl(inputs, input_types): - axis = int(inputs[1]) shape = _infer_shape(inputs[0]) - return shape[axis] + if len(inputs) > 1: + axis = int(inputs[1]) + return shape[axis] + return shape return _impl def _numtotensor(): @@ -484,10 +489,19 @@ def _impl(inputs, attrs, params): def _mean(): def _impl(inputs, input_types): data = inputs[0] - axis = _infer_shape(inputs[1]) - keepdims = int(inputs[2]) - exclude = int(inputs[3]) + if inputs[1]: + axis = _infer_shape(inputs[1]) + else: + axis = None + if len(inputs) > 2 and inputs[2]: + keepdims = int(inputs[2]) + else: + keepdims = False + if len(inputs) > 3 and inputs[3]: + exclude = int(inputs[3]) + else: + exclude = False return _op.mean(data, axis, keepdims, exclude) return _impl @@ -651,7 +665,7 @@ def _convert_elemwise_input(data, input_type): if isinstance(data, torch.Tensor): return _expr.const(data.item(), dtype=_convert_data_type(input_type)) elif not isinstance(data, _expr.Expr): - return _expr.const(int(data), dtype=_convert_data_type(input_type)) + return _expr.const(data, dtype=_convert_data_type(input_type)) else: return data @@ -718,293 +732,270 @@ def _convert_elemwise_input(data, input_type): "aten::sqrt" : _sqrt() } -# Internal graph for parsing -class Graph(object): - """ A helper class for parsing PyTorch model to Relay graph.""" +def _run_jit_passes(graph): + """ The inline pass is necessary to unwrap prim::CallMethod """ + import torch + if version.parse(torch.__version__) >= version.parse("1.4.0"): + torch._C._jit_pass_inline(graph) - def __init__(self, script_module, input_shapes): - self._script_module = script_module - self._graph = script_module.graph.copy() +def _is_int_seq(seq): + return len(seq) > 0 and all([isinstance(i, int) for i in seq]) - # TODO: Temporary fix to remove prim::CallMethod node introduced in PT 1.4 - import torch - from packaging import version - if version.parse(torch.__version__) >= version.parse("1.4.0"): - torch._C._jit_pass_inline(self._graph) - - self._inputs_r = {} - self._params = {} - self._param_tensors = {} - self._consts = {} - self._ops = {} - self._op_inputs_r = {} - self._op_inputs_types = {} - self._input_shapes = input_shapes if input_shapes else {} - self._parsed_node_names = {} - - def from_pytorch(self): - """ Construct relay nodes from PyTorch graph - - Currently only supports traced PyTorch format which means no control flow. - User must perform torch.jit.trace on a model and pass this in. - Future support should include support scripted models (torch.jit.script) which - preserves control flow. - - Returns - ------- - mod : tvm.relay.Module - The module that optimizations will be performed on. - - params : dict of str to tvm.runtime - Dict of converted parameters stored in tvm.runtime format - """ - # Check for missing ops - missing_operators = self._parse_import_prerequisites() - - if missing_operators: - raise tvm.error.OpNotImplemented( \ - "The following operators are not implemented: {}".format(missing_operators)) - - # Translate PyTorch graph to by decorating Graph with state dict and inputs into each op - self._parse_inputs() - self._parse_params() - self._parse_ops() - - outputs = [] - nid = 0 - - for op_name, op_node in self._ops.items(): - if op_node.kind() == "prim::ListConstruct": - if any(inp.debugName() in self._parsed_node_names.keys() \ - for inp in op_node.inputs()): - list_constr = [] - for i in op_node.inputs(): - if i.debugName() in self._parsed_node_names.keys(): - list_constr.append( \ - outputs[self._parsed_node_names[i.debugName()]]) - elif i.node().kind() == "prim::Constant": - list_constr.append(int(self._consts[i.debugName()])) - elif i.debugName() in self._inputs_r.keys(): - list_constr.append(int(self._inputs_r[i.debugName()])) - - # Unwrap for tensors - if len(list_constr) == 1: - list_constr = list_constr[0] - - outputs.append(list_constr) - self._parsed_node_names[op_name] = nid - nid = nid+1 - elif op_node.kind() != "prim::Constant": - for i in op_node.inputs(): - if i.debugName() in self._parsed_node_names.keys(): - for cnt in range(0, len(self._op_inputs_r[op_name])): - if isinstance(self._op_inputs_r[op_name][cnt], str): - if "call/var" in self._op_inputs_r[op_name][cnt]: - self._op_inputs_r[op_name][cnt] = \ - outputs[self._parsed_node_names[i.debugName()]] - break - - call = _convert_map[op_node.kind()](self._op_inputs_r[op_name], - self._op_inputs_types[op_name]) - - outputs.append(call) - self._parsed_node_names[op_name] = nid - nid = nid+1 - - func = tvm.relay.Function(_analysis.free_vars(outputs[-1]), outputs[-1]) - - param = {k: tvm.nd.array(v) for k, v in self._param_tensors.items()} - - return _module.IRModule.from_expr(func), param - - def _parse_inputs(self): - """ Map inputs to parser and inputs to graph. """ - # Get names and objects of inputs for IR - ir_inputs = [i for i in self._graph.inputs()] - - # Create corresponding shape and add to input - for input_name, ir_input in zip(self._input_shapes, ir_inputs[1:]): - input_shape = self._input_shapes[input_name] - ir_input.setDebugName(input_name) - - ir_dtype = _convert_data_type(ir_input.type().scalarType().lower()) - self._inputs_r[input_name] = _expr.var(input_name, - shape=self._input_shapes[input_name], - dtype=ir_dtype) - - # Add self (first input of a PyTorch graph) to inputs, the value doesn't matter here - input_name = ir_inputs[0].debugName() - self._inputs_r[input_name] = "self" - - def _parse_params(self): - """ Map state dictionary values to corresponding prim::GetAttr op node. """ - # Grab weights, biases, etc. from graph - state_dict = self._script_module.state_dict() - param_names = [] - for key, value in state_dict.items(): - param_str = str(key) - param_name = param_str.split(".")[-1] - param_names.append(param_name) - - # Get names of all inputs - input_names = [i for i in self._inputs_r.keys()] - - # Iterate through graph for getAttr nodes and match full state_dict name to nodes - node_weight_map = {} - for node in self._graph.nodes(): - if node.kind() == "prim::GetAttr": - - attribute_names = node.attributeNames() - assert len(attribute_names) == 1 - node_getattr_name = node.s(attribute_names[0]) - node_arg = node.input().debugName() - - if node.outputsSize() == 1: - node_name = node.output().debugName() - else: - node_name = [output.debugName() for output in node.outputs()][0] - - if node_arg in input_names: - node_weight_map[node_name] = node_getattr_name - else: - previous_map = node_weight_map[node_arg[:]] - node_weight_map[node_name] = previous_map+"."+node_getattr_name - - if node_getattr_name in param_names: - - value = state_dict[node_weight_map[node_name]] - tensor = tvm.nd.array(value.cpu().numpy()) - shape = tensor.shape - self._param_tensors[node_name] = tensor - - self._params[node_name] = _expr.var(node_name, - shape=shape, - dtype=_convert_data_type(str(value.dtype))) - - def _parse_ops(self): - """ Iterate through nodes and decorate graph with constants, operators, - and the inputs to each operator. """ - # Traverse nodes and add to graph - for node in self._graph.nodes(): - - if node.outputsSize() == 1: - node_name = node.output().debugName() - else: - node_name = [output.debugName() for output in node.outputs()][0] - - if node.kind() == "prim::Constant": - if node.hasAttributes(): - attribute_names = node.attributeNames() - attr_name = attribute_names[0] - ty = node.output().type().kind() - - if ty in ["IntType", "BoolType"]: - self._consts[node_name] = node.i(attr_name) - elif ty in ["FloatType", "LongType"]: - self._consts[node_name] = node.f(attr_name) - elif ty in ["TensorType", "CompleteTensorType"]: - self._consts[node_name] = node.output().toIValue() - else: - self._consts[node_name] = "0" - else: - self._consts[node_name] = "0" - elif node.kind() == "prim::ListConstruct": - list_shape = [] - for input_node in node.inputs(): - if input_node.debugName() in self._inputs_r.keys(): - c = self._inputs_r[input_node.debugName()] - assert isinstance(c, int) - list_shape.append(c) - elif input_node.debugName() in self._consts.keys(): - c = self._consts[input_node.debugName()] - assert isinstance(c, int) - list_shape.append(c) - self._inputs_r[node_name] = _expr.var(node_name, shape=list_shape) - - if node.kind() != "prim::GetAttr": - self._add_op(node_name, node) - - # Graph Helper Functions - - def _add_op(self, node_id, op_node): - """ Add an operator and its operators inputs to the graph and insert placeholders - where an input is a call node. - - Parameters - ---------- - node_id : string - The ID of the op node - - op_node : PyTorch Node object - The full Node object for the op node - - """ - self._ops[(node_id)] = op_node - input_list_r = [] - input_list_types = [] - for input_value in op_node.inputs(): - - inode_id = input_value.debugName() - inode = input_value.node() - - if inode_id in self._inputs_r.keys(): - input_list_r.append(self._inputs_r[inode_id]) - elif inode_id in self._params.keys(): - input_list_r.append(self._params[inode_id]) - elif inode.kind() == "prim::Constant": - input_list_r.append(self._consts[inode_id]) + +def _get_tensor_and_var(torch_tensor, name): + tensor = tvm.nd.array(torch_tensor.cpu().numpy()) + var = _expr.var(name, shape=tensor.shape) + return tensor, var + + +def _get_output_name(node): + assert node.outputsSize() == 1 + return node.output().debugName() + + +def _get_output_names(node): + return [output.debugName() for output in node.outputs()] + + +def _get_input_names(node_or_graph): + return [inp.debugName() for inp in node_or_graph.inputs()] + + +def _get_op_inputs(op_node, outputs, output_index_map): + input_names = [output_index_map[name] + for name in _get_input_names(op_node)] + return [outputs[name] for name in input_names] + + +def _update_outputs_from_pairs(name_output_pairs, outputs, output_index_map): + for output_name, output in name_output_pairs: + output_index_map[output_name] = len(outputs) + outputs.append(output) + + +def _report_missing_conversion(op_names): + """ Check if all ops in an input graph are supported by TVM """ + known_ops = ["prim::Constant", "prim::GetAttr", + "prim::ListConstruct", "prim::ListUnpack", + "prim::TupleConstruct", "prim::TupleUnpack"] + known_ops += list(_convert_map.keys()) + + missing = [op_name for op_name in op_names + if op_name not in known_ops] + + if missing: + msg = "The following operators are not implemented: {}".format(missing) + raise NotImplementedError(msg) + + +def _getattr_attr_name(node): + attribute_names = node.attributeNames() + assert len(attribute_names) == 1 + attr_name = node.s(attribute_names[0]) + return attr_name + + +def _getattr_full_name(getattrs): + return ".".join([_getattr_attr_name(node) for node in getattrs]) + + +def _get_input_types(op_node): + """ Returns a torch type for each input nodes """ + input_list_types = [] + for input_node in op_node.inputs(): + in_ty = input_node.type() + input_node_kind = in_ty.kind() + if input_node_kind == 'TensorType': + if in_ty.scalarType() is None: + input_list_types.append(None) else: - input_list_r.append("call/var."+inode_id) - - # If the inputs of a ListConstruct op is a call or var, remove it from inputs - if op_node.kind() == "prim::ListConstruct": - if node_id in self._inputs_r.keys(): - self._inputs_r.pop(node_id) - - try: - input_value_kind = input_value.type().kind() - if input_value_kind in ["TensorType", "CompleteTensorType"]: - if input_value.type().scalarType() is None: - input_list_types.append("float") - else: - input_list_types.append(input_value.type().scalarType().lower()) - elif input_value_kind == "ListType": - input_list_types.append(str(input_value.type().getElementType()).lower()) - elif input_value_kind in ["IntType", "FloatType", "BoolType", "StringType", - "OptionalType"]: - input_list_types.append(str(input_value.type()).lower()) - else: - input_list_types.append("UnsupportedType") - print("UnsupportedType "+str(input_value.type())+" and "+str(input_value_kind)) - except Exception as e: - print("Internal PyTorch error. Failed to grab type.") - - if op_node.kind() in ["aten::ones", "aten::zeros"]: - node_type = op_node.output().type().scalarType() - input_list_types[0] = node_type.lower() - - self._op_inputs_r[node_id] = input_list_r - self._op_inputs_types[node_id] = input_list_types - - def _parse_import_prerequisites(self): - """ Calculate the named preconditions from PyTorch graph. - - Returns - ------- - missing_operators : set object - Set of operator names which don't have their mapping in TVM - i.e. which are not supported - - """ - missing_operators = set() - for node in self._graph.nodes(): - if not node.kind() in ["prim::Constant", "prim::ListConstruct", "prim::GetAttr"] \ - and not node.kind() in _convert_map: - missing_operators.add(node.kind()) - - return missing_operators + input_list_types.append(in_ty.scalarType().lower()) + elif input_node_kind == 'ListType': + input_list_types.append(str(in_ty.getElementType()).lower()) + elif input_node_kind in ['IntType', 'FloatType', 'BoolType', + 'StringType', 'OptionalType']: + input_list_types.append(str(in_ty).lower()) + else: + input_list_types.append('UnsupportedType') + + if op_node.kind() in ['aten::ones', 'aten::zeros']: + node_type = op_node.output().type() + scalar_type = node_type.scalarType() + if scalar_type: + input_list_types[0] = scalar_type.lower() + + return input_list_types + + +def _get_constant(node): + """ Retrieve a constant associated with this prim::Constant node """ + attribute_names = node.attributeNames() + num_attributes = len(attribute_names) + + if num_attributes == 1: + attr_name = attribute_names[0] + ty = node.output().type().kind() + + if ty in ["IntType", "BoolType"]: + return node.i(attr_name) + elif ty in ["FloatType", "LongType"]: + return node.f(attr_name) + elif ty in ["TensorType", "CompleteTensorType"]: + tensor = node.t(attr_name) + if len(tensor.shape) == 0: # tensor(0.1) + return float(tensor) + return tensor + elif ty == "DeviceObjType": + return node.s(attr_name) + elif ty == "FunctionType": + return None + else: + raise NotImplementedError("Unsupported type: %s" % ty) + else: + assert num_attributes == 0 + return None + + +def _get_operator_nodes(nodes): + """ Returns torch IR nodes that need conversion to Relay """ + ops = {} + # Traverse nodes and add to graph + for node in nodes: + if node.outputsSize() > 1: + node_name = "_".join(_get_output_names(node)) + else: + node_name = _get_output_name(node) + + if node.kind() != "prim::GetAttr": + ops[node_name] = node + + return ops + + +def parse_inputs(graph_inputs, input_shapes): + """ Return Relay vars from torch input vars """ + ir_inputs = list(graph_inputs) + input_vars = {} + + for input_name, ir_input in zip(input_shapes, ir_inputs[1:]): + input_vars[input_name] = _expr.var(input_name, + shape=input_shapes[input_name]) + return input_vars + + +def get_use_chains(root_node, terminate=lambda _: False): + """ + Track a chain of users of this node forward, returning a list of chains + See get_attr_chains below for its usage + """ + def concat_lists(lists): + return itertools.chain.from_iterable(lists) + + def inner(current, accum): + users = [] + for output in current.outputs(): + users += [use.user for use in output.uses()] + + if not users or terminate(users): + return [accum] + + return concat_lists([inner(nxt, accum + [nxt]) for nxt in users]) + + return inner(root_node, [root_node]) + + +def get_attr_chains(root_getattr_node): + """ Returns chains of attribute access starting from root_getattr_node + + For example, given attribute "block", as in "self.block" when "self" points + to the top level torch.nn.Module, it returns lists of attribute "chains", + e.g. ['block', '2'], ['block', '1'], ['block', '0', '_packed_params'] + + These sets of attributes form full attribute accessors. For example, + "self.block.1", "self.block.2" will return the second and third submodule, + and "self.block.0._packed_params" will return the parameters of the first + submodule. + """ + def terminate(users): + next_attrs = [user for user in users if user.kind() == "prim::GetAttr"] + return len(next_attrs) == 0 + + return get_use_chains(root_getattr_node, terminate) + + +def parse_params(graph, state_dict): + """ + Return Relay vars and TVM NDArrays for input parameters + A chain of prim::GetAttr nodes is processed one at a time + """ + getattr_nodes = graph.findAllNodes("prim::GetAttr", recurse=True) + params = {} + param_tensors = {} + seen = set() + + for node in getattr_nodes: + if _get_output_name(node) in seen: + continue + + for getattrs in get_attr_chains(node): + seen.update(map(_get_output_name, getattrs)) + + full_attr = _getattr_full_name(getattrs) + full_attr_node_name = _get_output_name(getattrs[-1]) + + if full_attr in state_dict: + torch_tensor = state_dict[full_attr] + tensor, var = _get_tensor_and_var(torch_tensor, + full_attr_node_name) + param_tensors[full_attr_node_name] = tensor + params[full_attr_node_name] = var + + return params, param_tensors + + +def parse_operators(operators, outputs, output_index_map, ret_name): + """ Convert each Torch IR operators to Relay equivalent """ + for node_name, op_node in operators.items(): + operator = op_node.kind() + inputs = _get_op_inputs(op_node, outputs, output_index_map) + + if operator == "prim::Constant": + output_index_map[node_name] = len(outputs) + outputs.append(_get_constant(op_node)) + elif operator == 'prim::ListConstruct' and _is_int_seq(inputs): + output_index_map[node_name] = len(outputs) + outputs.append(_expr.var(node_name, shape=inputs)) + elif operator in ['prim::ListConstruct', 'prim::TupleConstruct']: + output_index_map[node_name] = len(outputs) + outputs.append(inputs) + elif operator in ["prim::ListUnpack", 'prim::TupleUnpack']: + assert len(inputs) == 1 + unpacked_names = _get_output_names(op_node) + _update_outputs_from_pairs(zip(unpacked_names, inputs[0]), + outputs, output_index_map) + else: + output_index_map[node_name] = len(outputs) + relay_op = _convert_map[operator] + outputs.append(relay_op(inputs, _get_input_types(op_node))) + + return outputs[output_index_map[ret_name]] + + +def get_all_op_names(graph): + """ Return all operator names in the input graph """ + nodes = list(graph.nodes()) + return set(node.kind() for node in nodes) + + +def get_graph_input_names(script_module): + """ Use this function to set the keys for input_shapes""" + # It seems variable names could change the first time a copy is made + # Use the copy of the graph here to prevent troubles later + ir_inputs = _get_input_names(script_module.graph.copy()) + return ir_inputs[1:] # remove self at the 0th arg + def from_pytorch(script_module, input_shapes): """ Load PyTorch model in the form of a scripted PyTorch model and convert into relay. @@ -1016,17 +1007,35 @@ def from_pytorch(script_module, input_shapes): TorchScripted PyTorch graph Note: We currently only support traces (ie: torch.jit.trace(model, input)) - shape : Dictionary of input dimensions + input_shapes : Dictionary of input dimensions Graph level input shape dictionary + The keys should be the same one returned by get_graph_input_names(...) above Returns ------- mod : tvm.relay.Module The module that optimizations will be performed on. - params : dict of str to tvm.runtime - Dict of converted parameters stored in tvm.runtime format + params : dict of str to tvm.runtime.NDArray + Dict of converted parameters stored in tvm.runtime.ndarray format """ - g = Graph(script_module, input_shapes) - mod, params = g.from_pytorch() - return mod, params + graph = script_module.graph.copy() + _run_jit_passes(graph) + op_names = get_all_op_names(graph) + _report_missing_conversion(op_names) + + params = script_module.state_dict() + input_vars = parse_inputs(graph.inputs(), input_shapes) + param_vars, tensors = parse_params(graph, params) + + input_vars.update(param_vars) + outputs = list(input_vars.values()) + output_index_map = dict(zip(input_vars.keys(), range(len(outputs)))) + ret_name = _get_input_names(graph.return_node())[0] + + body = parse_operators(_get_operator_nodes(graph.nodes()), outputs, + output_index_map, ret_name) + func = tvm.relay.Function(_analysis.free_vars(body), body) + tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()} + + return _module.IRModule.from_expr(func), tvm_params diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index ba1d7bbe67bc..831389b7ebf5 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -31,6 +31,8 @@ from tvm import relay from tvm.contrib import graph_runtime from tvm.relay.testing.config import ctx_list +from tvm.relay.frontend.pytorch import get_graph_input_names + sys.setrecursionlimit(10000) @@ -94,6 +96,7 @@ def load_model(model_name): if hasattr(torchvision.models, model_name): return load_torchvision(model_name) try: + import pretrainedmodels if hasattr(pretrainedmodels, model_name): return load_pretrainedmodels(model_name) except ModuleNotFoundError: @@ -167,16 +170,15 @@ def verify_model(model_name, input_data=[]): baseline_outputs = tuple(out.cpu().numpy() for out in baseline_outputs) else: baseline_outputs = (baseline_outputs.float().cpu().numpy(),) - output_shapes = [out.shape for out in baseline_outputs] - dtype = "float32" - input_name = "input0" - input_shapes = {input_name: list(baseline_input.shape)} trace = torch.jit.trace(baseline_model, baseline_input).float().eval() + if torch.cuda.is_available(): trace = trace.cuda() else: trace = trace.cpu() + input_name = get_graph_input_names(trace)[0] # only one input + input_shapes = {input_name: list(baseline_input.shape)} mod, params = relay.frontend.from_pytorch(trace, input_shapes) compiled_input = {input_name: tvm.nd.array(baseline_input.cpu().numpy())} @@ -276,7 +278,7 @@ def forward(self, *args): class Multiply2(Module): def forward(self, *args): - return args[0] * 1 + return args[0] * 1.0 class Multiply3(Module): def forward(self, *args): @@ -507,7 +509,7 @@ def test_forward_size(): class Size1(Module): def forward(self, *args): - return args[0].size(0) * args[0] + return float(args[0].size(0)) * args[0] with torch.no_grad(): input_data = torch.rand(input_shape).float() @@ -708,6 +710,10 @@ def test_mnasnet0_5(): torch.set_grad_enabled(False) verify_model("mnasnet0_5") +def test_mobilenet_v2(): + torch.set_grad_enabled(False) + verify_model("mobilenet_v2") + """ #TODO: Fix VGG and AlexNet issues (probably due to pooling) def test_alexnet(): @@ -721,13 +727,9 @@ def test_vgg11(): def test_vgg11_bn(): torch.set_grad_enabled(False) verify_model("vgg11_bn") - -#TODO: Need to update schedule in tophub file after PR #4787 updated workloads -def test_mobilenet_v2(): - torch.set_grad_enabled(False) - verify_model("mobilenet_v2") """ + if __name__ == "__main__": # Single operator tests test_forward_add() @@ -767,3 +769,4 @@ def test_mobilenet_v2(): test_inception_v3() test_googlenet() test_mnasnet0_5() + test_mobilenet_v2() diff --git a/tutorials/frontend/from_pytorch.py b/tutorials/frontend/from_pytorch.py index c280c259c1fe..503f64a4e7d9 100644 --- a/tutorials/frontend/from_pytorch.py +++ b/tutorials/frontend/from_pytorch.py @@ -41,14 +41,13 @@ be unstable. """ -# tvm, relay import tvm from tvm import relay -# numpy, packaging import numpy as np -from packaging import version + from tvm.contrib.download import download_testdata +from tvm.relay.frontend.pytorch import get_graph_input_names # PyTorch imports import torch @@ -91,7 +90,8 @@ # Import the graph to Relay # ------------------------- # Convert PyTorch graph to Relay graph. -shape_dict = {'img': img.shape} +input_name = get_graph_input_names(scripted_model)[0] # only one input +shape_dict = {input_name: img.shape} mod, params = relay.frontend.from_pytorch(scripted_model, shape_dict) @@ -116,12 +116,12 @@ dtype = 'float32' m = graph_runtime.create(graph, lib, ctx) # Set inputs -m.set_input('img', tvm.nd.array(img.astype(dtype))) +m.set_input(input_name, tvm.nd.array(img.astype(dtype))) m.set_input(**params) # Execute m.run() # Get outputs -tvm_output = m.get_output(0, tvm.nd.empty(((1, 1000)), 'float32')) +tvm_output = m.get_output(0) ##################################################################### # Look up synset name @@ -163,4 +163,4 @@ torch_class_key = class_id_to_key[top1_torch] print('Relay top-1 id: {}, class name: {}'.format(top1_tvm, key_to_classname[tvm_class_key])) -print('Torch top-1 id: {}, class name: {}'.format(top1_torch, key_to_classname[torch_class_key])) \ No newline at end of file +print('Torch top-1 id: {}, class name: {}'.format(top1_torch, key_to_classname[torch_class_key]))