From 9bc2af5a13ca158b06a9733e8b4a40b38cc037b9 Mon Sep 17 00:00:00 2001 From: masahi Date: Wed, 26 Feb 2020 13:25:57 +0900 Subject: [PATCH 01/10] The initial import of refactored implementation, all tests passed --- python/tvm/relay/frontend/pytorch.py | 572 +++++++++--------- tests/python/frontend/pytorch/test_forward.py | 5 +- 2 files changed, 284 insertions(+), 293 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index edd6ad84ae3e..ee58de5421cf 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -18,6 +18,10 @@ # pylint: disable=consider-iterating-dictionary, invalid-name, unused-argument, unused-variable, broad-except # pylint: disable=import-outside-toplevel, simplifiable-if-expression, unnecessary-comprehension """PT: PyTorch frontend.""" +import itertools +from packaging import version + +import torch import numpy as np import tvm @@ -484,10 +488,19 @@ def _impl(inputs, attrs, params): def _mean(): def _impl(inputs, input_types): data = inputs[0] - axis = _infer_shape(inputs[1]) - keepdims = int(inputs[2]) - exclude = int(inputs[3]) + if inputs[1]: + axis = _infer_shape(inputs[1]) + else: + axis = None + if len(inputs) > 2 and inputs[2]: + keepdims = int(inputs[2]) + else: + keepdims = False + if len(inputs) > 3 and inputs[3]: + exclude = int(inputs[3]) + else: + exclude = False return _op.mean(data, axis, keepdims, exclude) return _impl @@ -651,7 +664,7 @@ def _convert_elemwise_input(data, input_type): if isinstance(data, torch.Tensor): return _expr.const(data.item(), dtype=_convert_data_type(input_type)) elif not isinstance(data, _expr.Expr): - return _expr.const(int(data), dtype=_convert_data_type(input_type)) + return _expr.const(data, dtype=_convert_data_type(input_type)) else: return data @@ -718,293 +731,231 @@ def _convert_elemwise_input(data, input_type): "aten::sqrt" : _sqrt() } -# Internal graph for parsing -class Graph(object): - """ A helper class for parsing PyTorch model to Relay graph.""" +def is_int_seq(seq): + return len(seq) > 0 and all([isinstance(i, int) for i in seq]) - def __init__(self, script_module, input_shapes): - self._script_module = script_module - self._graph = script_module.graph.copy() +def parse_inputs(graph_inputs, input_shapes): + ir_inputs = list(graph_inputs) + ir_names = [i.debugName() for i in ir_inputs] + input_vars = {} - # TODO: Temporary fix to remove prim::CallMethod node introduced in PT 1.4 - import torch - from packaging import version - if version.parse(torch.__version__) >= version.parse("1.4.0"): - torch._C._jit_pass_inline(self._graph) - - self._inputs_r = {} - self._params = {} - self._param_tensors = {} - self._consts = {} - self._ops = {} - self._op_inputs_r = {} - self._op_inputs_types = {} - self._input_shapes = input_shapes if input_shapes else {} - self._parsed_node_names = {} - - def from_pytorch(self): - """ Construct relay nodes from PyTorch graph - - Currently only supports traced PyTorch format which means no control flow. - User must perform torch.jit.trace on a model and pass this in. - Future support should include support scripted models (torch.jit.script) which - preserves control flow. - - Returns - ------- - mod : tvm.relay.Module - The module that optimizations will be performed on. - - params : dict of str to tvm.runtime - Dict of converted parameters stored in tvm.runtime format - """ - # Check for missing ops - missing_operators = self._parse_import_prerequisites() - - if missing_operators: - raise tvm.error.OpNotImplemented( \ - "The following operators are not implemented: {}".format(missing_operators)) - - # Translate PyTorch graph to by decorating Graph with state dict and inputs into each op - self._parse_inputs() - self._parse_params() - self._parse_ops() - - outputs = [] - nid = 0 - - for op_name, op_node in self._ops.items(): - if op_node.kind() == "prim::ListConstruct": - if any(inp.debugName() in self._parsed_node_names.keys() \ - for inp in op_node.inputs()): - list_constr = [] - for i in op_node.inputs(): - if i.debugName() in self._parsed_node_names.keys(): - list_constr.append( \ - outputs[self._parsed_node_names[i.debugName()]]) - elif i.node().kind() == "prim::Constant": - list_constr.append(int(self._consts[i.debugName()])) - elif i.debugName() in self._inputs_r.keys(): - list_constr.append(int(self._inputs_r[i.debugName()])) - - # Unwrap for tensors - if len(list_constr) == 1: - list_constr = list_constr[0] - - outputs.append(list_constr) - self._parsed_node_names[op_name] = nid - nid = nid+1 - elif op_node.kind() != "prim::Constant": - for i in op_node.inputs(): - if i.debugName() in self._parsed_node_names.keys(): - for cnt in range(0, len(self._op_inputs_r[op_name])): - if isinstance(self._op_inputs_r[op_name][cnt], str): - if "call/var" in self._op_inputs_r[op_name][cnt]: - self._op_inputs_r[op_name][cnt] = \ - outputs[self._parsed_node_names[i.debugName()]] - break - - call = _convert_map[op_node.kind()](self._op_inputs_r[op_name], - self._op_inputs_types[op_name]) - - outputs.append(call) - self._parsed_node_names[op_name] = nid - nid = nid+1 - - func = tvm.relay.Function(_analysis.free_vars(outputs[-1]), outputs[-1]) - - param = {k: tvm.nd.array(v) for k, v in self._param_tensors.items()} - - return _module.IRModule.from_expr(func), param - - def _parse_inputs(self): - """ Map inputs to parser and inputs to graph. """ - # Get names and objects of inputs for IR - ir_inputs = [i for i in self._graph.inputs()] - - # Create corresponding shape and add to input - for input_name, ir_input in zip(self._input_shapes, ir_inputs[1:]): - input_shape = self._input_shapes[input_name] - ir_input.setDebugName(input_name) - - ir_dtype = _convert_data_type(ir_input.type().scalarType().lower()) - self._inputs_r[input_name] = _expr.var(input_name, - shape=self._input_shapes[input_name], - dtype=ir_dtype) - - # Add self (first input of a PyTorch graph) to inputs, the value doesn't matter here - input_name = ir_inputs[0].debugName() - self._inputs_r[input_name] = "self" - - def _parse_params(self): - """ Map state dictionary values to corresponding prim::GetAttr op node. """ - # Grab weights, biases, etc. from graph - state_dict = self._script_module.state_dict() - param_names = [] - for key, value in state_dict.items(): - param_str = str(key) - param_name = param_str.split(".")[-1] - param_names.append(param_name) - - # Get names of all inputs - input_names = [i for i in self._inputs_r.keys()] - - # Iterate through graph for getAttr nodes and match full state_dict name to nodes - node_weight_map = {} - for node in self._graph.nodes(): - if node.kind() == "prim::GetAttr": - - attribute_names = node.attributeNames() - assert len(attribute_names) == 1 - node_getattr_name = node.s(attribute_names[0]) - node_arg = node.input().debugName() - - if node.outputsSize() == 1: - node_name = node.output().debugName() - else: - node_name = [output.debugName() for output in node.outputs()][0] - - if node_arg in input_names: - node_weight_map[node_name] = node_getattr_name - else: - previous_map = node_weight_map[node_arg[:]] - node_weight_map[node_name] = previous_map+"."+node_getattr_name - - if node_getattr_name in param_names: - - value = state_dict[node_weight_map[node_name]] - tensor = tvm.nd.array(value.cpu().numpy()) - shape = tensor.shape - self._param_tensors[node_name] = tensor - - self._params[node_name] = _expr.var(node_name, - shape=shape, - dtype=_convert_data_type(str(value.dtype))) - - def _parse_ops(self): - """ Iterate through nodes and decorate graph with constants, operators, - and the inputs to each operator. """ - # Traverse nodes and add to graph - for node in self._graph.nodes(): - - if node.outputsSize() == 1: - node_name = node.output().debugName() - else: - node_name = [output.debugName() for output in node.outputs()][0] - - if node.kind() == "prim::Constant": - if node.hasAttributes(): - attribute_names = node.attributeNames() - attr_name = attribute_names[0] - ty = node.output().type().kind() - - if ty in ["IntType", "BoolType"]: - self._consts[node_name] = node.i(attr_name) - elif ty in ["FloatType", "LongType"]: - self._consts[node_name] = node.f(attr_name) - elif ty in ["TensorType", "CompleteTensorType"]: - self._consts[node_name] = node.output().toIValue() - else: - self._consts[node_name] = "0" - else: - self._consts[node_name] = "0" - elif node.kind() == "prim::ListConstruct": - list_shape = [] - for input_node in node.inputs(): - if input_node.debugName() in self._inputs_r.keys(): - c = self._inputs_r[input_node.debugName()] - assert isinstance(c, int) - list_shape.append(c) - elif input_node.debugName() in self._consts.keys(): - c = self._consts[input_node.debugName()] - assert isinstance(c, int) - list_shape.append(c) - self._inputs_r[node_name] = _expr.var(node_name, shape=list_shape) - - if node.kind() != "prim::GetAttr": - self._add_op(node_name, node) - - # Graph Helper Functions - - def _add_op(self, node_id, op_node): - """ Add an operator and its operators inputs to the graph and insert placeholders - where an input is a call node. - - Parameters - ---------- - node_id : string - The ID of the op node - - op_node : PyTorch Node object - The full Node object for the op node - - """ - self._ops[(node_id)] = op_node - input_list_r = [] - input_list_types = [] - for input_value in op_node.inputs(): - - inode_id = input_value.debugName() - inode = input_value.node() - - if inode_id in self._inputs_r.keys(): - input_list_r.append(self._inputs_r[inode_id]) - elif inode_id in self._params.keys(): - input_list_r.append(self._params[inode_id]) - elif inode.kind() == "prim::Constant": - input_list_r.append(self._consts[inode_id]) + for input_name, ir_input in zip(input_shapes, ir_inputs[1:]): + input_shape = input_shapes[input_name] + ir_input.setDebugName(input_name) + input_vars[input_name] = _expr.var(input_name, + shape=input_shapes[input_name]) + # Add self (first input of a PyTorch graph) to inputs + input_shape = [3] + tensor = tvm.nd.array(np.zeros(input_shape).astype(np.float32)) + input_name = ir_names[0] # self.1 + input_vars[input_name] = tensor + + return input_vars + + +def get_tensor_and_var(torch_tensor, name): + tensor = tvm.nd.array(torch_tensor.cpu().numpy()) + var = _expr.var(name, shape=tensor.shape) + return tensor, var + + +def get_output_name(node): + assert node.outputsSize() == 1 + return node.output().debugName() + + +def get_output_names(node): + return [output.debugName() for output in node.outputs()] + + +def get_input_names(node): + return [inp.debugName() for inp in node.inputs()] + + +def getattr_attr_name(node): + attribute_names = node.attributeNames() + assert(len(attribute_names) == 1) + attr_name = node.s(attribute_names[0]) + return attr_name + + +def get_use_chains(root_node, terminate=lambda _: False): + def concat_lists(lists): + return itertools.chain.from_iterable(lists) + + def inner(current, accum): + users = [] + for output in current.outputs(): + users += [use.user for use in output.uses()] + + if not users or terminate(users): + return [accum] + + return concat_lists([inner(nxt, accum + [nxt]) for nxt in users]) + + return inner(root_node, [root_node]) + + +def get_attr_chains(root_getattr_node): + """Returns chains of attribute access starting from root_getattr_node + + For example, given attribute "block", as in "self.block" when "self" points + to the top level torch.nn.Module, it returns lists of attribute "chains", + e.g. ['block', '2'], ['block', '1'], ['block', '0', '_packed_params'] + + These sets of attributes form full attribute accessors. For example, + "self.block.1", "self.block.2" will return the second and third submodule, + and "self.block.0._packed_params" will return the parameters of the first + submodule. + """ + def terminate(users): + next_attrs = [user for user in users if user.kind() == "prim::GetAttr"] + return len(next_attrs) == 0 + + return get_use_chains(root_getattr_node, terminate) + + +def get_full_attr_name(getattrs): + return ".".join([getattr_attr_name(node) for node in getattrs]) + + +def parse_params(graph, state_dict): + getattr_nodes = graph.findAllNodes("prim::GetAttr", recurse=True) + params = {} + param_tensors = {} + seen = set() + + for node in getattr_nodes: + if get_output_name(node) in seen: + continue + + for getattrs in get_attr_chains(node): + seen.update(map(get_output_name, getattrs)) + + full_attr = get_full_attr_name(getattrs) + full_attr_node_name = get_output_name(getattrs[-1]) + + if full_attr in state_dict: + torch_tensor = state_dict[full_attr] + tensor, var = get_tensor_and_var(torch_tensor, + full_attr_node_name) + param_tensors[full_attr_node_name] = tensor + params[full_attr_node_name] = var + + return params, param_tensors + + +def get_input_types(op_node): + input_list_types = [] + for input_node in op_node.inputs(): + in_ty = input_node.type() + input_node_kind = in_ty.kind() + if input_node_kind == 'TensorType': + if in_ty.scalarType() is None: + input_list_types.append('float') else: - input_list_r.append("call/var."+inode_id) - - # If the inputs of a ListConstruct op is a call or var, remove it from inputs - if op_node.kind() == "prim::ListConstruct": - if node_id in self._inputs_r.keys(): - self._inputs_r.pop(node_id) - - try: - input_value_kind = input_value.type().kind() - if input_value_kind in ["TensorType", "CompleteTensorType"]: - if input_value.type().scalarType() is None: - input_list_types.append("float") - else: - input_list_types.append(input_value.type().scalarType().lower()) - elif input_value_kind == "ListType": - input_list_types.append(str(input_value.type().getElementType()).lower()) - elif input_value_kind in ["IntType", "FloatType", "BoolType", "StringType", - "OptionalType"]: - input_list_types.append(str(input_value.type()).lower()) - else: - input_list_types.append("UnsupportedType") - print("UnsupportedType "+str(input_value.type())+" and "+str(input_value_kind)) - except Exception as e: - print("Internal PyTorch error. Failed to grab type.") - - if op_node.kind() in ["aten::ones", "aten::zeros"]: - node_type = op_node.output().type().scalarType() - input_list_types[0] = node_type.lower() - - self._op_inputs_r[node_id] = input_list_r - self._op_inputs_types[node_id] = input_list_types - - def _parse_import_prerequisites(self): - """ Calculate the named preconditions from PyTorch graph. - - Returns - ------- - missing_operators : set object - Set of operator names which don't have their mapping in TVM - i.e. which are not supported - - """ - missing_operators = set() - for node in self._graph.nodes(): - if not node.kind() in ["prim::Constant", "prim::ListConstruct", "prim::GetAttr"] \ - and not node.kind() in _convert_map: - missing_operators.add(node.kind()) - - return missing_operators + input_list_types.append(in_ty.scalarType().lower()) + elif input_node_kind == 'ListType': + input_list_types.append(str(in_ty.getElementType()).lower()) + elif input_node_kind in ['IntType', 'FloatType', 'BoolType', + 'StringType', 'OptionalType']: + input_list_types.append(str(in_ty).lower()) + else: + input_list_types.append('UnsupportedType') + + if op_node.kind() in ['aten::ones', 'aten::zeros']: + node_type = op_node.output().type() + scalar_type = node_type.scalarType() + if scalar_type: + input_list_types[0] = scalar_type.lower() + + return input_list_types + + +def get_constant(node): + attribute_names = node.attributeNames() + num_attributes = len(attribute_names) + + if num_attributes == 1: + attr_name = attribute_names[0] + ty = node.output().type().kind() + + if ty == "IntType" or ty == "BoolType": + return node.i(attr_name) + elif ty in ["FloatType", "LongType"]: + return node.f(attr_name) + elif ty in ["TensorType", "CompleteTensorType"]: + tensor = node.t(attr_name) + if len(tensor.shape) == 0: # tensor(0.1) + return float(tensor) + return tensor + elif ty == "DeviceObjType": + return node.s(attr_name) + elif ty == "FunctionType": + return None + else: + print(ty, node) + assert False # TODO: handle other types + else: + assert num_attributes == 0 + return None + + +def parse_ops(nodes): + ops = {} + # Traverse nodes and add to graph + for node in nodes: + if node.outputsSize() > 1: + node_name = "_".join(get_output_names(node)) + else: + node_name = get_output_name(node) + + if node.kind() != "prim::GetAttr": + ops[node_name] = node + + return ops + + +def get_input_node_names(op_node, output_index_map): + return [output_index_map[name] for name in get_input_names(op_node)] + + +def get_op_inputs(op_node, outputs, output_index_map): + input_names = get_input_node_names(op_node, output_index_map) + return [outputs[name] for name in input_names] + + +def run_jit_passes(graph): + if version.parse(torch.__version__) >= version.parse("1.4.0"): + torch._C._jit_pass_inline(graph) + + +def update_outputs_from_pairs(name_output_pairs, outputs, output_index_map): + for output_name, output in name_output_pairs: + output_index_map[output_name] = len(outputs) + outputs.append(output) + + +def get_all_op_names(graph): + nodes = list(graph.nodes()) + return set([node.kind() for node in nodes]) + + +def report_missing_conversion(graph): + known_ops = ["prim::Constant", "prim::GetAttr", + "prim::ListConstruct", "prim::ListUnpack", + "prim::TupleConstruct", "prim::TupleUnpack"] + known_ops += list(_convert_map.keys()) + + missing = [op_name for op_name in get_all_op_names(graph) + if op_name not in known_ops] + + if missing: + msg = "The following operators are not implemented: {}".format(missing) + raise NotImplementedError(msg) + def from_pytorch(script_module, input_shapes): """ Load PyTorch model in the form of a scripted PyTorch model and convert into relay. @@ -1024,9 +975,48 @@ def from_pytorch(script_module, input_shapes): mod : tvm.relay.Module The module that optimizations will be performed on. - params : dict of str to tvm.runtime - Dict of converted parameters stored in tvm.runtime format + params : dict of str to tvm.ndarray + Dict of converted parameters stored in tvm.ndarray format """ - g = Graph(script_module, input_shapes) - mod, params = g.from_pytorch() - return mod, params + graph = script_module.graph.copy() + run_jit_passes(graph) + report_missing_conversion(graph) + + params = script_module.state_dict() + input_vars = parse_inputs(graph.inputs(), input_shapes) + param_vars, tensors = parse_params(graph, params) + ops = parse_ops(graph.nodes()) + + input_vars.update(param_vars) + outputs = list(input_vars.values()) + output_index_map = dict(zip(input_vars.keys(), range(len(outputs)))) + + for node_name, op_node in ops.items(): + operator = op_node.kind() + inputs = get_op_inputs(op_node, outputs, output_index_map) + + if operator == "prim::Constant": + output_index_map[node_name] = len(outputs) + outputs.append(get_constant(op_node)) + elif operator == 'prim::ListConstruct' and is_int_seq(inputs): + output_index_map[node_name] = len(outputs) + outputs.append(_expr.var(node_name, shape=inputs)) + elif operator in ['prim::ListConstruct', 'prim::TupleConstruct']: + output_index_map[node_name] = len(outputs) + outputs.append(inputs) + elif operator in ["prim::ListUnpack", 'prim::TupleUnpack']: + assert len(inputs) == 1 + unpacked_names = get_output_names(op_node) + update_outputs_from_pairs(zip(unpacked_names, inputs[0]), + outputs, output_index_map) + else: + output_index_map[node_name] = len(outputs) + relay_op = _convert_map[operator] + outputs.append(relay_op(inputs, get_input_types(op_node))) + + ret_name = get_input_names(graph.return_node())[0] + body = outputs[output_index_map[ret_name]] + func = tvm.relay.Function(_analysis.free_vars(body), body) + tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()} + + return _module.IRModule.from_expr(func), tvm_params diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index ba1d7bbe67bc..8cb0f578a99c 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -94,6 +94,7 @@ def load_model(model_name): if hasattr(torchvision.models, model_name): return load_torchvision(model_name) try: + import pretrainedmodels if hasattr(pretrainedmodels, model_name): return load_pretrainedmodels(model_name) except ModuleNotFoundError: @@ -276,7 +277,7 @@ def forward(self, *args): class Multiply2(Module): def forward(self, *args): - return args[0] * 1 + return args[0] * 1.0 class Multiply3(Module): def forward(self, *args): @@ -507,7 +508,7 @@ def test_forward_size(): class Size1(Module): def forward(self, *args): - return args[0].size(0) * args[0] + return float(args[0].size(0)) * args[0] with torch.no_grad(): input_data = torch.rand(input_shape).float() From e1c9e20ef28251da779bfb7034e34a9a790149bd Mon Sep 17 00:00:00 2001 From: masahi Date: Wed, 26 Feb 2020 13:31:40 +0900 Subject: [PATCH 02/10] enable mobilenet v2 test --- tests/python/frontend/pytorch/test_forward.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 8cb0f578a99c..8d1e981d8e75 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -709,6 +709,10 @@ def test_mnasnet0_5(): torch.set_grad_enabled(False) verify_model("mnasnet0_5") +def test_mobilenet_v2(): + torch.set_grad_enabled(False) + verify_model("mobilenet_v2") + """ #TODO: Fix VGG and AlexNet issues (probably due to pooling) def test_alexnet(): @@ -722,13 +726,9 @@ def test_vgg11(): def test_vgg11_bn(): torch.set_grad_enabled(False) verify_model("vgg11_bn") - -#TODO: Need to update schedule in tophub file after PR #4787 updated workloads -def test_mobilenet_v2(): - torch.set_grad_enabled(False) - verify_model("mobilenet_v2") """ + if __name__ == "__main__": # Single operator tests test_forward_add() @@ -768,3 +768,4 @@ def test_mobilenet_v2(): test_inception_v3() test_googlenet() test_mnasnet0_5() + test_mobilenet_v2() From 186f2e2fb08e8dcea1fb8df8788e8ff1b40494e4 Mon Sep 17 00:00:00 2001 From: masahi Date: Wed, 26 Feb 2020 13:45:28 +0900 Subject: [PATCH 03/10] minor cleanup --- python/tvm/relay/frontend/pytorch.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index ee58de5421cf..bdc0bfdf7628 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -400,9 +400,11 @@ def _impl(inputs, input_types): def _size(): def _impl(inputs, input_types): - axis = int(inputs[1]) shape = _infer_shape(inputs[0]) - return shape[axis] + if len(inputs) > 1: + axis = int(inputs[1]) + return shape[axis] + return shape return _impl def _numtotensor(): @@ -738,20 +740,12 @@ def is_int_seq(seq): def parse_inputs(graph_inputs, input_shapes): ir_inputs = list(graph_inputs) - ir_names = [i.debugName() for i in ir_inputs] input_vars = {} for input_name, ir_input in zip(input_shapes, ir_inputs[1:]): - input_shape = input_shapes[input_name] ir_input.setDebugName(input_name) input_vars[input_name] = _expr.var(input_name, shape=input_shapes[input_name]) - # Add self (first input of a PyTorch graph) to inputs - input_shape = [3] - tensor = tvm.nd.array(np.zeros(input_shape).astype(np.float32)) - input_name = ir_names[0] # self.1 - input_vars[input_name] = tensor - return input_vars @@ -896,8 +890,7 @@ def get_constant(node): elif ty == "FunctionType": return None else: - print(ty, node) - assert False # TODO: handle other types + raise NotImplementedError("Unsupported type: %s" % ty) else: assert num_attributes == 0 return None @@ -975,7 +968,7 @@ def from_pytorch(script_module, input_shapes): mod : tvm.relay.Module The module that optimizations will be performed on. - params : dict of str to tvm.ndarray + params : dict of str to tvm.runtime.NDArray Dict of converted parameters stored in tvm.ndarray format """ graph = script_module.graph.copy() From 86a5a63113170077ecf8b2e06c0bb0c61ba4425b Mon Sep 17 00:00:00 2001 From: masahi Date: Wed, 26 Feb 2020 14:11:24 +0900 Subject: [PATCH 04/10] reorg --- python/tvm/relay/frontend/pytorch.py | 160 +++++++++++++-------------- 1 file changed, 79 insertions(+), 81 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index bdc0bfdf7628..80b710d3e438 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -734,19 +734,13 @@ def _convert_elemwise_input(data, input_type): } -def is_int_seq(seq): - return len(seq) > 0 and all([isinstance(i, int) for i in seq]) - +def run_jit_passes(graph): + if version.parse(torch.__version__) >= version.parse("1.4.0"): + torch._C._jit_pass_inline(graph) -def parse_inputs(graph_inputs, input_shapes): - ir_inputs = list(graph_inputs) - input_vars = {} - for input_name, ir_input in zip(input_shapes, ir_inputs[1:]): - ir_input.setDebugName(input_name) - input_vars[input_name] = _expr.var(input_name, - shape=input_shapes[input_name]) - return input_vars +def is_int_seq(seq): + return len(seq) > 0 and all([isinstance(i, int) for i in seq]) def get_tensor_and_var(torch_tensor, name): @@ -768,6 +762,37 @@ def get_input_names(node): return [inp.debugName() for inp in node.inputs()] +def get_op_inputs(op_node, outputs, output_index_map): + input_names = [output_index_map[name] + for name in get_input_names(op_node)] + return [outputs[name] for name in input_names] + + +def update_outputs_from_pairs(name_output_pairs, outputs, output_index_map): + for output_name, output in name_output_pairs: + output_index_map[output_name] = len(outputs) + outputs.append(output) + + +def get_all_op_names(graph): + nodes = list(graph.nodes()) + return set([node.kind() for node in nodes]) + + +def report_missing_conversion(op_names): + known_ops = ["prim::Constant", "prim::GetAttr", + "prim::ListConstruct", "prim::ListUnpack", + "prim::TupleConstruct", "prim::TupleUnpack"] + known_ops += list(_convert_map.keys()) + + missing = [op_name for op_name in op_names + if op_name not in known_ops] + + if missing: + msg = "The following operators are not implemented: {}".format(missing) + raise NotImplementedError(msg) + + def getattr_attr_name(node): attribute_names = node.attributeNames() assert(len(attribute_names) == 1) @@ -775,6 +800,10 @@ def getattr_attr_name(node): return attr_name +def get_full_attr_name(getattrs): + return ".".join([getattr_attr_name(node) for node in getattrs]) + + def get_use_chains(root_node, terminate=lambda _: False): def concat_lists(lists): return itertools.chain.from_iterable(lists) @@ -811,36 +840,6 @@ def terminate(users): return get_use_chains(root_getattr_node, terminate) -def get_full_attr_name(getattrs): - return ".".join([getattr_attr_name(node) for node in getattrs]) - - -def parse_params(graph, state_dict): - getattr_nodes = graph.findAllNodes("prim::GetAttr", recurse=True) - params = {} - param_tensors = {} - seen = set() - - for node in getattr_nodes: - if get_output_name(node) in seen: - continue - - for getattrs in get_attr_chains(node): - seen.update(map(get_output_name, getattrs)) - - full_attr = get_full_attr_name(getattrs) - full_attr_node_name = get_output_name(getattrs[-1]) - - if full_attr in state_dict: - torch_tensor = state_dict[full_attr] - tensor, var = get_tensor_and_var(torch_tensor, - full_attr_node_name) - param_tensors[full_attr_node_name] = tensor - params[full_attr_node_name] = var - - return params, param_tensors - - def get_input_types(op_node): input_list_types = [] for input_node in op_node.inputs(): @@ -896,6 +895,43 @@ def get_constant(node): return None +def parse_inputs(graph_inputs, input_shapes): + ir_inputs = list(graph_inputs) + input_vars = {} + + for input_name, ir_input in zip(input_shapes, ir_inputs[1:]): + ir_input.setDebugName(input_name) + input_vars[input_name] = _expr.var(input_name, + shape=input_shapes[input_name]) + return input_vars + + +def parse_params(graph, state_dict): + getattr_nodes = graph.findAllNodes("prim::GetAttr", recurse=True) + params = {} + param_tensors = {} + seen = set() + + for node in getattr_nodes: + if get_output_name(node) in seen: + continue + + for getattrs in get_attr_chains(node): + seen.update(map(get_output_name, getattrs)) + + full_attr = get_full_attr_name(getattrs) + full_attr_node_name = get_output_name(getattrs[-1]) + + if full_attr in state_dict: + torch_tensor = state_dict[full_attr] + tensor, var = get_tensor_and_var(torch_tensor, + full_attr_node_name) + param_tensors[full_attr_node_name] = tensor + params[full_attr_node_name] = var + + return params, param_tensors + + def parse_ops(nodes): ops = {} # Traverse nodes and add to graph @@ -911,45 +947,6 @@ def parse_ops(nodes): return ops -def get_input_node_names(op_node, output_index_map): - return [output_index_map[name] for name in get_input_names(op_node)] - - -def get_op_inputs(op_node, outputs, output_index_map): - input_names = get_input_node_names(op_node, output_index_map) - return [outputs[name] for name in input_names] - - -def run_jit_passes(graph): - if version.parse(torch.__version__) >= version.parse("1.4.0"): - torch._C._jit_pass_inline(graph) - - -def update_outputs_from_pairs(name_output_pairs, outputs, output_index_map): - for output_name, output in name_output_pairs: - output_index_map[output_name] = len(outputs) - outputs.append(output) - - -def get_all_op_names(graph): - nodes = list(graph.nodes()) - return set([node.kind() for node in nodes]) - - -def report_missing_conversion(graph): - known_ops = ["prim::Constant", "prim::GetAttr", - "prim::ListConstruct", "prim::ListUnpack", - "prim::TupleConstruct", "prim::TupleUnpack"] - known_ops += list(_convert_map.keys()) - - missing = [op_name for op_name in get_all_op_names(graph) - if op_name not in known_ops] - - if missing: - msg = "The following operators are not implemented: {}".format(missing) - raise NotImplementedError(msg) - - def from_pytorch(script_module, input_shapes): """ Load PyTorch model in the form of a scripted PyTorch model and convert into relay. The companion parameters will be handled automatically. @@ -973,7 +970,8 @@ def from_pytorch(script_module, input_shapes): """ graph = script_module.graph.copy() run_jit_passes(graph) - report_missing_conversion(graph) + op_names = get_all_op_names(graph) + report_missing_conversion(op_names) params = script_module.state_dict() input_vars = parse_inputs(graph.inputs(), input_shapes) From fb1998d68263a8744fc2100c27822cf750346672 Mon Sep 17 00:00:00 2001 From: masahi Date: Wed, 26 Feb 2020 14:43:00 +0900 Subject: [PATCH 05/10] fix lint --- python/tvm/relay/frontend/pytorch.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 80b710d3e438..5623a1db5400 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -21,7 +21,6 @@ import itertools from packaging import version -import torch import numpy as np import tvm @@ -735,6 +734,8 @@ def _convert_elemwise_input(data, input_type): def run_jit_passes(graph): + """ The inline pass is nessary to unwrap prim::CallMethod """ + import torch if version.parse(torch.__version__) >= version.parse("1.4.0"): torch._C._jit_pass_inline(graph) @@ -776,10 +777,11 @@ def update_outputs_from_pairs(name_output_pairs, outputs, output_index_map): def get_all_op_names(graph): nodes = list(graph.nodes()) - return set([node.kind() for node in nodes]) + return set(node.kind() for node in nodes) def report_missing_conversion(op_names): + """Check if all ops in an input graph are supported by TVM""" known_ops = ["prim::Constant", "prim::GetAttr", "prim::ListConstruct", "prim::ListUnpack", "prim::TupleConstruct", "prim::TupleUnpack"] @@ -795,7 +797,7 @@ def report_missing_conversion(op_names): def getattr_attr_name(node): attribute_names = node.attributeNames() - assert(len(attribute_names) == 1) + assert len(attribute_names) == 1 attr_name = node.s(attribute_names[0]) return attr_name @@ -805,6 +807,10 @@ def get_full_attr_name(getattrs): def get_use_chains(root_node, terminate=lambda _: False): + """ + Track a chain of users of this node forward, returning a list of chains + See get_attr_chains below for its usage + """ def concat_lists(lists): return itertools.chain.from_iterable(lists) @@ -841,6 +847,7 @@ def terminate(users): def get_input_types(op_node): + """Returns a torch type for each input nodes""" input_list_types = [] for input_node in op_node.inputs(): in_ty = input_node.type() @@ -868,6 +875,7 @@ def get_input_types(op_node): def get_constant(node): + """ Retrive a constant associated with this prim::Constant node""" attribute_names = node.attributeNames() num_attributes = len(attribute_names) @@ -875,7 +883,7 @@ def get_constant(node): attr_name = attribute_names[0] ty = node.output().type().kind() - if ty == "IntType" or ty == "BoolType": + if ty in ["IntType", "BoolType"]: return node.i(attr_name) elif ty in ["FloatType", "LongType"]: return node.f(attr_name) @@ -896,6 +904,7 @@ def get_constant(node): def parse_inputs(graph_inputs, input_shapes): + """ Return Relay vars from torch input vars""" ir_inputs = list(graph_inputs) input_vars = {} @@ -907,6 +916,10 @@ def parse_inputs(graph_inputs, input_shapes): def parse_params(graph, state_dict): + """ + Return Relay vars and TVM NDArrays for input parameters + A chain of prim::GetAttr nodes is processed one at a time + """ getattr_nodes = graph.findAllNodes("prim::GetAttr", recurse=True) params = {} param_tensors = {} @@ -933,6 +946,7 @@ def parse_params(graph, state_dict): def parse_ops(nodes): + """ Returns torch IR nodes that need conversion to Relay """ ops = {} # Traverse nodes and add to graph for node in nodes: From 29385d596b5478b854b78d1314912c97c9197f1e Mon Sep 17 00:00:00 2001 From: masahi Date: Wed, 26 Feb 2020 17:04:10 +0900 Subject: [PATCH 06/10] use input names that come with torch IR --- python/tvm/relay/frontend/pytorch.py | 28 ++++++++++++------- tests/python/frontend/pytorch/test_forward.py | 9 +++--- tutorials/frontend/from_pytorch.py | 14 +++++----- 3 files changed, 30 insertions(+), 21 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 5623a1db5400..c0d983e0d6e7 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -759,8 +759,8 @@ def get_output_names(node): return [output.debugName() for output in node.outputs()] -def get_input_names(node): - return [inp.debugName() for inp in node.inputs()] +def get_input_names(node_or_graph): + return [inp.debugName() for inp in node_or_graph.inputs()] def get_op_inputs(op_node, outputs, output_index_map): @@ -781,7 +781,7 @@ def get_all_op_names(graph): def report_missing_conversion(op_names): - """Check if all ops in an input graph are supported by TVM""" + """ Check if all ops in an input graph are supported by TVM """ known_ops = ["prim::Constant", "prim::GetAttr", "prim::ListConstruct", "prim::ListUnpack", "prim::TupleConstruct", "prim::TupleUnpack"] @@ -828,7 +828,7 @@ def inner(current, accum): def get_attr_chains(root_getattr_node): - """Returns chains of attribute access starting from root_getattr_node + """ Returns chains of attribute access starting from root_getattr_node For example, given attribute "block", as in "self.block" when "self" points to the top level torch.nn.Module, it returns lists of attribute "chains", @@ -847,7 +847,7 @@ def terminate(users): def get_input_types(op_node): - """Returns a torch type for each input nodes""" + """ Returns a torch type for each input nodes """ input_list_types = [] for input_node in op_node.inputs(): in_ty = input_node.type() @@ -875,7 +875,7 @@ def get_input_types(op_node): def get_constant(node): - """ Retrive a constant associated with this prim::Constant node""" + """ Retrive a constant associated with this prim::Constant node """ attribute_names = node.attributeNames() num_attributes = len(attribute_names) @@ -904,12 +904,11 @@ def get_constant(node): def parse_inputs(graph_inputs, input_shapes): - """ Return Relay vars from torch input vars""" + """ Return Relay vars from torch input vars """ ir_inputs = list(graph_inputs) input_vars = {} for input_name, ir_input in zip(input_shapes, ir_inputs[1:]): - ir_input.setDebugName(input_name) input_vars[input_name] = _expr.var(input_name, shape=input_shapes[input_name]) return input_vars @@ -961,6 +960,14 @@ def parse_ops(nodes): return ops +def get_graph_input_names(script_module): + """ Use this function to set the keys for input_shapes""" + # It seems variable names could change the first time a copy is made + # Use the copy of the graph here to prevent troubles later + ir_inputs = get_input_names(script_module.graph.copy()) + return ir_inputs[1:] # remove self at the 0th arg + + def from_pytorch(script_module, input_shapes): """ Load PyTorch model in the form of a scripted PyTorch model and convert into relay. The companion parameters will be handled automatically. @@ -971,8 +978,9 @@ def from_pytorch(script_module, input_shapes): TorchScripted PyTorch graph Note: We currently only support traces (ie: torch.jit.trace(model, input)) - shape : Dictionary of input dimensions + input_shape : Dictionary of input dimensions Graph level input shape dictionary + The keys should be the same one returned by get_graph_input_names(...) above Returns ------- @@ -980,7 +988,7 @@ def from_pytorch(script_module, input_shapes): The module that optimizations will be performed on. params : dict of str to tvm.runtime.NDArray - Dict of converted parameters stored in tvm.ndarray format + Dict of converted parameters stored in tvm.runtime.ndarray format """ graph = script_module.graph.copy() run_jit_passes(graph) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 8d1e981d8e75..831389b7ebf5 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -31,6 +31,8 @@ from tvm import relay from tvm.contrib import graph_runtime from tvm.relay.testing.config import ctx_list +from tvm.relay.frontend.pytorch import get_graph_input_names + sys.setrecursionlimit(10000) @@ -168,16 +170,15 @@ def verify_model(model_name, input_data=[]): baseline_outputs = tuple(out.cpu().numpy() for out in baseline_outputs) else: baseline_outputs = (baseline_outputs.float().cpu().numpy(),) - output_shapes = [out.shape for out in baseline_outputs] - dtype = "float32" - input_name = "input0" - input_shapes = {input_name: list(baseline_input.shape)} trace = torch.jit.trace(baseline_model, baseline_input).float().eval() + if torch.cuda.is_available(): trace = trace.cuda() else: trace = trace.cpu() + input_name = get_graph_input_names(trace)[0] # only one input + input_shapes = {input_name: list(baseline_input.shape)} mod, params = relay.frontend.from_pytorch(trace, input_shapes) compiled_input = {input_name: tvm.nd.array(baseline_input.cpu().numpy())} diff --git a/tutorials/frontend/from_pytorch.py b/tutorials/frontend/from_pytorch.py index c280c259c1fe..503f64a4e7d9 100644 --- a/tutorials/frontend/from_pytorch.py +++ b/tutorials/frontend/from_pytorch.py @@ -41,14 +41,13 @@ be unstable. """ -# tvm, relay import tvm from tvm import relay -# numpy, packaging import numpy as np -from packaging import version + from tvm.contrib.download import download_testdata +from tvm.relay.frontend.pytorch import get_graph_input_names # PyTorch imports import torch @@ -91,7 +90,8 @@ # Import the graph to Relay # ------------------------- # Convert PyTorch graph to Relay graph. -shape_dict = {'img': img.shape} +input_name = get_graph_input_names(scripted_model)[0] # only one input +shape_dict = {input_name: img.shape} mod, params = relay.frontend.from_pytorch(scripted_model, shape_dict) @@ -116,12 +116,12 @@ dtype = 'float32' m = graph_runtime.create(graph, lib, ctx) # Set inputs -m.set_input('img', tvm.nd.array(img.astype(dtype))) +m.set_input(input_name, tvm.nd.array(img.astype(dtype))) m.set_input(**params) # Execute m.run() # Get outputs -tvm_output = m.get_output(0, tvm.nd.empty(((1, 1000)), 'float32')) +tvm_output = m.get_output(0) ##################################################################### # Look up synset name @@ -163,4 +163,4 @@ torch_class_key = class_id_to_key[top1_torch] print('Relay top-1 id: {}, class name: {}'.format(top1_tvm, key_to_classname[tvm_class_key])) -print('Torch top-1 id: {}, class name: {}'.format(top1_torch, key_to_classname[torch_class_key])) \ No newline at end of file +print('Torch top-1 id: {}, class name: {}'.format(top1_torch, key_to_classname[torch_class_key])) From 752685899e7bc57b76d4a7642e4d48e8cc304f09 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 27 Feb 2020 06:11:28 +0900 Subject: [PATCH 07/10] fix typo --- python/tvm/relay/frontend/pytorch.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index c0d983e0d6e7..e6c34d823ec0 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -734,7 +734,7 @@ def _convert_elemwise_input(data, input_type): def run_jit_passes(graph): - """ The inline pass is nessary to unwrap prim::CallMethod """ + """ The inline pass is necessary to unwrap prim::CallMethod """ import torch if version.parse(torch.__version__) >= version.parse("1.4.0"): torch._C._jit_pass_inline(graph) @@ -875,7 +875,7 @@ def get_input_types(op_node): def get_constant(node): - """ Retrive a constant associated with this prim::Constant node """ + """ Retrieve a constant associated with this prim::Constant node """ attribute_names = node.attributeNames() num_attributes = len(attribute_names) @@ -978,7 +978,7 @@ def from_pytorch(script_module, input_shapes): TorchScripted PyTorch graph Note: We currently only support traces (ie: torch.jit.trace(model, input)) - input_shape : Dictionary of input dimensions + input_shapes : Dictionary of input dimensions Graph level input shape dictionary The keys should be the same one returned by get_graph_input_names(...) above From 31ef1e7ba3bfd71c8d03a79bbbb80351d0accca1 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 27 Feb 2020 06:20:48 +0900 Subject: [PATCH 08/10] introduce parse_operators --- python/tvm/relay/frontend/pytorch.py | 80 +++++++++++++++------------- 1 file changed, 42 insertions(+), 38 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index e6c34d823ec0..26443765c5f0 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -903,6 +903,22 @@ def get_constant(node): return None +def get_operator_nodes(nodes): + """ Returns torch IR nodes that need conversion to Relay """ + ops = {} + # Traverse nodes and add to graph + for node in nodes: + if node.outputsSize() > 1: + node_name = "_".join(get_output_names(node)) + else: + node_name = get_output_name(node) + + if node.kind() != "prim::GetAttr": + ops[node_name] = node + + return ops + + def parse_inputs(graph_inputs, input_shapes): """ Return Relay vars from torch input vars """ ir_inputs = list(graph_inputs) @@ -944,20 +960,31 @@ def parse_params(graph, state_dict): return params, param_tensors -def parse_ops(nodes): - """ Returns torch IR nodes that need conversion to Relay """ - ops = {} - # Traverse nodes and add to graph - for node in nodes: - if node.outputsSize() > 1: - node_name = "_".join(get_output_names(node)) - else: - node_name = get_output_name(node) +def parse_operators(operators, outputs, output_index_map, ret_name): + for node_name, op_node in operators.items(): + operator = op_node.kind() + inputs = get_op_inputs(op_node, outputs, output_index_map) - if node.kind() != "prim::GetAttr": - ops[node_name] = node + if operator == "prim::Constant": + output_index_map[node_name] = len(outputs) + outputs.append(get_constant(op_node)) + elif operator == 'prim::ListConstruct' and is_int_seq(inputs): + output_index_map[node_name] = len(outputs) + outputs.append(_expr.var(node_name, shape=inputs)) + elif operator in ['prim::ListConstruct', 'prim::TupleConstruct']: + output_index_map[node_name] = len(outputs) + outputs.append(inputs) + elif operator in ["prim::ListUnpack", 'prim::TupleUnpack']: + assert len(inputs) == 1 + unpacked_names = get_output_names(op_node) + update_outputs_from_pairs(zip(unpacked_names, inputs[0]), + outputs, output_index_map) + else: + output_index_map[node_name] = len(outputs) + relay_op = _convert_map[operator] + outputs.append(relay_op(inputs, get_input_types(op_node))) - return ops + return outputs[output_index_map[ret_name]] def get_graph_input_names(script_module): @@ -998,37 +1025,14 @@ def from_pytorch(script_module, input_shapes): params = script_module.state_dict() input_vars = parse_inputs(graph.inputs(), input_shapes) param_vars, tensors = parse_params(graph, params) - ops = parse_ops(graph.nodes()) input_vars.update(param_vars) outputs = list(input_vars.values()) output_index_map = dict(zip(input_vars.keys(), range(len(outputs)))) - - for node_name, op_node in ops.items(): - operator = op_node.kind() - inputs = get_op_inputs(op_node, outputs, output_index_map) - - if operator == "prim::Constant": - output_index_map[node_name] = len(outputs) - outputs.append(get_constant(op_node)) - elif operator == 'prim::ListConstruct' and is_int_seq(inputs): - output_index_map[node_name] = len(outputs) - outputs.append(_expr.var(node_name, shape=inputs)) - elif operator in ['prim::ListConstruct', 'prim::TupleConstruct']: - output_index_map[node_name] = len(outputs) - outputs.append(inputs) - elif operator in ["prim::ListUnpack", 'prim::TupleUnpack']: - assert len(inputs) == 1 - unpacked_names = get_output_names(op_node) - update_outputs_from_pairs(zip(unpacked_names, inputs[0]), - outputs, output_index_map) - else: - output_index_map[node_name] = len(outputs) - relay_op = _convert_map[operator] - outputs.append(relay_op(inputs, get_input_types(op_node))) - ret_name = get_input_names(graph.return_node())[0] - body = outputs[output_index_map[ret_name]] + + body = parse_operators(get_operator_nodes(graph.nodes()), outputs, + output_index_map, ret_name) func = tvm.relay.Function(_analysis.free_vars(body), body) tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()} From 7351c42925aa63997a75b2385d614904fa2c7a77 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 27 Feb 2020 06:27:15 +0900 Subject: [PATCH 09/10] fix lint --- python/tvm/relay/frontend/pytorch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 26443765c5f0..6c3ef29e1cbd 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -961,6 +961,7 @@ def parse_params(graph, state_dict): def parse_operators(operators, outputs, output_index_map, ret_name): + """ Convert each Torch IR operators to Relay equivalent """ for node_name, op_node in operators.items(): operator = op_node.kind() inputs = get_op_inputs(op_node, outputs, output_index_map) From 3a2cc942c3c3c28553aa645631d6f85596b859db Mon Sep 17 00:00:00 2001 From: masahi Date: Fri, 28 Feb 2020 10:10:18 +0900 Subject: [PATCH 10/10] add _ prefix --- python/tvm/relay/frontend/pytorch.py | 165 ++++++++++++++------------- 1 file changed, 83 insertions(+), 82 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 6c3ef29e1cbd..fd66e3c1f367 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -733,54 +733,49 @@ def _convert_elemwise_input(data, input_type): } -def run_jit_passes(graph): +def _run_jit_passes(graph): """ The inline pass is necessary to unwrap prim::CallMethod """ import torch if version.parse(torch.__version__) >= version.parse("1.4.0"): torch._C._jit_pass_inline(graph) -def is_int_seq(seq): +def _is_int_seq(seq): return len(seq) > 0 and all([isinstance(i, int) for i in seq]) -def get_tensor_and_var(torch_tensor, name): +def _get_tensor_and_var(torch_tensor, name): tensor = tvm.nd.array(torch_tensor.cpu().numpy()) var = _expr.var(name, shape=tensor.shape) return tensor, var -def get_output_name(node): +def _get_output_name(node): assert node.outputsSize() == 1 return node.output().debugName() -def get_output_names(node): +def _get_output_names(node): return [output.debugName() for output in node.outputs()] -def get_input_names(node_or_graph): +def _get_input_names(node_or_graph): return [inp.debugName() for inp in node_or_graph.inputs()] -def get_op_inputs(op_node, outputs, output_index_map): +def _get_op_inputs(op_node, outputs, output_index_map): input_names = [output_index_map[name] - for name in get_input_names(op_node)] + for name in _get_input_names(op_node)] return [outputs[name] for name in input_names] -def update_outputs_from_pairs(name_output_pairs, outputs, output_index_map): +def _update_outputs_from_pairs(name_output_pairs, outputs, output_index_map): for output_name, output in name_output_pairs: output_index_map[output_name] = len(outputs) outputs.append(output) -def get_all_op_names(graph): - nodes = list(graph.nodes()) - return set(node.kind() for node in nodes) - - -def report_missing_conversion(op_names): +def _report_missing_conversion(op_names): """ Check if all ops in an input graph are supported by TVM """ known_ops = ["prim::Constant", "prim::GetAttr", "prim::ListConstruct", "prim::ListUnpack", @@ -795,58 +790,18 @@ def report_missing_conversion(op_names): raise NotImplementedError(msg) -def getattr_attr_name(node): +def _getattr_attr_name(node): attribute_names = node.attributeNames() assert len(attribute_names) == 1 attr_name = node.s(attribute_names[0]) return attr_name -def get_full_attr_name(getattrs): - return ".".join([getattr_attr_name(node) for node in getattrs]) - - -def get_use_chains(root_node, terminate=lambda _: False): - """ - Track a chain of users of this node forward, returning a list of chains - See get_attr_chains below for its usage - """ - def concat_lists(lists): - return itertools.chain.from_iterable(lists) - - def inner(current, accum): - users = [] - for output in current.outputs(): - users += [use.user for use in output.uses()] - - if not users or terminate(users): - return [accum] - - return concat_lists([inner(nxt, accum + [nxt]) for nxt in users]) - - return inner(root_node, [root_node]) - - -def get_attr_chains(root_getattr_node): - """ Returns chains of attribute access starting from root_getattr_node - - For example, given attribute "block", as in "self.block" when "self" points - to the top level torch.nn.Module, it returns lists of attribute "chains", - e.g. ['block', '2'], ['block', '1'], ['block', '0', '_packed_params'] - - These sets of attributes form full attribute accessors. For example, - "self.block.1", "self.block.2" will return the second and third submodule, - and "self.block.0._packed_params" will return the parameters of the first - submodule. - """ - def terminate(users): - next_attrs = [user for user in users if user.kind() == "prim::GetAttr"] - return len(next_attrs) == 0 - - return get_use_chains(root_getattr_node, terminate) +def _getattr_full_name(getattrs): + return ".".join([_getattr_attr_name(node) for node in getattrs]) -def get_input_types(op_node): +def _get_input_types(op_node): """ Returns a torch type for each input nodes """ input_list_types = [] for input_node in op_node.inputs(): @@ -854,7 +809,7 @@ def get_input_types(op_node): input_node_kind = in_ty.kind() if input_node_kind == 'TensorType': if in_ty.scalarType() is None: - input_list_types.append('float') + input_list_types.append(None) else: input_list_types.append(in_ty.scalarType().lower()) elif input_node_kind == 'ListType': @@ -874,7 +829,7 @@ def get_input_types(op_node): return input_list_types -def get_constant(node): +def _get_constant(node): """ Retrieve a constant associated with this prim::Constant node """ attribute_names = node.attributeNames() num_attributes = len(attribute_names) @@ -903,15 +858,15 @@ def get_constant(node): return None -def get_operator_nodes(nodes): +def _get_operator_nodes(nodes): """ Returns torch IR nodes that need conversion to Relay """ ops = {} # Traverse nodes and add to graph for node in nodes: if node.outputsSize() > 1: - node_name = "_".join(get_output_names(node)) + node_name = "_".join(_get_output_names(node)) else: - node_name = get_output_name(node) + node_name = _get_output_name(node) if node.kind() != "prim::GetAttr": ops[node_name] = node @@ -930,6 +885,46 @@ def parse_inputs(graph_inputs, input_shapes): return input_vars +def get_use_chains(root_node, terminate=lambda _: False): + """ + Track a chain of users of this node forward, returning a list of chains + See get_attr_chains below for its usage + """ + def concat_lists(lists): + return itertools.chain.from_iterable(lists) + + def inner(current, accum): + users = [] + for output in current.outputs(): + users += [use.user for use in output.uses()] + + if not users or terminate(users): + return [accum] + + return concat_lists([inner(nxt, accum + [nxt]) for nxt in users]) + + return inner(root_node, [root_node]) + + +def get_attr_chains(root_getattr_node): + """ Returns chains of attribute access starting from root_getattr_node + + For example, given attribute "block", as in "self.block" when "self" points + to the top level torch.nn.Module, it returns lists of attribute "chains", + e.g. ['block', '2'], ['block', '1'], ['block', '0', '_packed_params'] + + These sets of attributes form full attribute accessors. For example, + "self.block.1", "self.block.2" will return the second and third submodule, + and "self.block.0._packed_params" will return the parameters of the first + submodule. + """ + def terminate(users): + next_attrs = [user for user in users if user.kind() == "prim::GetAttr"] + return len(next_attrs) == 0 + + return get_use_chains(root_getattr_node, terminate) + + def parse_params(graph, state_dict): """ Return Relay vars and TVM NDArrays for input parameters @@ -941,19 +936,19 @@ def parse_params(graph, state_dict): seen = set() for node in getattr_nodes: - if get_output_name(node) in seen: + if _get_output_name(node) in seen: continue for getattrs in get_attr_chains(node): - seen.update(map(get_output_name, getattrs)) + seen.update(map(_get_output_name, getattrs)) - full_attr = get_full_attr_name(getattrs) - full_attr_node_name = get_output_name(getattrs[-1]) + full_attr = _getattr_full_name(getattrs) + full_attr_node_name = _get_output_name(getattrs[-1]) if full_attr in state_dict: torch_tensor = state_dict[full_attr] - tensor, var = get_tensor_and_var(torch_tensor, - full_attr_node_name) + tensor, var = _get_tensor_and_var(torch_tensor, + full_attr_node_name) param_tensors[full_attr_node_name] = tensor params[full_attr_node_name] = var @@ -964,12 +959,12 @@ def parse_operators(operators, outputs, output_index_map, ret_name): """ Convert each Torch IR operators to Relay equivalent """ for node_name, op_node in operators.items(): operator = op_node.kind() - inputs = get_op_inputs(op_node, outputs, output_index_map) + inputs = _get_op_inputs(op_node, outputs, output_index_map) if operator == "prim::Constant": output_index_map[node_name] = len(outputs) - outputs.append(get_constant(op_node)) - elif operator == 'prim::ListConstruct' and is_int_seq(inputs): + outputs.append(_get_constant(op_node)) + elif operator == 'prim::ListConstruct' and _is_int_seq(inputs): output_index_map[node_name] = len(outputs) outputs.append(_expr.var(node_name, shape=inputs)) elif operator in ['prim::ListConstruct', 'prim::TupleConstruct']: @@ -977,22 +972,28 @@ def parse_operators(operators, outputs, output_index_map, ret_name): outputs.append(inputs) elif operator in ["prim::ListUnpack", 'prim::TupleUnpack']: assert len(inputs) == 1 - unpacked_names = get_output_names(op_node) - update_outputs_from_pairs(zip(unpacked_names, inputs[0]), - outputs, output_index_map) + unpacked_names = _get_output_names(op_node) + _update_outputs_from_pairs(zip(unpacked_names, inputs[0]), + outputs, output_index_map) else: output_index_map[node_name] = len(outputs) relay_op = _convert_map[operator] - outputs.append(relay_op(inputs, get_input_types(op_node))) + outputs.append(relay_op(inputs, _get_input_types(op_node))) return outputs[output_index_map[ret_name]] +def get_all_op_names(graph): + """ Return all operator names in the input graph """ + nodes = list(graph.nodes()) + return set(node.kind() for node in nodes) + + def get_graph_input_names(script_module): """ Use this function to set the keys for input_shapes""" # It seems variable names could change the first time a copy is made # Use the copy of the graph here to prevent troubles later - ir_inputs = get_input_names(script_module.graph.copy()) + ir_inputs = _get_input_names(script_module.graph.copy()) return ir_inputs[1:] # remove self at the 0th arg @@ -1019,9 +1020,9 @@ def from_pytorch(script_module, input_shapes): Dict of converted parameters stored in tvm.runtime.ndarray format """ graph = script_module.graph.copy() - run_jit_passes(graph) + _run_jit_passes(graph) op_names = get_all_op_names(graph) - report_missing_conversion(op_names) + _report_missing_conversion(op_names) params = script_module.state_dict() input_vars = parse_inputs(graph.inputs(), input_shapes) @@ -1030,9 +1031,9 @@ def from_pytorch(script_module, input_shapes): input_vars.update(param_vars) outputs = list(input_vars.values()) output_index_map = dict(zip(input_vars.keys(), range(len(outputs)))) - ret_name = get_input_names(graph.return_node())[0] + ret_name = _get_input_names(graph.return_node())[0] - body = parse_operators(get_operator_nodes(graph.nodes()), outputs, + body = parse_operators(_get_operator_nodes(graph.nodes()), outputs, output_index_map, ret_name) func = tvm.relay.Function(_analysis.free_vars(body), body) tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}