From 1c80a6990fa5637919639d6ef72819e957b82450 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 6 Apr 2020 08:46:10 +0900 Subject: [PATCH 01/31] use funcs from prelude, pass around convert_map --- python/tvm/relay/frontend/pytorch.py | 366 ++++++++++++++++++--------- 1 file changed, 243 insertions(+), 123 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index a542ccc48af0..6b74e8040891 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -34,11 +34,13 @@ from .common import get_relay_op from .common import infer_shape as _infer_shape from .common import infer_value as _infer_value +from ..prelude import Prelude, StaticTensorArrayOps from . import qnn_torch __all__ = ["from_pytorch"] + # operator implementation def _elemwise(name): def _impl(inputs, input_types): @@ -1077,6 +1079,49 @@ def _impl(inputs, input_types): return _op.cast(inputs[0], "float32") return _impl + +def _stack(): + def _impl(inputs, input_types): + if isinstance(inputs[0], list): + return _op.tensor.stack(inputs[0], 0) + else: + return _wrap_const(1) + return _impl + + +def _mm(): + def _impl(inputs, input_types): + return _op.nn.dense(inputs[0], inputs[1]) + return _impl + + +def _empty_list(prelude): + def _impl(inputs, input_types): + return prelude.nil() + return _impl + + +def _cons_list(prelude): + def _impl(inputs, input_types): + tensor2 = prelude.get_var('tensor2', "float32") + return prelude.cons(tensor2(inputs[0]), inputs[1]) + return _impl + + +def _rev_list(prelude): + def _impl(inputs, input_types): + return prelude.rev(inputs[0]) + return _impl + + +def _tensor_array_stack(prelude): + def _impl(inputs, input_types): + stack = prelude.get_var('tensor_array_stack', "float32") + stacked = stack(inputs[0]) + get_tensor_func = prelude.get_var("get_tensor2", "float32") + return get_tensor_func(stacked) + return _impl + # Helper functions for operator implementation def _convert_dtype_value(val): convert_torch_dtype_map = {7:"torch.float64", @@ -1153,107 +1198,114 @@ def _wrap_const(c): return c # Operator mappings - -_convert_map = { - "aten::device" : _none(), - "aten::add" : _elemwise("add"), - "aten::add_" : _elemwise("add"), - "aten::sub" : _elemwise("subtract"), - "aten::sub_" : _elemwise("subtract"), - "aten::max" : _elemwise("maximum"), - "aten::min" : _elemwise("minimum"), - "aten::mul" : _elemwise("multiply"), - "aten::mul_" : _elemwise("multiply"), - "aten::pow" : _elemwise("power"), - "aten::div" : _elemwise("divide"), - "aten::div_" : _elemwise("divide"), - "aten::abs" : _abs(), - "aten::arange" : _arange(), - "aten::ones" : _ones(), - "aten::zeros" : _zeros(), - "aten::reciprocal" : _reciprocal(), - "aten::repeat" : _repeat(), - "aten::repeat_interleave" : _repeat_interleave(), - "aten::to" : _to(), - "aten::squeeze" : _squeeze(), - "aten::unsqueeze" : _unsqueeze(), - "aten::cat" : _concatenate(), - "aten::slice" : _slice(), - "aten::split" : _split(), - "aten::split_with_sizes" : _split_with_sizes(), - "aten::select" : _select(), - "aten::relu" : _relu(), - "aten::relu_" : _relu(), - "aten::prelu" : _prelu(), - "aten::leaky_relu" : _leaky_relu(), - "aten::elu" : _elu(), - "aten::celu" : _celu(), - "aten::gelu" : _gelu(), - "aten::selu" : _selu(), - "aten::log_sigmoid" : _log_sigmoid(), - "aten::adaptive_avg_pool2d" : _adaptive_avg_pool_2d(), - "aten::adaptive_max_pool2d" : _adaptive_max_pool_2d(), - "aten::max_pool2d" : _maxpool_2d(), - "aten::max_pool2d_with_indices" : _maxpool_2d(), - "aten::max_pool1d" : _maxpool_1d(), - "aten::max_pool3d" : _maxpool_3d(), - "aten::hardtanh" : _hardtanh(), - "aten::hardtanh_" : _hardtanh(), - "aten::_convolution" : _convolution(), - "aten::softmax" : _softmax(), - "aten::threshold" : _threshold(), - "aten::threshold_" : _threshold(), - "aten::contiguous" : _contiguous(), - "aten::batch_norm" : _batch_norm(), - "aten::instance_norm" : _instance_norm(), - "aten::layer_norm" : _layer_norm(), - "aten::transpose" : _transpose(), - "aten::transpose_" : _transpose(), - "aten::t" : _transpose(), - "aten::flatten" : _flatten(), - "aten::addmm" : _dense(), - "aten::size" : _size(), - "aten::view" : _view(), - "aten::reshape" : _reshape(), - "aten::clone" : _clone(), - "aten::log_softmax" : _log_softmax(), - "aten::sigmoid" : _sigmoid(), - "aten::softplus" : _softplus(), - "aten::avg_pool2d" : _avg_pool2d(), - "aten::avg_pool3d" : _avg_pool3d(), - "aten::dropout" : _dropout(), - "aten::dropout_" : _dropout(), - "aten::feature_dropout" : _dropout(), - "aten::alpha_dropout" : _dropout(), - "aten::mean" : _mean(), - "aten::chunk" : _chunk(), - "aten::matmul" : _matmul(), - "aten::expand" : _expand(), - "aten::Int" : _int(), - "prim::NumToTensor" : _numtotensor(), - "prim::ListUnpack" : _identity(), - "aten::constant_pad_nd" : _pad(), - "aten::permute" : _transpose(), - "aten::sum" : _reduce("sum"), - "aten::prod" : _reduce("prod"), - "aten::sqrt" : _sqrt(), - 'aten::floor' : _floor(), - "aten::detach" : _identity(), - "aten::upsample_bilinear2d" : _upsample("bilinear"), - "aten::upsample_nearest2d" : _upsample("nearest_neighbor"), - "aten::expand_as" : _expand_as(), - "aten::lt" : _elemwise("less"), - "aten::gt" : _elemwise("greater"), - "aten::le" : _elemwise("less_equal"), - "aten::ge" : _elemwise("greater_equal"), - "aten::ne" : _elemwise("not_equal"), - "aten::Bool" : _Bool(), - "aten::Float" : _Float(), - "aten::neg" : _neg(), - "aten::tanh" : _tanh(), - "aten::adaptive_avg_pool3d" : _adaptive_avg_pool_3d(), - "aten::adaptive_max_pool3d" : _adaptive_max_pool_3d() -} +def get_convert_map(prelude): + convert_map = { + "aten::device" : _none(), + "aten::add" : _elemwise("add"), + "aten::add_" : _elemwise("add"), + "aten::sub" : _elemwise("subtract"), + "aten::sub_" : _elemwise("subtract"), + "aten::max" : _elemwise("maximum"), + "aten::min" : _elemwise("minimum"), + "aten::mul" : _elemwise("multiply"), + "aten::mul_" : _elemwise("multiply"), + "aten::pow" : _elemwise("power"), + "aten::abs" : _abs(), + "aten::arange" : _arange(), + "aten::div" : _elemwise("divide"), + "aten::div_" : _elemwise("divide"), + "aten::ones" : _ones(), + "aten::zeros" : _zeros(), + "aten::reciprocal" : _reciprocal(), + "aten::repeat" : _repeat(), + "aten::repeat_interleave" : _repeat_interleave(), + "aten::to" : _to(), + "aten::squeeze" : _squeeze(), + "aten::unsqueeze" : _unsqueeze(), + "aten::cat" : _concatenate(), + "aten::slice" : _slice(), + "aten::split" : _split(), + "aten::split_with_sizes" : _split_with_sizes(), + "aten::select" : _select(), + "aten::relu" : _relu(), + "aten::relu_" : _relu(), + "aten::prelu" : _prelu(), + "aten::leaky_relu" : _leaky_relu(), + "aten::elu" : _elu(), + "aten::celu" : _celu(), + "aten::gelu" : _gelu(), + "aten::selu" : _selu(), + "aten::log_sigmoid" : _log_sigmoid(), + "aten::adaptive_avg_pool2d" : _adaptive_avg_pool_2d(), + "aten::adaptive_max_pool2d" : _adaptive_max_pool_2d(), + "aten::max_pool2d" : _maxpool_2d(), + "aten::max_pool2d_with_indices" : _maxpool_2d(), + "aten::max_pool1d" : _maxpool_1d(), + "aten::max_pool3d" : _maxpool_3d(), + "aten::hardtanh" : _hardtanh(), + "aten::hardtanh_" : _hardtanh(), + "aten::_convolution" : _convolution(), + "aten::softmax" : _softmax(), + "aten::threshold" : _threshold(), + "aten::threshold_" : _threshold(), + "aten::contiguous" : _contiguous(), + "aten::batch_norm" : _batch_norm(), + "aten::instance_norm" : _instance_norm(), + "aten::layer_norm" : _layer_norm(), + "aten::transpose" : _transpose(), + "aten::transpose_" : _transpose(), + "aten::t" : _transpose(), + "aten::flatten" : _flatten(), + "aten::addmm" : _dense(), + "aten::size" : _size(), + "aten::view" : _view(), + "aten::reshape" : _reshape(), + "aten::clone" : _clone(), + "aten::log_softmax" : _log_softmax(), + "aten::sigmoid" : _sigmoid(), + "aten::softplus" : _softplus(), + "aten::avg_pool2d" : _avg_pool2d(), + "aten::avg_pool3d" : _avg_pool3d(), + "aten::dropout" : _dropout(), + "aten::dropout_" : _dropout(), + "aten::feature_dropout" : _dropout(), + "aten::alpha_dropout" : _dropout(), + "aten::mean" : _mean(), + "aten::chunk" : _chunk(), + "aten::matmul" : _matmul(), + "aten::expand" : _expand(), + "aten::Int" : _int(), + "prim::NumToTensor" : _numtotensor(), + "prim::ListUnpack" : _identity(), + "aten::constant_pad_nd" : _pad(), + "aten::permute" : _transpose(), + "aten::sum" : _reduce("sum"), + "aten::prod" : _reduce("prod"), + "aten::sqrt" : _sqrt(), + 'aten::floor' : _floor(), + "aten::detach" : _identity(), + "aten::upsample_bilinear2d" : _upsample("bilinear"), + "aten::upsample_nearest2d" : _upsample("nearest_neighbor"), + "aten::expand_as" : _expand_as(), + "aten::lt" : _elemwise("less"), + "aten::gt" : _elemwise("greater"), + "aten::le" : _elemwise("less_equal"), + "aten::ge" : _elemwise("greater_equal"), + "aten::ne" : _elemwise("not_equal"), + "aten::Bool" : _Bool(), + "aten::Float" : _Float(), + "aten::neg" : _neg(), + "aten::tanh" : _tanh(), + "aten::adaptive_avg_pool3d" : _adaptive_avg_pool_3d(), + "aten::adaptive_max_pool3d" : _adaptive_max_pool_3d(), + "aten::stack" : _stack(), + "aten::mm" : _matmul(), + "relay::empty_list" : _empty_list(prelude), + "relay::cons_list" : _cons_list(prelude), + "relay::rev_list" : _rev_list(prelude), + "relay::tensor_array_stack" : _tensor_array_stack(prelude), + } + return convert_map def _run_jit_passes(graph): @@ -1289,13 +1341,13 @@ def _get_op_inputs(op_node, outputs): return [outputs[name] for name in _get_input_names(op_node)] -def _report_missing_conversion(op_names): +def _report_missing_conversion(op_names, convert_map): """ Check if all ops in an input graph are supported by TVM """ known_ops = ["prim::Constant", "prim::GetAttr", "prim::ListConstruct", "prim::ListUnpack", "prim::TupleConstruct", "prim::TupleUnpack", "prim::If", "prim::Loop"] - known_ops += list(_convert_map.keys()) + known_ops += list(convert_map.keys()) known_ops += list(qnn_torch.convert_map.keys()) missing = [op_name for op_name in op_names @@ -1422,7 +1474,7 @@ def _get_operator_nodes(nodes): return ops -def _get_relay_input_vars(graph, input_shapes): +def _get_relay_input_vars(graph, input_shapes, input_types): """ Return Relay vars from input shapes and create entries based on expected graph inputs - to allow translation @@ -1437,6 +1489,68 @@ def _get_relay_input_vars(graph, input_shapes): return input_vars +def _rewrite_for_tensor_array(graph): + def has_kind(chain, kind): + return any([node.kind() == kind for node in chain]) + + def needs_rewrite(chain): + return has_kind(chain, "aten::stack") and has_kind(chain, "prim::Loop") + + def get_node(node_list, kind, filter_func=lambda node: True): + for node in node_list: + if node.kind() == kind and filter_func(node): + return node + assert False + return None + + def node_type(node): + return str(node.output().type()) + + list_construct_ops = graph.findAllNodes("prim::ListConstruct") + tensor_list_ops = [op for op in list_construct_ops + if node_type(op) == "List[Tensor]"] + chains = [] + for tensor_list_op in tensor_list_ops: + chains += get_use_chains(tensor_list_op) + + for chain in [chain for chain in chains if needs_rewrite(chain)]: + tensor_list_op = chain[0] + loop_op = get_node(chain, "prim::Loop") + + empty_list_node = graph.create("relay::empty_list") + empty_list_node.insertBefore(loop_op) + tensor_list_op.replaceAllUsesWith(empty_list_node) + tensor_list_op.destroy() + + rev_list_node = graph.create("relay::rev_list", + [loop_op.outputsAt(0)]) + rev_list_node.insertAfter(loop_op) + + stack_op = get_node(chain, "aten::stack") + tarray_stack_node = graph.create("relay::tensor_array_stack", + [rev_list_node.output()]) + tarray_stack_node.insertBefore(stack_op) + stack_op.replaceAllUsesWith(tarray_stack_node) + stack_op.destroy() + + loop_block = list(loop_op.blocks())[0] + loop_nodes = list(loop_block.nodes()) + + add_op = get_node(loop_nodes, "aten::add_", + lambda node: node_type(node) == "List[Tensor]") + + list_singlton_op = add_op.inputsAt(1).node() + list_singlton_op_input = list_singlton_op.inputsAt(0) + list_singlton_op.output().replaceAllUsesWith(list_singlton_op_input) + list_singlton_op.destroy() + + cons_list_node = graph.create("relay::cons_list", + list(reversed(list(add_op.inputs())))) + cons_list_node.insertBefore(add_op) + add_op.replaceAllUsesWith(cons_list_node) + add_op.destroy() + + def get_use_chains(root_node, terminate=lambda _: False): """ Track a chain of users of this node forward, returning a list of chains @@ -1512,24 +1626,24 @@ def convert_params(graph, state_dict): return params, param_tensors, packed_param_map -def convert_block(block, outputs): +def convert_block(block, outputs, convert_map): """ Translate Torch "Block", used for prim::If and prim::Loop """ ops = _get_operator_nodes(block.nodes()) ret_names = _get_input_names(block.returnNode()) - return convert_operators(ops, outputs, ret_names) + return convert_operators(ops, outputs, ret_names, convert_map) -def convert_if(if_node, outputs): +def convert_if(if_node, outputs, convert_map): """ Translate Torch prim::If to Relay If """ cond = outputs[if_node.inputsAt(0).debugName()] blocks = list(if_node.blocks()) - true_branch = convert_block(blocks[0], outputs) - false_branch = convert_block(blocks[1], outputs) + true_branch = convert_block(blocks[0], outputs, convert_map) + false_branch = convert_block(blocks[1], outputs, convert_map) assert len(true_branch) == 1 and len(false_branch) == 1 return _expr.If(cond, true_branch[0], false_branch[0]) -def convert_loop(loop_node, outputs): +def convert_loop(loop_node, outputs, convert_map): """ Translate Torch prim::Loop to Relay while_loop """ def get_input(index): ivalue = loop_node.inputsAt(index) @@ -1572,7 +1686,7 @@ def body(*current_vals): for (i, iname) in enumerate(block_input_names): outputs[iname] = current_vals[i] - block_outputs = convert_block(body_block, outputs) + block_outputs = convert_block(body_block, outputs, convert_map) if not is_while_loop: # iter var increment implicit in torch, so do it manually @@ -1614,7 +1728,7 @@ def get_var(name, val): return [_expr.TupleGetItem(loop_val, i+1) for i in range(num_loop_var)] -def convert_operators(operators, outputs, ret_names): +def convert_operators(operators, outputs, ret_names, convert_map): """ Convert each Torch IR operators to Relay equivalent """ for node_name, op_node in operators: operator = op_node.kind() @@ -1631,15 +1745,15 @@ def convert_operators(operators, outputs, ret_names): unpacked_names = _get_output_names(op_node) outputs.update(zip(unpacked_names, inputs[0])) elif operator == "prim::If": - if_out = convert_if(op_node, outputs) + if_out = convert_if(op_node, outputs, convert_map) outputs[node_name] = if_out elif operator == "prim::Loop": - loop_out = convert_loop(op_node, outputs) + loop_out = convert_loop(op_node, outputs, convert_map) unpacked_names = _get_output_names(op_node) assert len(loop_out) == len(unpacked_names) outputs.update(zip(unpacked_names, loop_out)) else: - relay_op = _convert_map[operator] + relay_op = convert_map[operator] relay_out = relay_op(inputs, _get_input_types(op_node)) if isinstance(relay_out, tuple): @@ -1674,7 +1788,8 @@ def _get_graph_input_names(graph): return ir_inputs[1:] # remove self at the 0th arg -def from_pytorch(script_module, input_shapes, custom_convert_map=None): +def from_pytorch(script_module, input_shapes, + input_types=[], custom_convert_map=None): """ Load PyTorch model in the form of a scripted PyTorch model and convert into relay. The companion parameters will be handled automatically. @@ -1700,18 +1815,23 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None): params : dict of str to tvm.runtime.NDArray Dict of converted parameters stored in tvm.runtime.ndarray format """ + mod = tvm.IRModule() + p = Prelude(mod) + + convert_map = get_convert_map(p) + graph = script_module.graph.copy() _run_jit_passes(graph) if custom_convert_map: - _convert_map.update(custom_convert_map) + convert_map.update(custom_convert_map) op_names = get_all_op_names(graph) - _report_missing_conversion(op_names) + _report_missing_conversion(op_names, convert_map) _check_inputs(graph, input_shapes) params = script_module.state_dict() - outputs = _get_relay_input_vars(graph, input_shapes) + outputs = _get_relay_input_vars(graph, input_shapes, input_types) param_vars, tensors, packed_param_map = convert_params(graph, params) tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()} @@ -1726,14 +1846,14 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None): packed_param_map, weight_quant_params) qnn_torch.add_quant_params(tvm_params, weight_quant_params) - _convert_map.update(qnn_torch.convert_map) + convert_map.update(qnn_torch.convert_map) ret = convert_operators(_get_operator_nodes(graph.nodes()), - outputs, ret_name) + outputs, ret_name, convert_map) if isinstance(ret[0], list): ret[0] = _expr.Tuple(ret[0]) - func = tvm.relay.Function(_analysis.free_vars(ret[0]), ret[0]) + mod["main"] = tvm.relay.Function(_analysis.free_vars(ret[0]), ret[0]) - return _module.IRModule.from_expr(func), tvm_params + return mod, tvm_params From ff440eaaaaa7b3fe0c23b6ed6e8b58c97d936324 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 6 Apr 2020 09:28:26 +0900 Subject: [PATCH 02/31] get relay input type from user ishape --- python/tvm/relay/frontend/pytorch.py | 37 +++++++++++++++++----------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 6b74e8040891..153e0486aee1 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -30,6 +30,7 @@ from .. import analysis as _analysis from .. import expr as _expr from .. import op as _op +from ..ty import TupleType, TensorType from ..loops import while_loop from .common import get_relay_op from .common import infer_shape as _infer_shape @@ -1474,15 +1475,32 @@ def _get_operator_nodes(nodes): return ops -def _get_relay_input_vars(graph, input_shapes, input_types): +def _get_graph_input_names(graph): + """ Get the graph input names (use after graph copy and run jit passes) """ + # Variable names could change the first time a copy is made and after + # _run_jit_passes is called, expected that those functions already invoked + ir_inputs = _get_input_names(graph) + return ir_inputs[1:] # remove self at the 0th arg + + +def _get_relay_input_vars(graph, input_shapes): """ Return Relay vars from input shapes and create entries based on expected graph inputs - to allow translation """ + def get_relay_ty(tup): + if _is_int_seq(tup): + return TensorType(tup) + elif isinstance(tup, tuple): + # tuple of tuple + return TupleType([get_relay_ty(elem) for elem in tup]) + raise NotImplementedError("Only int tuple or tuple of int tuple supported") + + input_types = [(tup[0], get_relay_ty(tup[1])) for tup in input_shapes] input_vars = {} ir_inputs = _get_graph_input_names(graph) - for ir_input, (name, shape) in zip(ir_inputs, input_shapes): - inp = _expr.var(name, shape=shape) + for ir_input, (name, itype) in zip(ir_inputs, input_types): + inp = _expr.var(name, type_annotation=itype) # Translate from graph input to user input name input_vars[ir_input] = inp @@ -1780,16 +1798,7 @@ def get_all_op_names(graph): return set(node.kind() for node in nodes) -def _get_graph_input_names(graph): - """ Get the graph input names (use after graph copy and run jit passes) """ - # Variable names could change the first time a copy is made and after - # _run_jit_passes is called, expected that those functions already invoked - ir_inputs = _get_input_names(graph) - return ir_inputs[1:] # remove self at the 0th arg - - -def from_pytorch(script_module, input_shapes, - input_types=[], custom_convert_map=None): +def from_pytorch(script_module, input_shapes, custom_convert_map=None): """ Load PyTorch model in the form of a scripted PyTorch model and convert into relay. The companion parameters will be handled automatically. @@ -1831,7 +1840,7 @@ def from_pytorch(script_module, input_shapes, _check_inputs(graph, input_shapes) params = script_module.state_dict() - outputs = _get_relay_input_vars(graph, input_shapes, input_types) + outputs = _get_relay_input_vars(graph, input_shapes) param_vars, tensors, packed_param_map = convert_params(graph, params) tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()} From 21be771c75ba0f820ea192ca67cc3705e104fd34 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 6 Apr 2020 11:56:05 +0900 Subject: [PATCH 03/31] handle tuple unpack --- python/tvm/relay/frontend/pytorch.py | 30 ++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 153e0486aee1..b654c53326b8 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -25,7 +25,6 @@ import numpy as np import tvm -from tvm.ir import module as _module from .. import analysis as _analysis from .. import expr as _expr @@ -1507,6 +1506,18 @@ def get_relay_ty(tup): return input_vars +def _unpack_tuple(tup): + def unpack(tup, num_fields): + return [_expr.TupleGetItem(tup, i) for i in range(num_fields)] + + if isinstance(tup, _expr.Tuple): + return unpack(tup, len(tup.fields)) + elif isinstance(tup.type_annotation, TupleType): + return unpack(tup, len(tup.type_annotation.fields)) + else: + assert False + + def _rewrite_for_tensor_array(graph): def has_kind(chain, kind): return any([node.kind() == kind for node in chain]) @@ -1718,6 +1729,10 @@ def body(*current_vals): def get_var(name, val): if isinstance(val, _expr.Constant): return _expr.var(name, shape=val.data.shape, dtype=val.data.dtype) + if isinstance(val, _expr.Var): + return _expr.var(name, type_annotation=val.type_annotation) + if isinstance(val, list): + assert False return _expr.var(name) if is_while_loop: @@ -1756,12 +1771,17 @@ def convert_operators(operators, outputs, ret_names, convert_map): outputs[node_name] = _get_constant(op_node) elif operator == 'prim::ListConstruct' and _is_int_seq(inputs): outputs[node_name] = _expr.var(node_name, shape=inputs) - elif operator in ['prim::ListConstruct', 'prim::TupleConstruct']: + elif operator == 'prim::ListConstruct': outputs[node_name] = inputs + elif operator == 'prim::TupleConstruct': + outputs[node_name] = _expr.Tuple(inputs) elif operator in ["prim::ListUnpack", 'prim::TupleUnpack']: assert len(inputs) == 1 - unpacked_names = _get_output_names(op_node) - outputs.update(zip(unpacked_names, inputs[0])) + if isinstance(inputs[0], list): + unpacked = inputs[0] + else: + unpacked = _unpack_tuple(inputs[0]) + outputs.update(zip(_get_output_names(op_node), unpacked)) elif operator == "prim::If": if_out = convert_if(op_node, outputs, convert_map) outputs[node_name] = if_out @@ -1831,6 +1851,8 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None): graph = script_module.graph.copy() _run_jit_passes(graph) + _rewrite_for_tensor_array(graph) + print(graph) if custom_convert_map: convert_map.update(custom_convert_map) From cf0af1b3f0256eb57a2dc7192de11d0234501931 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 6 Apr 2020 12:43:58 +0900 Subject: [PATCH 04/31] experimenting with static tensor array --- python/tvm/relay/frontend/pytorch.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index b654c53326b8..38c20ebfd2d1 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1101,25 +1101,29 @@ def _impl(inputs, input_types): return _impl -def _cons_list(prelude): +def _rev_list(prelude): def _impl(inputs, input_types): - tensor2 = prelude.get_var('tensor2', "float32") - return prelude.cons(tensor2(inputs[0]), inputs[1]) + return prelude.rev(inputs[0]) return _impl -def _rev_list(prelude): +def _cons_list(prelude): def _impl(inputs, input_types): - return prelude.rev(inputs[0]) + shape = _infer_shape(inputs[0]) + static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape) + static_tensor_array_ops.register() + tensor = prelude.get_var_static('tensor_constructor', "float32", shape) + return prelude.cons(tensor(inputs[0]), inputs[1]) return _impl def _tensor_array_stack(prelude): def _impl(inputs, input_types): - stack = prelude.get_var('tensor_array_stack', "float32") + # print(prelude.mod) + # TODO: how to get the fixed shape of static_tensor_array inputs[0]? + stack = prelude.get_var_static('tensor_array_stack', "float32", (2, 4)) stacked = stack(inputs[0]) - get_tensor_func = prelude.get_var("get_tensor2", "float32") - return get_tensor_func(stacked) + return stacked return _impl # Helper functions for operator implementation From 24c22f784d9add8757a4d5899134077691c49f33 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 6 Apr 2020 19:51:11 +0900 Subject: [PATCH 05/31] use prelude concat instead of cons + rev --- python/tvm/relay/frontend/pytorch.py | 118 +++++++-------------------- 1 file changed, 28 insertions(+), 90 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 38c20ebfd2d1..80348af715aa 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1080,15 +1080,6 @@ def _impl(inputs, input_types): return _impl -def _stack(): - def _impl(inputs, input_types): - if isinstance(inputs[0], list): - return _op.tensor.stack(inputs[0], 0) - else: - return _wrap_const(1) - return _impl - - def _mm(): def _impl(inputs, input_types): return _op.nn.dense(inputs[0], inputs[1]) @@ -1101,19 +1092,25 @@ def _impl(inputs, input_types): return _impl -def _rev_list(prelude): +def _add_(prelude): def _impl(inputs, input_types): - return prelude.rev(inputs[0]) - return _impl + if isinstance(inputs[1], list): + # list concat op + # inputs[0] is ADT list (the number of elem changes at runtime) + # inputs[1] is python list (static list) + if len(inputs[1]) == 0: + return inputs[0] + shape = _infer_shape(inputs[1][0]) + static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape) + static_tensor_array_ops.register() + tensor_create = prelude.get_var_static('tensor_constructor', "float32", shape) -def _cons_list(prelude): - def _impl(inputs, input_types): - shape = _infer_shape(inputs[0]) - static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape) - static_tensor_array_ops.register() - tensor = prelude.get_var_static('tensor_constructor', "float32", shape) - return prelude.cons(tensor(inputs[0]), inputs[1]) + rhs = prelude.nil() + for elem in reversed(inputs[1]): + rhs = prelude.cons(tensor_create(elem), rhs) + return prelude.concat(inputs[0], rhs) + return _elemwise("add")(inputs, input_types) return _impl @@ -1121,11 +1118,13 @@ def _tensor_array_stack(prelude): def _impl(inputs, input_types): # print(prelude.mod) # TODO: how to get the fixed shape of static_tensor_array inputs[0]? - stack = prelude.get_var_static('tensor_array_stack', "float32", (2, 4)) + shape = (2, 4) + stack = prelude.get_var_static('tensor_array_stack', "float32", shape) stacked = stack(inputs[0]) return stacked return _impl + # Helper functions for operator implementation def _convert_dtype_value(val): convert_torch_dtype_map = {7:"torch.float64", @@ -1206,7 +1205,6 @@ def get_convert_map(prelude): convert_map = { "aten::device" : _none(), "aten::add" : _elemwise("add"), - "aten::add_" : _elemwise("add"), "aten::sub" : _elemwise("subtract"), "aten::sub_" : _elemwise("subtract"), "aten::max" : _elemwise("maximum"), @@ -1280,7 +1278,6 @@ def get_convert_map(prelude): "aten::expand" : _expand(), "aten::Int" : _int(), "prim::NumToTensor" : _numtotensor(), - "prim::ListUnpack" : _identity(), "aten::constant_pad_nd" : _pad(), "aten::permute" : _transpose(), "aten::sum" : _reduce("sum"), @@ -1522,68 +1519,6 @@ def unpack(tup, num_fields): assert False -def _rewrite_for_tensor_array(graph): - def has_kind(chain, kind): - return any([node.kind() == kind for node in chain]) - - def needs_rewrite(chain): - return has_kind(chain, "aten::stack") and has_kind(chain, "prim::Loop") - - def get_node(node_list, kind, filter_func=lambda node: True): - for node in node_list: - if node.kind() == kind and filter_func(node): - return node - assert False - return None - - def node_type(node): - return str(node.output().type()) - - list_construct_ops = graph.findAllNodes("prim::ListConstruct") - tensor_list_ops = [op for op in list_construct_ops - if node_type(op) == "List[Tensor]"] - chains = [] - for tensor_list_op in tensor_list_ops: - chains += get_use_chains(tensor_list_op) - - for chain in [chain for chain in chains if needs_rewrite(chain)]: - tensor_list_op = chain[0] - loop_op = get_node(chain, "prim::Loop") - - empty_list_node = graph.create("relay::empty_list") - empty_list_node.insertBefore(loop_op) - tensor_list_op.replaceAllUsesWith(empty_list_node) - tensor_list_op.destroy() - - rev_list_node = graph.create("relay::rev_list", - [loop_op.outputsAt(0)]) - rev_list_node.insertAfter(loop_op) - - stack_op = get_node(chain, "aten::stack") - tarray_stack_node = graph.create("relay::tensor_array_stack", - [rev_list_node.output()]) - tarray_stack_node.insertBefore(stack_op) - stack_op.replaceAllUsesWith(tarray_stack_node) - stack_op.destroy() - - loop_block = list(loop_op.blocks())[0] - loop_nodes = list(loop_block.nodes()) - - add_op = get_node(loop_nodes, "aten::add_", - lambda node: node_type(node) == "List[Tensor]") - - list_singlton_op = add_op.inputsAt(1).node() - list_singlton_op_input = list_singlton_op.inputsAt(0) - list_singlton_op.output().replaceAllUsesWith(list_singlton_op_input) - list_singlton_op.destroy() - - cons_list_node = graph.create("relay::cons_list", - list(reversed(list(add_op.inputs())))) - cons_list_node.insertBefore(add_op) - add_op.replaceAllUsesWith(cons_list_node) - add_op.destroy() - - def get_use_chains(root_node, terminate=lambda _: False): """ Track a chain of users of this node forward, returning a list of chains @@ -1773,13 +1708,18 @@ def convert_operators(operators, outputs, ret_names, convert_map): if operator == "prim::Constant": outputs[node_name] = _get_constant(op_node) - elif operator == 'prim::ListConstruct' and _is_int_seq(inputs): + elif operator == "prim::ListConstruct" and _is_int_seq(inputs): outputs[node_name] = _expr.var(node_name, shape=inputs) - elif operator == 'prim::ListConstruct': + elif operator == "prim::ListConstruct" and len(inputs) > 0: # static + # This assumes that no more elements will be appended to this list outputs[node_name] = inputs - elif operator == 'prim::TupleConstruct': + elif operator == "prim::ListConstruct": # dynamic + # %outputs : Tensor[] = prim::ListConstruct() + relay_op = convert_map["relay::empty_list"] + outputs[node_name] = relay_op(inputs, _get_input_types(op_node)) + elif operator == "prim::TupleConstruct": outputs[node_name] = _expr.Tuple(inputs) - elif operator in ["prim::ListUnpack", 'prim::TupleUnpack']: + elif operator in ["prim::ListUnpack", "prim::TupleUnpack"]: assert len(inputs) == 1 if isinstance(inputs[0], list): unpacked = inputs[0] @@ -1855,8 +1795,6 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None): graph = script_module.graph.copy() _run_jit_passes(graph) - _rewrite_for_tensor_array(graph) - print(graph) if custom_convert_map: convert_map.update(custom_convert_map) From 8d59ae61cd9dfc3fc60479758b4c1feef29c4d0a Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 6 Apr 2020 20:00:23 +0900 Subject: [PATCH 06/31] minor clean up --- python/tvm/relay/frontend/pytorch.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 80348af715aa..270845bd910e 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1093,6 +1093,17 @@ def _impl(inputs, input_types): def _add_(prelude): + def concat_list(lhs, rhs_static): + shape = _infer_shape(rhs_static[0]) + static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape) + static_tensor_array_ops.register() + tensor_create = prelude.get_var_static('tensor_constructor', "float32", shape) + + rhs = prelude.nil() + for elem in reversed(rhs_static): + rhs = prelude.cons(tensor_create(elem), rhs) + return prelude.concat(lhs, rhs) + def _impl(inputs, input_types): if isinstance(inputs[1], list): # list concat op @@ -1100,16 +1111,7 @@ def _impl(inputs, input_types): # inputs[1] is python list (static list) if len(inputs[1]) == 0: return inputs[0] - - shape = _infer_shape(inputs[1][0]) - static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape) - static_tensor_array_ops.register() - tensor_create = prelude.get_var_static('tensor_constructor', "float32", shape) - - rhs = prelude.nil() - for elem in reversed(inputs[1]): - rhs = prelude.cons(tensor_create(elem), rhs) - return prelude.concat(inputs[0], rhs) + return concat_list(inputs[0], inputs[1]) return _elemwise("add")(inputs, input_types) return _impl From edbc2a4354cd4e5484aa78db20dd43b34dbb6c48 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 7 Apr 2020 06:26:14 +0900 Subject: [PATCH 07/31] fix layer norm conversion bug, unwrap tensor array --- python/tvm/relay/frontend/pytorch.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 270845bd910e..fcffe39e6946 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -29,7 +29,7 @@ from .. import analysis as _analysis from .. import expr as _expr from .. import op as _op -from ..ty import TupleType, TensorType +from ..ty import TupleType, TensorType, Any from ..loops import while_loop from .common import get_relay_op from .common import infer_shape as _infer_shape @@ -1123,7 +1123,12 @@ def _impl(inputs, input_types): shape = (2, 4) stack = prelude.get_var_static('tensor_array_stack', "float32", shape) stacked = stack(inputs[0]) - return stacked + + stacked_shape = (Any(),) + shape + static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape) + static_tensor_array_ops.define_tensor_get_data(stacked_shape) + get_tensor = prelude.get_var_static('tensor_get_data', "float32", shape) + return get_tensor(stacked) return _impl @@ -1798,6 +1803,8 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None): graph = script_module.graph.copy() _run_jit_passes(graph) + print(graph) + if custom_convert_map: convert_map.update(custom_convert_map) From f7ecc75646040bfd88aaae182322cb96b6ea9e3b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 7 Apr 2020 13:27:16 +0900 Subject: [PATCH 08/31] add infer shape on tensor array --- python/tvm/relay/prelude.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index 243eace0fb94..f512b68bcac4 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -29,21 +29,16 @@ def get_tensor_array_shape(expr, dtype, prelude): """Get the static shape of a tensor array if it has fixed rank shape. - By design, static ADT tensor in TVM has type name in the format of static_tensor_dim0_dim1_..._dimN_t. - Parameters ---------- expr : Relay Expr Input expression. - dtype : str Data type. - prelude : Prelude Tensor array prelude - Returns ------- shape : tuple of (int, Any) or None @@ -70,6 +65,7 @@ def get_tensor_array_shape(expr, dtype, prelude): return tuple(shape) return None + def _get_name_static(canonical, dtype, shape): """Get name for static shape tensor array op corresponding to the canonical name""" From b7271827aca05faca88171ede5754eddbb12e78f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 7 Apr 2020 13:34:47 +0900 Subject: [PATCH 09/31] pass around prelude for now --- python/tvm/relay/frontend/pytorch.py | 89 ++++++++++++++++++---------- 1 file changed, 58 insertions(+), 31 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index fcffe39e6946..300e832ea7b7 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -29,12 +29,14 @@ from .. import analysis as _analysis from .. import expr as _expr from .. import op as _op +from ..function import Function +from .. import transform from ..ty import TupleType, TensorType, Any from ..loops import while_loop from .common import get_relay_op from .common import infer_shape as _infer_shape from .common import infer_value as _infer_value -from ..prelude import Prelude, StaticTensorArrayOps +from ..prelude import Prelude, StaticTensorArrayOps, get_tensor_array_shape from . import qnn_torch @@ -896,7 +898,6 @@ def _impl(inputs, input_types): chunk_out = _op.transform.strided_slice(data, begin, end, stride) chunks.append(chunk_out) - if dim % num_chunks: begin = [0] * len(shape) end = shape[:] @@ -1092,16 +1093,26 @@ def _impl(inputs, input_types): return _impl +def _list_getitem(prelude): + def _impl(inputs, input_types): + return prelude.nth(inputs[0], _wrap_const(inputs[1])) + return _impl + + def _add_(prelude): def concat_list(lhs, rhs_static): - shape = _infer_shape(rhs_static[0]) + shape = (2, 4) # _infer_shape(rhs_static[0]) static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape) static_tensor_array_ops.register() tensor_create = prelude.get_var_static('tensor_constructor', "float32", shape) rhs = prelude.nil() for elem in reversed(rhs_static): - rhs = prelude.cons(tensor_create(elem), rhs) + if isinstance(elem, _expr.Tuple): + tup = _expr.Tuple([tensor_create(tup_elem) for tup_elem in elem.fields]) + rhs = prelude.cons(tup, rhs) + else: + rhs = prelude.cons(tensor_create(elem), rhs) return prelude.concat(lhs, rhs) def _impl(inputs, input_types): @@ -1118,15 +1129,14 @@ def _impl(inputs, input_types): def _tensor_array_stack(prelude): def _impl(inputs, input_types): - # print(prelude.mod) - # TODO: how to get the fixed shape of static_tensor_array inputs[0]? - shape = (2, 4) + shape = get_tensor_array_shape(inputs[0], "float32", prelude) stack = prelude.get_var_static('tensor_array_stack', "float32", shape) stacked = stack(inputs[0]) stacked_shape = (Any(),) + shape static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape) static_tensor_array_ops.define_tensor_get_data(stacked_shape) + # passing stacked_shape below gives "'Prelude' object has no attribute" error get_tensor = prelude.get_var_static('tensor_get_data', "float32", shape) return get_tensor(stacked) return _impl @@ -1312,6 +1322,7 @@ def get_convert_map(prelude): "relay::cons_list" : _cons_list(prelude), "relay::rev_list" : _rev_list(prelude), "relay::tensor_array_stack" : _tensor_array_stack(prelude), + "aten::__getitem__" : _list_getitem(prelude), } return convert_map @@ -1490,18 +1501,24 @@ def _get_graph_input_names(graph): return ir_inputs[1:] # remove self at the 0th arg -def _get_relay_input_vars(graph, input_shapes): +def _get_relay_input_vars(graph, input_shapes, prelude): """ Return Relay vars from input shapes and create entries based on expected graph inputs - to allow translation """ - def get_relay_ty(tup): - if _is_int_seq(tup): - return TensorType(tup) - elif isinstance(tup, tuple): - # tuple of tuple - return TupleType([get_relay_ty(elem) for elem in tup]) - raise NotImplementedError("Only int tuple or tuple of int tuple supported") + def get_relay_ty(ishape): + if _is_int_seq(ishape): + return TensorType(ishape) + elif isinstance(ishape, tuple): + # ishapele of ishapele + return TupleType([get_relay_ty(elem) for elem in ishape]) + elif isinstance(ishape, list): + assert len(ishape) > 0 + elem_tys = [get_relay_ty(s) for s in ishape] + msg = "List elements should have identical types" + assert all(map(lambda ty: ty == elem_tys[0], elem_tys)), msg + return prelude.l(elem_tys[0]) + raise NotImplementedError("unsupported input type") input_types = [(tup[0], get_relay_ty(tup[1])) for tup in input_shapes] input_vars = {} @@ -1523,6 +1540,8 @@ def unpack(tup, num_fields): elif isinstance(tup.type_annotation, TupleType): return unpack(tup, len(tup.type_annotation.fields)) else: + # print(type(tup), tup) + return unpack(tup, 2) assert False @@ -1601,24 +1620,24 @@ def convert_params(graph, state_dict): return params, param_tensors, packed_param_map -def convert_block(block, outputs, convert_map): +def convert_block(block, outputs, convert_map, prelude): """ Translate Torch "Block", used for prim::If and prim::Loop """ ops = _get_operator_nodes(block.nodes()) ret_names = _get_input_names(block.returnNode()) - return convert_operators(ops, outputs, ret_names, convert_map) + return convert_operators(ops, outputs, ret_names, convert_map, prelude) -def convert_if(if_node, outputs, convert_map): +def convert_if(if_node, outputs, convert_map, prelude): """ Translate Torch prim::If to Relay If """ cond = outputs[if_node.inputsAt(0).debugName()] blocks = list(if_node.blocks()) - true_branch = convert_block(blocks[0], outputs, convert_map) - false_branch = convert_block(blocks[1], outputs, convert_map) + true_branch = convert_block(blocks[0], outputs, convert_map, prelude) + false_branch = convert_block(blocks[1], outputs, convert_map, prelude) assert len(true_branch) == 1 and len(false_branch) == 1 return _expr.If(cond, true_branch[0], false_branch[0]) -def convert_loop(loop_node, outputs, convert_map): +def convert_loop(loop_node, outputs, convert_map, prelude): """ Translate Torch prim::Loop to Relay while_loop """ def get_input(index): ivalue = loop_node.inputsAt(index) @@ -1661,7 +1680,7 @@ def body(*current_vals): for (i, iname) in enumerate(block_input_names): outputs[iname] = current_vals[i] - block_outputs = convert_block(body_block, outputs, convert_map) + block_outputs = convert_block(body_block, outputs, convert_map, prelude) if not is_while_loop: # iter var increment implicit in torch, so do it manually @@ -1677,8 +1696,15 @@ def get_var(name, val): return _expr.var(name, shape=val.data.shape, dtype=val.data.dtype) if isinstance(val, _expr.Var): return _expr.var(name, type_annotation=val.type_annotation) - if isinstance(val, list): - assert False + + # print("loop var type:", name, val, type(val)) + mod = prelude.mod + func = Function([], val) + mod["main"] = func + mod = transform.InferType()(mod) + checked_type = mod["main"].body.checked_type + print("checked type:", checked_type) + return _expr.var(name) if is_while_loop: @@ -1688,6 +1714,7 @@ def get_var(name, val): if isinstance(init_cond, _expr.Constant): init_cond = _op.cast(init_cond, "bool") init_loop_iter_val = init_cond + else: loop_iter_dtype = "int32" # always count from 0 @@ -1707,7 +1734,7 @@ def get_var(name, val): return [_expr.TupleGetItem(loop_val, i+1) for i in range(num_loop_var)] -def convert_operators(operators, outputs, ret_names, convert_map): +def convert_operators(operators, outputs, ret_names, convert_map, prelude): """ Convert each Torch IR operators to Relay equivalent """ for node_name, op_node in operators: operator = op_node.kind() @@ -1734,10 +1761,10 @@ def convert_operators(operators, outputs, ret_names, convert_map): unpacked = _unpack_tuple(inputs[0]) outputs.update(zip(_get_output_names(op_node), unpacked)) elif operator == "prim::If": - if_out = convert_if(op_node, outputs, convert_map) + if_out = convert_if(op_node, outputs, convert_map, prelude) outputs[node_name] = if_out elif operator == "prim::Loop": - loop_out = convert_loop(op_node, outputs, convert_map) + loop_out = convert_loop(op_node, outputs, convert_map, prelude) unpacked_names = _get_output_names(op_node) assert len(loop_out) == len(unpacked_names) outputs.update(zip(unpacked_names, loop_out)) @@ -1796,9 +1823,9 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None): Dict of converted parameters stored in tvm.runtime.ndarray format """ mod = tvm.IRModule() - p = Prelude(mod) + prelude = Prelude(mod) - convert_map = get_convert_map(p) + convert_map = get_convert_map(prelude) graph = script_module.graph.copy() _run_jit_passes(graph) @@ -1813,7 +1840,7 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None): _check_inputs(graph, input_shapes) params = script_module.state_dict() - outputs = _get_relay_input_vars(graph, input_shapes) + outputs = _get_relay_input_vars(graph, input_shapes, prelude) param_vars, tensors, packed_param_map = convert_params(graph, params) tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()} @@ -1831,7 +1858,7 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None): convert_map.update(qnn_torch.convert_map) ret = convert_operators(_get_operator_nodes(graph.nodes()), - outputs, ret_name, convert_map) + outputs, ret_name, convert_map, prelude) if isinstance(ret[0], list): ret[0] = _expr.Tuple(ret[0]) From a3319f39079378c223dad0d4e23a42356d9a89fd Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 7 Apr 2020 17:42:17 +0900 Subject: [PATCH 10/31] compile worked but runtime error --- python/tvm/relay/frontend/pytorch.py | 34 ++++++++++++++++++---------- 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 300e832ea7b7..11b067032b35 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -43,6 +43,14 @@ __all__ = ["from_pytorch"] +def _infer_type_with_prelude(val, prelude): + mod = prelude.mod + func = Function([], val) + mod["main"] = func + mod = transform.InferType()(mod) + return mod["main"].body.checked_type + + # operator implementation def _elemwise(name): def _impl(inputs, input_types): @@ -1107,12 +1115,21 @@ def concat_list(lhs, rhs_static): tensor_create = prelude.get_var_static('tensor_constructor', "float32", shape) rhs = prelude.nil() + elem_ty = _infer_type_with_prelude(rhs_static[0], prelude) + print("elem_ty:", elem_ty) + for elem in reversed(rhs_static): - if isinstance(elem, _expr.Tuple): - tup = _expr.Tuple([tensor_create(tup_elem) for tup_elem in elem.fields]) + if isinstance(elem_ty, TensorType): + rhs = prelude.cons(tensor_create(elem), rhs) + elif isinstance(elem_ty, TupleType): + print("ty fields:", elem_ty.fields) + msg = "Only a tuple of tensors supported for now" + #assert all(map(lambda ty: ty == TensorType, elem_ty.fields)), msg + tup = _expr.Tuple([tensor_create(_expr.TupleGetItem(elem, i)) + for i in range(len(elem_ty.fields))]) rhs = prelude.cons(tup, rhs) else: - rhs = prelude.cons(tensor_create(elem), rhs) + assert False return prelude.concat(lhs, rhs) def _impl(inputs, input_types): @@ -1540,8 +1557,6 @@ def unpack(tup, num_fields): elif isinstance(tup.type_annotation, TupleType): return unpack(tup, len(tup.type_annotation.fields)) else: - # print(type(tup), tup) - return unpack(tup, 2) assert False @@ -1697,15 +1712,10 @@ def get_var(name, val): if isinstance(val, _expr.Var): return _expr.var(name, type_annotation=val.type_annotation) - # print("loop var type:", name, val, type(val)) - mod = prelude.mod - func = Function([], val) - mod["main"] = func - mod = transform.InferType()(mod) - checked_type = mod["main"].body.checked_type + checked_type = _infer_type_with_prelude(val, prelude) print("checked type:", checked_type) - return _expr.var(name) + return _expr.var(name, type_annotation=checked_type) if is_while_loop: loop_iter_dtype = "bool" From 58e2908f8b9b2430bc6ece46442b1a89708b1fd3 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 9 Apr 2020 08:34:41 +0900 Subject: [PATCH 11/31] fix tensor array wrapping --- python/tvm/relay/frontend/pytorch.py | 53 ++++++++++++++++------------ 1 file changed, 30 insertions(+), 23 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 11b067032b35..a8840797b7cc 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -51,6 +51,29 @@ def _infer_type_with_prelude(val, prelude): return mod["main"].body.checked_type +def _convert_to_list_adt(py_lst, prelude): + elem_tys = [_infer_type_with_prelude(elem, prelude) for elem in py_lst] + msg = "List elements should have identical types" + assert all(map(lambda ty: ty == elem_tys[0], elem_tys)), msg + + adt_lst = prelude.nil() + for elem in reversed(py_lst): + adt_lst = prelude.cons(elem, adt_lst) + return adt_lst + + +def _convert_to_tensor_array(adt_lst, prelude): + if prelude.length(adt_lst) == 0: + return prelude.nil() + + shape = _infer_type_with_prelude(prelude.hd(adt_lst), prelude).shape + static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape) + static_tensor_array_ops.register() + tensor_create = prelude.get_var_static('tensor_constructor', "float32", shape) + + return prelude.map(tensor_create, adt_lst) + + # operator implementation def _elemwise(name): def _impl(inputs, input_types): @@ -1109,27 +1132,8 @@ def _impl(inputs, input_types): def _add_(prelude): def concat_list(lhs, rhs_static): - shape = (2, 4) # _infer_shape(rhs_static[0]) - static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape) - static_tensor_array_ops.register() - tensor_create = prelude.get_var_static('tensor_constructor', "float32", shape) - - rhs = prelude.nil() - elem_ty = _infer_type_with_prelude(rhs_static[0], prelude) - print("elem_ty:", elem_ty) - - for elem in reversed(rhs_static): - if isinstance(elem_ty, TensorType): - rhs = prelude.cons(tensor_create(elem), rhs) - elif isinstance(elem_ty, TupleType): - print("ty fields:", elem_ty.fields) - msg = "Only a tuple of tensors supported for now" - #assert all(map(lambda ty: ty == TensorType, elem_ty.fields)), msg - tup = _expr.Tuple([tensor_create(_expr.TupleGetItem(elem, i)) - for i in range(len(elem_ty.fields))]) - rhs = prelude.cons(tup, rhs) - else: - assert False + # TODO: check lhs is an ADT list + rhs = _convert_to_list_adt(rhs_static, prelude) return prelude.concat(lhs, rhs) def _impl(inputs, input_types): @@ -1146,9 +1150,12 @@ def _impl(inputs, input_types): def _tensor_array_stack(prelude): def _impl(inputs, input_types): - shape = get_tensor_array_shape(inputs[0], "float32", prelude) + # TODO: check inputs[0] is List[TensorType] + # assert type_equal(inputs[0], prelude.l(TensorType)) + tensor_array = _convert_to_tensor_array(inputs[0], prelude) + shape = get_tensor_array_shape(tensor_array, "float32", prelude) stack = prelude.get_var_static('tensor_array_stack', "float32", shape) - stacked = stack(inputs[0]) + stacked = stack(tensor_array) stacked_shape = (Any(),) + shape static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape) From a973954ca461b02055ae7840f7fd728e955f2ad3 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 9 Apr 2020 08:48:12 +0900 Subject: [PATCH 12/31] begin list dynamic test --- python/tvm/relay/frontend/pytorch.py | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index a8840797b7cc..3a3f456cd667 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1118,12 +1118,6 @@ def _impl(inputs, input_types): return _impl -def _empty_list(prelude): - def _impl(inputs, input_types): - return prelude.nil() - return _impl - - def _list_getitem(prelude): def _impl(inputs, input_types): return prelude.nth(inputs[0], _wrap_const(inputs[1])) @@ -1340,12 +1334,10 @@ def get_convert_map(prelude): "aten::tanh" : _tanh(), "aten::adaptive_avg_pool3d" : _adaptive_avg_pool_3d(), "aten::adaptive_max_pool3d" : _adaptive_max_pool_3d(), - "aten::stack" : _stack(), "aten::mm" : _matmul(), - "relay::empty_list" : _empty_list(prelude), - "relay::cons_list" : _cons_list(prelude), - "relay::rev_list" : _rev_list(prelude), "relay::tensor_array_stack" : _tensor_array_stack(prelude), + "aten::add_" : _add_(prelude), + "aten::stack" : _tensor_array_stack(prelude), "aten::__getitem__" : _list_getitem(prelude), } return convert_map @@ -1607,6 +1599,10 @@ def terminate(users): return get_use_chains(root_getattr_node, terminate) +def is_list_dynamic(list_construct_node): + return False + + def convert_params(graph, state_dict): """ Return Relay vars and TVM NDArrays for input parameters @@ -1761,13 +1757,13 @@ def convert_operators(operators, outputs, ret_names, convert_map, prelude): outputs[node_name] = _get_constant(op_node) elif operator == "prim::ListConstruct" and _is_int_seq(inputs): outputs[node_name] = _expr.var(node_name, shape=inputs) - elif operator == "prim::ListConstruct" and len(inputs) > 0: # static + elif operator == "prim::ListConstruct" and is_list_dynamic(op_node): + outputs[node_name] = _convert_to_list_adt(inputs, prelude) + elif operator == "prim::ListConstruct": + assert len(inputs) > 0, "An empty static list found" # This assumes that no more elements will be appended to this list + # In this case, we keep the Python list outputs[node_name] = inputs - elif operator == "prim::ListConstruct": # dynamic - # %outputs : Tensor[] = prim::ListConstruct() - relay_op = convert_map["relay::empty_list"] - outputs[node_name] = relay_op(inputs, _get_input_types(op_node)) elif operator == "prim::TupleConstruct": outputs[node_name] = _expr.Tuple(inputs) elif operator in ["prim::ListUnpack", "prim::TupleUnpack"]: From 0bef1fa2a18f9507d806745e427f35ee036eca04 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 9 Apr 2020 10:09:40 +0900 Subject: [PATCH 13/31] is_list_dynamic first version --- python/tvm/relay/frontend/pytorch.py | 40 +++++++++++++++++++++++++--- 1 file changed, 37 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 3a3f456cd667..02ceff4a4d91 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1376,6 +1376,22 @@ def _get_op_inputs(op_node, outputs): return [outputs[name] for name in _get_input_names(op_node)] +def _get_node_type(node): + assert node.outputsSize() == 1 + return node.output().type().kind() + + +def _get_uses(node): + uses = [] + for output in node.outputs(): + uses += output.uses() + return uses + + +def _get_users(node): + return [use.user for use in _get_uses(node)] + + def _report_missing_conversion(op_names, convert_map): """ Check if all ops in an input graph are supported by TVM """ known_ops = ["prim::Constant", "prim::GetAttr", @@ -1568,9 +1584,7 @@ def concat_lists(lists): return itertools.chain.from_iterable(lists) def inner(current, accum): - users = [] - for output in current.outputs(): - users += [use.user for use in output.uses()] + users = _get_users(current) if not users or terminate(users): return [accum] @@ -1600,6 +1614,25 @@ def terminate(users): def is_list_dynamic(list_construct_node): + uses = _get_uses(list_construct_node) + + for loop_use in filter(lambda use: use.user.kind() == "prim::Loop", uses): + block_input_index = loop_use.offset - 1 + block = list(loop_use.user.blocks())[0] + list_loop_var = list(block.inputs())[block_input_index] + uses += _get_uses(list_loop_var.node()) + + op_names = set(use.user.kind() for use in uses) + list_ops = set(["aten::add_", "aten::__getitem__", "aten::stack"]) + intersect = list_ops.intersection(op_names) + + if len(intersect) > 0 and intersect != set(["aten::add_"]): + print("list op", list_construct_node) + return True + if intersect == set(["aten::add_"]) and _get_node_type(list_construct_node) == "ListType": + print("add_ found and it is list", list_construct_node) + return True + return False @@ -1760,6 +1793,7 @@ def convert_operators(operators, outputs, ret_names, convert_map, prelude): elif operator == "prim::ListConstruct" and is_list_dynamic(op_node): outputs[node_name] = _convert_to_list_adt(inputs, prelude) elif operator == "prim::ListConstruct": + print(op_node) assert len(inputs) > 0, "An empty static list found" # This assumes that no more elements will be appended to this list # In this case, we keep the Python list From 0c56041d07fb2a5a57a5b822fcf7ace7f0950d62 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 9 Apr 2020 11:04:39 +0900 Subject: [PATCH 14/31] finish dynamic list test --- python/tvm/relay/frontend/pytorch.py | 32 ++++++++++------------------ 1 file changed, 11 insertions(+), 21 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 02ceff4a4d91..aa14b4afc8d0 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1125,20 +1125,15 @@ def _impl(inputs, input_types): def _add_(prelude): - def concat_list(lhs, rhs_static): - # TODO: check lhs is an ADT list - rhs = _convert_to_list_adt(rhs_static, prelude) - return prelude.concat(lhs, rhs) - + """ + add_ is overloaded for list concat, like below + %17 : Tensor[] = prim::ListConstruct(%out.1) + %outputs.3 : Tensor[] = aten::add_(%outputs.6, %17) + """ def _impl(inputs, input_types): - if isinstance(inputs[1], list): - # list concat op - # inputs[0] is ADT list (the number of elem changes at runtime) - # inputs[1] is python list (static list) - if len(inputs[1]) == 0: - return inputs[0] - return concat_list(inputs[0], inputs[1]) - return _elemwise("add")(inputs, input_types) + return prelude.concat(inputs[0], inputs[1]) + # TODO: could inputs[0], and inputs[1] be tensors? + # return _elemwise("add")(inputs, input_types) return _impl @@ -1627,10 +1622,10 @@ def is_list_dynamic(list_construct_node): intersect = list_ops.intersection(op_names) if len(intersect) > 0 and intersect != set(["aten::add_"]): - print("list op", list_construct_node) return True - if intersect == set(["aten::add_"]) and _get_node_type(list_construct_node) == "ListType": - print("add_ found and it is list", list_construct_node) + + output_type = _get_node_type(list_construct_node) + if intersect == set(["aten::add_"]) and output_type == "ListType": return True return False @@ -1749,7 +1744,6 @@ def get_var(name, val): return _expr.var(name, type_annotation=val.type_annotation) checked_type = _infer_type_with_prelude(val, prelude) - print("checked type:", checked_type) return _expr.var(name, type_annotation=checked_type) @@ -1793,8 +1787,6 @@ def convert_operators(operators, outputs, ret_names, convert_map, prelude): elif operator == "prim::ListConstruct" and is_list_dynamic(op_node): outputs[node_name] = _convert_to_list_adt(inputs, prelude) elif operator == "prim::ListConstruct": - print(op_node) - assert len(inputs) > 0, "An empty static list found" # This assumes that no more elements will be appended to this list # In this case, we keep the Python list outputs[node_name] = inputs @@ -1877,8 +1869,6 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None): graph = script_module.graph.copy() _run_jit_passes(graph) - print(graph) - if custom_convert_map: convert_map.update(custom_convert_map) From bb8550414de55e0a90ba3884e97fe757e6e9b84c Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 9 Apr 2020 11:24:52 +0900 Subject: [PATCH 15/31] a few fix --- python/tvm/relay/frontend/pytorch.py | 98 ++++++++----------- tests/python/frontend/pytorch/test_forward.py | 28 +++--- 2 files changed, 56 insertions(+), 70 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index aa14b4afc8d0..375584a9cc8f 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -43,6 +43,7 @@ __all__ = ["from_pytorch"] +# List ADT utilities def _infer_type_with_prelude(val, prelude): mod = prelude.mod func = Function([], val) @@ -74,6 +75,30 @@ def _convert_to_tensor_array(adt_lst, prelude): return prelude.map(tensor_create, adt_lst) +def _should_construct_dynamic_list(list_construct_node): + # if this list is element-accessed or modified at runtime, generate List ADT + uses = _get_uses(list_construct_node) + + for loop_use in filter(lambda use: use.user.kind() == "prim::Loop", uses): + block_input_index = loop_use.offset - 1 + block = list(loop_use.user.blocks())[0] + list_loop_var = list(block.inputs())[block_input_index] + uses += _get_uses(list_loop_var.node()) + + op_names = set(use.user.kind() for use in uses) + list_ops = set(["aten::add_", "aten::__getitem__", "aten::stack"]) + intersect = list_ops.intersection(op_names) + + if len(intersect) > 0 and intersect != set(["aten::add_"]): + return True + + output_type = _get_node_type(list_construct_node) + if intersect == set(["aten::add_"]) and output_type == "ListType": + return True + + return False + + # operator implementation def _elemwise(name): def _impl(inputs, input_types): @@ -165,7 +190,7 @@ def _impl(inputs, input_types): else: end = data.shape - begin = [0]*len(end) + begin = [0] * len(end) dim = int(inputs[1]) begin[dim] = int(inputs[2]) @@ -406,7 +431,7 @@ def _impl(inputs, input_types): ceil_mode = int(inputs[5]) if dilation != (1, 1): - msg = "MaxPool2d with dilation %s is not implemented" % (str(dilation), ) + msg = "MaxPool2d with dilation %s is not implemented" % (str(dilation)) raise NotImplementedError(msg) return _op.nn.max_pool2d(data, pool_size, strides, padding, "NCHW", ceil_mode) @@ -423,7 +448,7 @@ def _impl(inputs, input_types): ceil_mode = int(inputs[5]) if dilation != (1,): - msg = "MaxPool1d with dilation %s is not implemented" % (str(dilation), ) + msg = "MaxPool1d with dilation %s is not implemented" % (str(dilation)) raise NotImplementedError(msg) return _op.nn.max_pool1d(data, pool_size, strides, padding, "NCW", ceil_mode) @@ -439,7 +464,7 @@ def _impl(inputs, input_types): dilation = _infer_shape(inputs[4]) ceil_mode = int(inputs[5]) if dilation != (1, 1, 1): - msg = "MaxPool3d with dilation %s is not implemented" % (str(dilation), ) + msg = "MaxPool3d with dilation %s is not implemented" % (str(dilation)) raise NotImplementedError(msg) return _op.nn.max_pool3d(data, @@ -1125,22 +1150,17 @@ def _impl(inputs, input_types): def _add_(prelude): - """ - add_ is overloaded for list concat, like below - %17 : Tensor[] = prim::ListConstruct(%out.1) - %outputs.3 : Tensor[] = aten::add_(%outputs.6, %17) - """ + # add_ is overloaded for tensor add and list concat def _impl(inputs, input_types): - return prelude.concat(inputs[0], inputs[1]) - # TODO: could inputs[0], and inputs[1] be tensors? - # return _elemwise("add")(inputs, input_types) + if input_types[0] == "ListType": + return prelude.concat(inputs[0], inputs[1]) + return _elemwise("add")(inputs, input_types) return _impl def _tensor_array_stack(prelude): def _impl(inputs, input_types): - # TODO: check inputs[0] is List[TensorType] - # assert type_equal(inputs[0], prelude.l(TensorType)) + # TODO: check inputs[0] is a ADT List[TensorType] tensor_array = _convert_to_tensor_array(inputs[0], prelude) shape = get_tensor_array_shape(tensor_array, "float32", prelude) stack = prelude.get_var_static('tensor_array_stack', "float32", shape) @@ -1231,7 +1251,7 @@ def _wrap_const(c): return c # Operator mappings -def get_convert_map(prelude): +def _get_convert_map(prelude): convert_map = { "aten::device" : _none(), "aten::add" : _elemwise("add"), @@ -1459,7 +1479,7 @@ def _get_input_types(op_node): input_list_types.append(in_ty.scalarType().lower()) elif input_node_kind == 'ListType': - input_list_types.append(str(in_ty.getElementType()).lower()) + input_list_types.append("ListType") elif input_node_kind in ['IntType', 'FloatType', 'BoolType', 'StringType', 'OptionalType']: input_list_types.append(str(in_ty).lower()) @@ -1534,10 +1554,9 @@ def _get_relay_input_vars(graph, input_shapes, prelude): expected graph inputs - to allow translation """ def get_relay_ty(ishape): - if _is_int_seq(ishape): + if _is_int_seq(ishape) or len(ishape) == 0: return TensorType(ishape) elif isinstance(ishape, tuple): - # ishapele of ishapele return TupleType([get_relay_ty(elem) for elem in ishape]) elif isinstance(ishape, list): assert len(ishape) > 0 @@ -1566,8 +1585,8 @@ def unpack(tup, num_fields): return unpack(tup, len(tup.fields)) elif isinstance(tup.type_annotation, TupleType): return unpack(tup, len(tup.type_annotation.fields)) - else: - assert False + # shouldn't happen + assert False def get_use_chains(root_node, terminate=lambda _: False): @@ -1608,29 +1627,6 @@ def terminate(users): return get_use_chains(root_getattr_node, terminate) -def is_list_dynamic(list_construct_node): - uses = _get_uses(list_construct_node) - - for loop_use in filter(lambda use: use.user.kind() == "prim::Loop", uses): - block_input_index = loop_use.offset - 1 - block = list(loop_use.user.blocks())[0] - list_loop_var = list(block.inputs())[block_input_index] - uses += _get_uses(list_loop_var.node()) - - op_names = set(use.user.kind() for use in uses) - list_ops = set(["aten::add_", "aten::__getitem__", "aten::stack"]) - intersect = list_ops.intersection(op_names) - - if len(intersect) > 0 and intersect != set(["aten::add_"]): - return True - - output_type = _get_node_type(list_construct_node) - if intersect == set(["aten::add_"]) and output_type == "ListType": - return True - - return False - - def convert_params(graph, state_dict): """ Return Relay vars and TVM NDArrays for input parameters @@ -1738,13 +1734,7 @@ def body(*current_vals): return block_outputs def get_var(name, val): - if isinstance(val, _expr.Constant): - return _expr.var(name, shape=val.data.shape, dtype=val.data.dtype) - if isinstance(val, _expr.Var): - return _expr.var(name, type_annotation=val.type_annotation) - checked_type = _infer_type_with_prelude(val, prelude) - return _expr.var(name, type_annotation=checked_type) if is_while_loop: @@ -1754,7 +1744,6 @@ def get_var(name, val): if isinstance(init_cond, _expr.Constant): init_cond = _op.cast(init_cond, "bool") init_loop_iter_val = init_cond - else: loop_iter_dtype = "int32" # always count from 0 @@ -1784,7 +1773,7 @@ def convert_operators(operators, outputs, ret_names, convert_map, prelude): outputs[node_name] = _get_constant(op_node) elif operator == "prim::ListConstruct" and _is_int_seq(inputs): outputs[node_name] = _expr.var(node_name, shape=inputs) - elif operator == "prim::ListConstruct" and is_list_dynamic(op_node): + elif operator == "prim::ListConstruct" and _should_construct_dynamic_list(op_node): outputs[node_name] = _convert_to_list_adt(inputs, prelude) elif operator == "prim::ListConstruct": # This assumes that no more elements will be appended to this list @@ -1794,7 +1783,7 @@ def convert_operators(operators, outputs, ret_names, convert_map, prelude): outputs[node_name] = _expr.Tuple(inputs) elif operator in ["prim::ListUnpack", "prim::TupleUnpack"]: assert len(inputs) == 1 - if isinstance(inputs[0], list): + if isinstance(inputs[0], (list, _expr.TupleWrapper)): unpacked = inputs[0] else: unpacked = _unpack_tuple(inputs[0]) @@ -1864,7 +1853,7 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None): mod = tvm.IRModule() prelude = Prelude(mod) - convert_map = get_convert_map(prelude) + convert_map = _get_convert_map(prelude) graph = script_module.graph.copy() _run_jit_passes(graph) @@ -1897,9 +1886,6 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None): ret = convert_operators(_get_operator_nodes(graph.nodes()), outputs, ret_name, convert_map, prelude) - if isinstance(ret[0], list): - ret[0] = _expr.Tuple(ret[0]) - mod["main"] = tvm.relay.Function(_analysis.free_vars(ret[0]), ret[0]) return mod, tvm_params diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index d60ab9eeec5f..01879115a9c1 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -526,13 +526,13 @@ def test_forward_maxpool2d(): input_data = torch.rand(input_shape).float() verify_model(torch.nn.MaxPool2d(kernel_size=[1, 1]).eval(), - input_data) + input_data) verify_model(torch.nn.MaxPool2d(kernel_size=[10, 10]).eval(), - input_data) + input_data) verify_model(torch.nn.MaxPool2d(kernel_size=[4, 4], padding=2, stride=2).eval(), - input_data) + input_data) def test_forward_maxpool1d(): torch.set_grad_enabled(False) @@ -540,13 +540,13 @@ def test_forward_maxpool1d(): input_data = torch.rand(input_shape).float() verify_model(torch.nn.MaxPool1d(kernel_size=1).eval(), - input_data) + input_data) verify_model(torch.nn.MaxPool1d(kernel_size=10).eval(), - input_data) - verify_model( torch.nn.MaxPool1d(kernel_size=4, + input_data) + verify_model(torch.nn.MaxPool1d(kernel_size=4, padding=2, stride=2).eval(), - input_data) + input_data) def test_forward_maxpool3d(): torch.set_grad_enabled(False) @@ -554,13 +554,13 @@ def test_forward_maxpool3d(): input_data = torch.rand(input_shape).float() verify_model(torch.nn.MaxPool3d(kernel_size=[1, 1, 1]).eval(), - input_data) + input_data) verify_model(torch.nn.MaxPool3d(kernel_size=[10, 10, 10]).eval(), - input_data) + input_data) verify_model(torch.nn.MaxPool3d(kernel_size=[4, 4, 4], padding=2, stride=2).eval(), - input_data) + input_data) def test_forward_split(): torch.set_grad_enabled(False) @@ -577,13 +577,13 @@ def forward(self, *args): input_data = torch.rand(input_shape).float() verify_model(Split(2, 0).float().eval(), - input_data=input_data) + input_data=input_data) verify_model(Split(3, 1).float().eval(), - input_data=input_data) + input_data=input_data) verify_model(Split(4, 1).float().eval(), - input_data=input_data) + input_data=input_data) verify_model(Split([2, 3, 5], 1).float().eval(), - input_data=input_data) + input_data=input_data) def test_forward_avgpool(): torch.set_grad_enabled(False) From 002eb4eb1c6cfc85055d660989a73190aad7ed4c Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 10 Apr 2020 10:50:39 +0900 Subject: [PATCH 16/31] use shape_of function if Any is found --- python/tvm/relay/frontend/common.py | 2 +- python/tvm/relay/frontend/pytorch.py | 22 +++++++++++++++------- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index e86890f3639a..2790ba328799 100644 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -487,7 +487,7 @@ def infer_shape(inputs, mod=None): checked_type = out_type.checked_type if hasattr(checked_type, 'shape'): # Regular operator that outputs tensors - return get_const_tuple(checked_type.shape) + return get_const_tuple(out_type.checked_type.shape) # The return type is not a tensor, for example List return checked_type diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 375584a9cc8f..528774befeae 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -753,9 +753,17 @@ def _impl(inputs, input_types): return dense_out return _impl -def _size(): +def _size(prelude): def _impl(inputs, input_types): - shape = _infer_shape(inputs[0]) + shape = _infer_shape(inputs[0], prelude.mod) + + if any(map(lambda s: isinstance(s, tvm.tir.expr.Any), shape)): + shape_dynamic = _op.shape_of(inputs[0]) + if len(inputs) > 1: + axis = int(inputs[1]) + return _op.take(shape_dynamic, _expr.const(axis), 0) + return shape_dynamic + if len(inputs) > 1: axis = int(inputs[1]) return shape[axis] @@ -922,7 +930,7 @@ def func(x): return _impl -def _chunk(): +def _chunk(prelude): def _impl(inputs, input_types): data = inputs[0] @@ -930,7 +938,7 @@ def _impl(inputs, input_types): axis = int(inputs[2]) if isinstance(data, _expr.Expr): - inferred_shape = _infer_shape(data) + inferred_shape = _infer_shape(data, prelude.mod) shape = [] for infer in inferred_shape: @@ -1246,7 +1254,7 @@ def _convert_elemwise_input(data, input_type): return data def _wrap_const(c): - if not isinstance(c, _expr.Expr) and not isinstance(c, list): + if not isinstance(c, _expr.Expr) and not isinstance(c, (list, tvm.tir.expr.Any)): return _expr.const(c) return c @@ -1309,7 +1317,7 @@ def _get_convert_map(prelude): "aten::t" : _transpose(), "aten::flatten" : _flatten(), "aten::addmm" : _dense(), - "aten::size" : _size(), + "aten::size" : _size(prelude), "aten::view" : _view(), "aten::reshape" : _reshape(), "aten::clone" : _clone(), @@ -1323,7 +1331,7 @@ def _get_convert_map(prelude): "aten::feature_dropout" : _dropout(), "aten::alpha_dropout" : _dropout(), "aten::mean" : _mean(), - "aten::chunk" : _chunk(), + "aten::chunk" : _chunk(prelude), "aten::matmul" : _matmul(), "aten::expand" : _expand(), "aten::Int" : _int(), From 3261466886769030e0075a3e3b370a9d20e0add2 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 10 Apr 2020 11:19:36 +0900 Subject: [PATCH 17/31] improve size conversion --- python/tvm/relay/frontend/pytorch.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 528774befeae..a3c4aa34ac64 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -753,23 +753,30 @@ def _impl(inputs, input_types): return dense_out return _impl + def _size(prelude): + def _impl_dynamic(inp, axis): + shape_dynamic = _op.shape_of(inp) + if axis is not None: + return _op.take(shape_dynamic, _expr.const(axis), 0) + return shape_dynamic + def _impl(inputs, input_types): shape = _infer_shape(inputs[0], prelude.mod) + axis = None + if len(inputs) > 1: + axis = int(inputs[1]) if any(map(lambda s: isinstance(s, tvm.tir.expr.Any), shape)): - shape_dynamic = _op.shape_of(inputs[0]) - if len(inputs) > 1: - axis = int(inputs[1]) - return _op.take(shape_dynamic, _expr.const(axis), 0) - return shape_dynamic + if axis is None or isinstance(shape[axis], tvm.tir.expr.Any): + return _impl_dynamic(inputs[0], axis) - if len(inputs) > 1: - axis = int(inputs[1]) + if axis is not None: return shape[axis] return shape return _impl + def _numtotensor(): def _impl(inputs, input_types): val = inputs[0] From 0a14c19e246a242321088fa249e613ae512de2a4 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 10 Apr 2020 13:01:02 +0900 Subject: [PATCH 18/31] working on adding free vars to loop block --- python/tvm/relay/frontend/pytorch.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index a3c4aa34ac64..9bbba9dcf713 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1604,6 +1604,20 @@ def unpack(tup, num_fields): assert False +def _get_free_vars_from_block(block): + block_inp_names = _get_input_names(block) + bound_names = block_inp_names + free_vars = set() + + for node in block.nodes(): + inp_names = _get_input_names(node) + list_diff = [name for name in inp_names if name not in bound_names] + free_vars.update(list_diff) + bound_names += _get_output_names(node) + + return list(free_vars) + + def get_use_chains(root_node, terminate=lambda _: False): """ Track a chain of users of this node forward, returning a list of chains @@ -1771,6 +1785,13 @@ def get_var(name, val): loop_iter_var = _expr.var(block_input_names[0], shape=(), dtype=loop_iter_dtype) loop_vars = [get_var(name, val) for name, val in name_val_pairs[1:]] + + # add free variable + free_vars = _get_free_vars_from_block(body_block) + additional_vars = [var for var in free_vars + if var in outputs and + not isinstance(outputs[var], (_expr.Constant, int, float))] + loop = while_loop(cond, [loop_iter_var] + loop_vars, body) loop_val = loop(init_loop_iter_val, *init_vals) @@ -1887,6 +1908,7 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None): outputs.update(param_vars) ret_name = _get_input_names(graph.return_node()) + print(graph) # For quantized models if "aten::quantize_per_tensor" in op_names: From f2d8bd23b2992c63413483f735eaa1530471ab8d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 10 Apr 2020 19:03:20 +0900 Subject: [PATCH 19/31] fixed inlined inner loop issue --- python/tvm/relay/frontend/pytorch.py | 95 +++++++++++++++++----------- 1 file changed, 57 insertions(+), 38 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 9bbba9dcf713..97ff2b3e023f 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1610,12 +1610,11 @@ def _get_free_vars_from_block(block): free_vars = set() for node in block.nodes(): - inp_names = _get_input_names(node) - list_diff = [name for name in inp_names if name not in bound_names] - free_vars.update(list_diff) + new_vars = [n for n in node.inputs() if n.debugName() not in bound_names] + free_vars.update(new_vars) bound_names += _get_output_names(node) - return list(free_vars) + return free_vars def get_use_chains(root_node, terminate=lambda _: False): @@ -1734,8 +1733,47 @@ def get_input(index): is_while_loop = (isinstance(max_loop_count, _expr.Constant) and _get_constant(loop_node.inputsAt(0).node()) == sys.maxsize) + if is_while_loop: + loop_iter_dtype = "bool" + # while loop with non input dependent condition such as while i < 10: + # init_cond is int, need to cast to bool to type check + if isinstance(init_cond, _expr.Constant): + init_cond = _op.cast(init_cond, "bool") + init_loop_iter_val = init_cond + else: + loop_iter_dtype = "int32" + # always count from 0 + init_loop_iter_val = _expr.const(0, dtype="int32") + body_block = list(loop_node.blocks())[0] block_input_names = _get_input_names(body_block) + num_block_inputs = len(block_input_names) + name_val_pairs = list(zip(block_input_names, + [init_loop_iter_val] + init_vals)) + outputs.update(name_val_pairs) + + def get_var(name, val): + checked_type = _infer_type_with_prelude(val, prelude) + return _expr.var(name, type_annotation=checked_type) + + loop_iter_var = _expr.var(block_input_names[0], shape=(), + dtype=loop_iter_dtype) + loop_vars = [get_var(name, val) for name, val in name_val_pairs[1:]] + + # add free variables to loop variables + free_vars = _get_free_vars_from_block(body_block) + additional_vars = [var for var in free_vars + if var.debugName() in outputs and + not isinstance(outputs[var.debugName()], (_expr.Constant, int, float))] + prev_outputs = {} + for var in additional_vars: + name = var.debugName() + prev_output = outputs[name] + new_loop_var = get_var(name, prev_output) + prev_outputs[name] = prev_output + outputs[name] = new_loop_var + loop_vars.append(new_loop_var) + init_vals.append(prev_output) def cond(*current_vals): i = current_vals[0] @@ -1747,12 +1785,19 @@ def cond(*current_vals): def body(*current_vals): # Update loop variables using the prev iteration outputs - assert len(current_vals) == len(block_input_names) - for (i, iname) in enumerate(block_input_names): - outputs[iname] = current_vals[i] + assert len(current_vals) == num_block_inputs + len(additional_vars) + + for i in range(len(current_vals)): + if i < num_block_inputs: + outputs[block_input_names[i]] = current_vals[i] + else: + outputs[additional_vars[i-num_block_inputs].debugName()] = current_vals[i] block_outputs = convert_block(body_block, outputs, convert_map, prelude) + for var in additional_vars: + block_outputs.append(outputs[var.debugName()]) + if not is_while_loop: # iter var increment implicit in torch, so do it manually # for while loop, block_outputs[0] is already a boolean, @@ -1762,39 +1807,14 @@ def body(*current_vals): return block_outputs - def get_var(name, val): - checked_type = _infer_type_with_prelude(val, prelude) - return _expr.var(name, type_annotation=checked_type) - - if is_while_loop: - loop_iter_dtype = "bool" - # while loop with non input dependent condition such as while i < 10: - # init_cond is int, need to cast to bool to type check - if isinstance(init_cond, _expr.Constant): - init_cond = _op.cast(init_cond, "bool") - init_loop_iter_val = init_cond - else: - loop_iter_dtype = "int32" - # always count from 0 - init_loop_iter_val = _expr.const(0, dtype="int32") - - name_val_pairs = list(zip(block_input_names, - [init_loop_iter_val] + init_vals)) - outputs.update(name_val_pairs) - - loop_iter_var = _expr.var(block_input_names[0], shape=(), - dtype=loop_iter_dtype) - loop_vars = [get_var(name, val) for name, val in name_val_pairs[1:]] - - # add free variable - free_vars = _get_free_vars_from_block(body_block) - additional_vars = [var for var in free_vars - if var in outputs and - not isinstance(outputs[var], (_expr.Constant, int, float))] - loop = while_loop(cond, [loop_iter_var] + loop_vars, body) loop_val = loop(init_loop_iter_val, *init_vals) + # restore original output values for free vars + for var in additional_vars: + name = var.debugName() + outputs[name] = prev_outputs[name] + # The first element is a loop counter or boolean condition, ignore it return [_expr.TupleGetItem(loop_val, i+1) for i in range(num_loop_var)] @@ -1908,7 +1928,6 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None): outputs.update(param_vars) ret_name = _get_input_names(graph.return_node()) - print(graph) # For quantized models if "aten::quantize_per_tensor" in op_names: From fd297ae3f8c2dd5b4764cae100b9b59a79d564a3 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 10 Apr 2020 19:15:47 +0900 Subject: [PATCH 20/31] clean up free var handling --- python/tvm/relay/frontend/pytorch.py | 27 +++++++++++---------------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 97ff2b3e023f..a1b031a6e224 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1610,8 +1610,9 @@ def _get_free_vars_from_block(block): free_vars = set() for node in block.nodes(): - new_vars = [n for n in node.inputs() if n.debugName() not in bound_names] - free_vars.update(new_vars) + inp_names = _get_input_names(node) + list_diff = [name for name in inp_names if name not in bound_names] + free_vars.update(list_diff) bound_names += _get_output_names(node) return free_vars @@ -1761,13 +1762,11 @@ def get_var(name, val): loop_vars = [get_var(name, val) for name, val in name_val_pairs[1:]] # add free variables to loop variables - free_vars = _get_free_vars_from_block(body_block) - additional_vars = [var for var in free_vars - if var.debugName() in outputs and - not isinstance(outputs[var.debugName()], (_expr.Constant, int, float))] + free_vars = [var for var in _get_free_vars_from_block(body_block) + if var in outputs and not isinstance(outputs[var], (_expr.Constant, int, float))] + prev_outputs = {} - for var in additional_vars: - name = var.debugName() + for name in free_vars: prev_output = outputs[name] new_loop_var = get_var(name, prev_output) prev_outputs[name] = prev_output @@ -1785,18 +1784,16 @@ def cond(*current_vals): def body(*current_vals): # Update loop variables using the prev iteration outputs - assert len(current_vals) == num_block_inputs + len(additional_vars) + assert len(current_vals) == num_block_inputs + len(free_vars) for i in range(len(current_vals)): if i < num_block_inputs: outputs[block_input_names[i]] = current_vals[i] else: - outputs[additional_vars[i-num_block_inputs].debugName()] = current_vals[i] + outputs[free_vars[i-num_block_inputs]] = current_vals[i] block_outputs = convert_block(body_block, outputs, convert_map, prelude) - - for var in additional_vars: - block_outputs.append(outputs[var.debugName()]) + block_outputs += [outputs[name] for name in free_vars] if not is_while_loop: # iter var increment implicit in torch, so do it manually @@ -1811,9 +1808,7 @@ def body(*current_vals): loop_val = loop(init_loop_iter_val, *init_vals) # restore original output values for free vars - for var in additional_vars: - name = var.debugName() - outputs[name] = prev_outputs[name] + outputs.update(prev_outputs) # The first element is a loop counter or boolean condition, ignore it return [_expr.TupleGetItem(loop_val, i+1) for i in range(num_loop_var)] From 18aad7f64ee8a66ae108f2a55dff19b7fafe70fb Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 11 Apr 2020 00:45:58 +0900 Subject: [PATCH 21/31] add support for tensor array concat --- python/tvm/relay/frontend/pytorch.py | 74 ++++++++++++++++++++++------ 1 file changed, 58 insertions(+), 16 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index a1b031a6e224..d15e002cc5b6 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -77,6 +77,23 @@ def _convert_to_tensor_array(adt_lst, prelude): def _should_construct_dynamic_list(list_construct_node): # if this list is element-accessed or modified at runtime, generate List ADT + + def is_used_by_list_add(uses): + for use in uses: + op_name = use.user.kind() + if op_name == "prim::Loop": + continue + output_type = _get_node_type(use.user) + if op_name in ["aten::add", "aten::add_"] and output_type == "ListType": + return True + return False + + def inplace_add_to_add(op_name): + if op_name == "aten::add_": + return "aten::add" + else: + return op_name + uses = _get_uses(list_construct_node) for loop_use in filter(lambda use: use.user.kind() == "prim::Loop", uses): @@ -85,15 +102,15 @@ def _should_construct_dynamic_list(list_construct_node): list_loop_var = list(block.inputs())[block_input_index] uses += _get_uses(list_loop_var.node()) - op_names = set(use.user.kind() for use in uses) - list_ops = set(["aten::add_", "aten::__getitem__", "aten::stack"]) + op_names = map(inplace_add_to_add, set(use.user.kind() for use in uses)) + + list_ops = set(["aten::add", "aten::__getitem__", "aten::stack"]) intersect = list_ops.intersection(op_names) - if len(intersect) > 0 and intersect != set(["aten::add_"]): + if len(intersect) > 0 and intersect != set(["aten::add"]): return True - output_type = _get_node_type(list_construct_node) - if intersect == set(["aten::add_"]) and output_type == "ListType": + if is_used_by_list_add(uses): return True return False @@ -163,11 +180,28 @@ def _impl(inputs, input_types): return _op.transform.expand_dims(data, int(axis), 1) return _impl -def _concatenate(): + +def _concatenate(prelude): + def tensor_array_concat(lst, axis): + # assert axis == 0 + tensor_array = _convert_to_tensor_array(lst, prelude) + shape = get_tensor_array_shape(tensor_array, "float32", prelude) + print("tensor array concat shape:", shape) + concat = prelude.get_var_static('tensor_array_concat', "float32", shape) + concatenated = concat(tensor_array) + + static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape) + static_tensor_array_ops.define_tensor_get_data(shape) + get_tensor = prelude.get_var_static('tensor_get_data', "float32", shape) + return get_tensor(concatenated) + def _impl(inputs, input_types): data = inputs[0] axis = inputs[1] + if input_types[0] == "ListType": + return tensor_array_concat(data, axis) + if isinstance(data, _expr.Expr): data = [data] @@ -678,13 +712,13 @@ def _impl(inputs, input_types): scale=True) return _impl -def _transpose(): +def _transpose(prelude): def _impl(inputs, input_types): data = inputs[0] import torch if isinstance(data, _expr.Expr): - ndims = len(_infer_shape(data)) + ndims = len(_infer_shape(data, prelude.mod)) elif isinstance(data, list): ndims = data elif isinstance(data, (torch.Tensor, np.ndarray)): @@ -1164,7 +1198,13 @@ def _impl(inputs, input_types): return _impl -def _add_(prelude): +def _list_len(prelude): + def _impl(inputs, input_types): + return prelude.length(inputs[0]) + return _impl + + +def _add(prelude): # add_ is overloaded for tensor add and list concat def _impl(inputs, input_types): if input_types[0] == "ListType": @@ -1269,7 +1309,6 @@ def _wrap_const(c): def _get_convert_map(prelude): convert_map = { "aten::device" : _none(), - "aten::add" : _elemwise("add"), "aten::sub" : _elemwise("subtract"), "aten::sub_" : _elemwise("subtract"), "aten::max" : _elemwise("maximum"), @@ -1289,7 +1328,7 @@ def _get_convert_map(prelude): "aten::to" : _to(), "aten::squeeze" : _squeeze(), "aten::unsqueeze" : _unsqueeze(), - "aten::cat" : _concatenate(), + "aten::cat" : _concatenate(prelude), "aten::slice" : _slice(), "aten::split" : _split(), "aten::split_with_sizes" : _split_with_sizes(), @@ -1319,9 +1358,9 @@ def _get_convert_map(prelude): "aten::batch_norm" : _batch_norm(), "aten::instance_norm" : _instance_norm(), "aten::layer_norm" : _layer_norm(), - "aten::transpose" : _transpose(), - "aten::transpose_" : _transpose(), - "aten::t" : _transpose(), + "aten::transpose" : _transpose(prelude), + "aten::transpose_" : _transpose(prelude), + "aten::t" : _transpose(prelude), "aten::flatten" : _flatten(), "aten::addmm" : _dense(), "aten::size" : _size(prelude), @@ -1344,7 +1383,7 @@ def _get_convert_map(prelude): "aten::Int" : _int(), "prim::NumToTensor" : _numtotensor(), "aten::constant_pad_nd" : _pad(), - "aten::permute" : _transpose(), + "aten::permute" : _transpose(prelude), "aten::sum" : _reduce("sum"), "aten::prod" : _reduce("prod"), "aten::sqrt" : _sqrt(), @@ -1366,9 +1405,11 @@ def _get_convert_map(prelude): "aten::adaptive_max_pool3d" : _adaptive_max_pool_3d(), "aten::mm" : _matmul(), "relay::tensor_array_stack" : _tensor_array_stack(prelude), - "aten::add_" : _add_(prelude), + "aten::add" : _add(prelude), + "aten::add_" : _add(prelude), "aten::stack" : _tensor_array_stack(prelude), "aten::__getitem__" : _list_getitem(prelude), + "aten::len" : _list_len(prelude), } return convert_map @@ -1923,6 +1964,7 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None): outputs.update(param_vars) ret_name = _get_input_names(graph.return_node()) + print(graph) # For quantized models if "aten::quantize_per_tensor" in op_names: From a7c59ed365779449d2bc4b4757fac527d35180c5 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 11 Apr 2020 06:41:55 +0900 Subject: [PATCH 22/31] adding ta concat on last axis --- python/tvm/relay/frontend/pytorch.py | 3 +- python/tvm/relay/prelude.py | 76 ++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index d15e002cc5b6..dbcc4175ec79 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -68,6 +68,7 @@ def _convert_to_tensor_array(adt_lst, prelude): return prelude.nil() shape = _infer_type_with_prelude(prelude.hd(adt_lst), prelude).shape + print("register shape:", shape) static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape) static_tensor_array_ops.register() tensor_create = prelude.get_var_static('tensor_constructor', "float32", shape) @@ -187,7 +188,7 @@ def tensor_array_concat(lst, axis): tensor_array = _convert_to_tensor_array(lst, prelude) shape = get_tensor_array_shape(tensor_array, "float32", prelude) print("tensor array concat shape:", shape) - concat = prelude.get_var_static('tensor_array_concat', "float32", shape) + concat = prelude.get_var_static('tensor_array_concat_last', "float32", shape) concatenated = concat(tensor_array) static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape) diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index f512b68bcac4..bb3d9094e18f 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -200,6 +200,39 @@ def define_tensor_concatenate(self): self.prelude.mod[concat_var] = \ Function([x, y], Match(x, [case], False), tensor_type_var(), []) + def define_tensor_concatenate_last(self): + """Defines a function to concatenate two tensor_t on axis -1. + tensor_concatenate(t) : tensor_t -> tensor_t -> tensor_t + """ + # We don't register concatenate for scalar tensor. + ndim = len(self.shape) + if ndim == 0: + return + + concat_name = self.get_name("tensor_concatenate_last") + concat_var = self._create_global_var(concat_name) + setattr(self.prelude, concat_name, concat_var) + output_shape = list(self.shape[:-1]) + [Any(),] + print("Tensor concat output_shape:", output_shape) + tensor_type_var, tensor_constructor = \ + self._get_adt_by_shape(output_shape) + + origin_tensor_constructor = self.get_var('tensor_constructor') + origin_tensor_type_var = self.get_var('tensor_t') + x = Var("x", origin_tensor_type_var()) + y = Var("y", origin_tensor_type_var()) + t1 = Var("t1") + t2 = Var("t2") + + case = Clause(PatternConstructor(origin_tensor_constructor, [PatternVar(t1)]), + Match(y, + [Clause(PatternConstructor(origin_tensor_constructor, [PatternVar(t2)]), + tensor_constructor(op.concatenate([t1, t2], axis=-1)))], + False)) + + self.prelude.mod[concat_var] = \ + Function([x, y], Match(x, [case], False), tensor_type_var(), []) + def define_tensor_expand_dims(self): """Defines a function to grow a tensor_t's rank by adding one dimension in front @@ -483,6 +516,47 @@ def define_tensor_array_concat(self): Function([tensor_array], Match(tensor_array, [nil_case, cons_case], False), tensor_type_var(), []) + def define_tensor_array_concat_last(self): + """Defines a function to return the values in the tensor array as concatenated tensor_t. + tensor_array_concat(ta) : list[tensor_t] -> tensor_t + """ + # We don't register concat for scalar tensor array. + ndim = len(self.shape) + if ndim == 0: + return + + concat_name = self.get_name("tensor_array_concat_last") + concat_var = self._create_global_var(concat_name) + setattr(self.prelude, concat_name, concat_var) + + output_shape = list(self.shape[:-1]) + [Any(),] + print("output shape", self.shape, output_shape) + tensor_type_var, _ = self._get_adt_by_shape(output_shape) + + # Register tensor concatenate and get tensor_nil var for output shape + origin_shape = self.shape + self.shape = output_shape + self.define_tensor_concatenate() + print(self.prelude.mod) + print(self.shape) + tensor_concat_var = self.get_var('tensor_concatenate_last') + tensor_nil_var = self.get_var('tensor_nil') + self.shape = origin_shape + + tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) + hd = Var("hd") + tl = Var("tl") + nil_case = Clause(PatternConstructor(self.prelude.nil), tensor_nil_var()) + cons_case = Clause(PatternConstructor(self.prelude.cons, [PatternVar(hd), PatternVar(tl)]), + Match(tl, [ + Clause(PatternConstructor(self.prelude.nil), hd), + Clause(PatternWildcard(), + tensor_concat_var(hd, concat_var(tl))) + ], False)) + self.prelude.mod[concat_var] = \ + Function([tensor_array], + Match(tensor_array, [nil_case, cons_case], False), tensor_type_var(), []) + def define_tensor_array_stack(self): """Defines a function to get the values in the tensor array as a stack tensor_t. tensor_array_stack(l) : list[tensor_t] -> tensor_t @@ -572,6 +646,7 @@ def register(self): self.define_tensor_adt() self.define_tensor_take() self.define_tensor_concatenate() + self.define_tensor_concatenate_last() self.define_tensor_expand_dims() self.define_tensor_array() self.define_tensor_array_read() @@ -580,6 +655,7 @@ def register(self): self.define_tensor_array_scatter() self.define_tensor_array_split() self.define_tensor_array_concat() + self.define_tensor_array_concat_last() self.define_tensor_array_stack() self.define_tensor_array_gather() From e9cb1a77065170c908f7d034893c851e595e638e Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 11 Apr 2020 09:23:29 +0900 Subject: [PATCH 23/31] fix concat, but got runtime error --- python/tvm/relay/frontend/pytorch.py | 35 ++++++++++++++++------------ python/tvm/relay/prelude.py | 7 +----- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index dbcc4175ec79..b685a1f20307 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -63,12 +63,7 @@ def _convert_to_list_adt(py_lst, prelude): return adt_lst -def _convert_to_tensor_array(adt_lst, prelude): - if prelude.length(adt_lst) == 0: - return prelude.nil() - - shape = _infer_type_with_prelude(prelude.hd(adt_lst), prelude).shape - print("register shape:", shape) +def _map_tensor_array_constructor(adt_lst, prelude, shape): static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape) static_tensor_array_ops.register() tensor_create = prelude.get_var_static('tensor_constructor', "float32", shape) @@ -76,6 +71,14 @@ def _convert_to_tensor_array(adt_lst, prelude): return prelude.map(tensor_create, adt_lst) +def _convert_to_tensor_array(adt_lst, prelude): + if prelude.length(adt_lst) == 0: + return prelude.nil() + + shape = _infer_type_with_prelude(prelude.hd(adt_lst), prelude).shape + return _map_tensor_array_constructor(adt_lst, prelude, shape) + + def _should_construct_dynamic_list(list_construct_node): # if this list is element-accessed or modified at runtime, generate List ADT @@ -184,16 +187,18 @@ def _impl(inputs, input_types): def _concatenate(prelude): def tensor_array_concat(lst, axis): - # assert axis == 0 - tensor_array = _convert_to_tensor_array(lst, prelude) - shape = get_tensor_array_shape(tensor_array, "float32", prelude) - print("tensor array concat shape:", shape) - concat = prelude.get_var_static('tensor_array_concat_last', "float32", shape) - concatenated = concat(tensor_array) + # TODO for axis == 0 case + assert axis == -1 + shape = _infer_type_with_prelude(prelude.hd(lst), prelude).shape + concat_shape = tuple(shape[:-1]) + (Any(),) - static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape) - static_tensor_array_ops.define_tensor_get_data(shape) - get_tensor = prelude.get_var_static('tensor_get_data', "float32", shape) + tensor_array = _map_tensor_array_constructor(lst, prelude, concat_shape) + static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", concat_shape) + static_tensor_array_ops.define_tensor_get_data(concat_shape) + + concat = prelude.get_var_static('tensor_array_concat_last', "float32", concat_shape) + concatenated = concat(tensor_array) + get_tensor = prelude.get_var_static('tensor_get_data', "float32", concat_shape) return get_tensor(concatenated) def _impl(inputs, input_types): diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index bb3d9094e18f..6e56be792a3a 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -213,7 +213,6 @@ def define_tensor_concatenate_last(self): concat_var = self._create_global_var(concat_name) setattr(self.prelude, concat_name, concat_var) output_shape = list(self.shape[:-1]) + [Any(),] - print("Tensor concat output_shape:", output_shape) tensor_type_var, tensor_constructor = \ self._get_adt_by_shape(output_shape) @@ -423,7 +422,6 @@ def define_tensor_array_split(self, take_var = self.get_var('tensor_take') self.shape = origin_shape - ta1 = Var("tensor_array", self.prelude.l(output_tensor_type_var())) value1 = Var('value1', value_type_var()) offset1 = Var('offset1', scalar_type('int32')) @@ -530,15 +528,12 @@ def define_tensor_array_concat_last(self): setattr(self.prelude, concat_name, concat_var) output_shape = list(self.shape[:-1]) + [Any(),] - print("output shape", self.shape, output_shape) tensor_type_var, _ = self._get_adt_by_shape(output_shape) # Register tensor concatenate and get tensor_nil var for output shape origin_shape = self.shape self.shape = output_shape - self.define_tensor_concatenate() - print(self.prelude.mod) - print(self.shape) + self.define_tensor_concatenate_last() tensor_concat_var = self.get_var('tensor_concatenate_last') tensor_nil_var = self.get_var('tensor_nil') self.shape = origin_shape From a2b0da4c2e671adf638d0df08aada0e39e96731d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 11 Apr 2020 14:04:48 +0900 Subject: [PATCH 24/31] disable concat on axis -1 for now --- python/tvm/relay/frontend/pytorch.py | 6 ++---- src/relay/backend/contrib/dnnl/codegen.cc | 6 ++---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index b685a1f20307..259e7692d9f7 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -187,10 +187,9 @@ def _impl(inputs, input_types): def _concatenate(prelude): def tensor_array_concat(lst, axis): - # TODO for axis == 0 case - assert axis == -1 + assert axis == 0, "Tensor array concat supported only for axis 0" shape = _infer_type_with_prelude(prelude.hd(lst), prelude).shape - concat_shape = tuple(shape[:-1]) + (Any(),) + concat_shape = (Any(), ) + tuple(shape) tensor_array = _map_tensor_array_constructor(lst, prelude, concat_shape) static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", concat_shape) @@ -1970,7 +1969,6 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None): outputs.update(param_vars) ret_name = _get_input_names(graph.return_node()) - print(graph) # For quantized models if "aten::quantize_per_tensor" in op_names: diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index 7f3aabf6e016..2bf3a3284ee1 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -194,10 +194,8 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase { } out_.clear(); - for (size_t i = 0; i < ret.outputs.size(); ++i) { - buf_decl_.push_back(ret.buffers[i]); - out_.push_back(ret.outputs[i]); - } + buf_decl_.insert(buf_decl_.end(), ret.buffers.begin(), ret.buffers.end()); + out_.insert(out_.end(), ret.outputs.begin(), ret.outputs.end()); ext_func_body.push_back(ret.decl); } From eb70587ff1e8c43e6d024621d97794cd689a672e Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 11 Apr 2020 15:41:04 +0900 Subject: [PATCH 25/31] add lstm tests --- python/tvm/relay/frontend/pytorch.py | 34 +- python/tvm/relay/prelude.py | 73 ---- tests/python/frontend/pytorch/lstm_test.py | 348 ++++++++++++++++++ tests/python/frontend/pytorch/test_forward.py | 5 + 4 files changed, 374 insertions(+), 86 deletions(-) create mode 100644 tests/python/frontend/pytorch/lstm_test.py diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 259e7692d9f7..c0003e3ae2c5 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -67,7 +67,6 @@ def _map_tensor_array_constructor(adt_lst, prelude, shape): static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape) static_tensor_array_ops.register() tensor_create = prelude.get_var_static('tensor_constructor', "float32", shape) - return prelude.map(tensor_create, adt_lst) @@ -81,7 +80,6 @@ def _convert_to_tensor_array(adt_lst, prelude): def _should_construct_dynamic_list(list_construct_node): # if this list is element-accessed or modified at runtime, generate List ADT - def is_used_by_list_add(uses): for use in uses: op_name = use.user.kind() @@ -189,13 +187,13 @@ def _concatenate(prelude): def tensor_array_concat(lst, axis): assert axis == 0, "Tensor array concat supported only for axis 0" shape = _infer_type_with_prelude(prelude.hd(lst), prelude).shape - concat_shape = (Any(), ) + tuple(shape) + concat_shape = (Any(), ) + tuple(shape[1:]) - tensor_array = _map_tensor_array_constructor(lst, prelude, concat_shape) + tensor_array = _map_tensor_array_constructor(lst, prelude, shape) static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", concat_shape) static_tensor_array_ops.define_tensor_get_data(concat_shape) - concat = prelude.get_var_static('tensor_array_concat_last', "float32", concat_shape) + concat = prelude.get_var_static('tensor_array_concat', "float32", concat_shape) concatenated = concat(tensor_array) get_tensor = prelude.get_var_static('tensor_get_data', "float32", concat_shape) return get_tensor(concatenated) @@ -204,7 +202,7 @@ def _impl(inputs, input_types): data = inputs[0] axis = inputs[1] - if input_types[0] == "ListType": + if not isinstance(data, list): return tensor_array_concat(data, axis) if isinstance(data, _expr.Expr): @@ -1800,16 +1798,26 @@ def get_input(index): outputs.update(name_val_pairs) def get_var(name, val): - checked_type = _infer_type_with_prelude(val, prelude) - return _expr.var(name, type_annotation=checked_type) + if val is not None: + print(val) + checked_type = _infer_type_with_prelude(val, prelude) + return _expr.var(name, type_annotation=checked_type) + return _expr.var(name) loop_iter_var = _expr.var(block_input_names[0], shape=(), dtype=loop_iter_dtype) loop_vars = [get_var(name, val) for name, val in name_val_pairs[1:]] - # add free variables to loop variables + # Add non constant free variables to loop variables to prevent code blow up + # Without this, if there are two for loops in a row, which often happens + # if the outer loop is unrolled, the computation corresponding to the first for loop + # is inlined inside loop body, turning O(N) + O(N) computation into O(N^2). + # This issue was found when converting from Stacked LSTM test. Torch does not add the output + # of the eariler loop into loop variables of the next loop. + # So the variable corresponding to the first loop output appears free in the second loop body. free_vars = [var for var in _get_free_vars_from_block(body_block) - if var in outputs and not isinstance(outputs[var], (_expr.Constant, int, float))] + if var in outputs and not isinstance(outputs[var], (_expr.Constant, int, float)) + and outputs[var]] prev_outputs = {} for name in free_vars: @@ -1832,11 +1840,11 @@ def body(*current_vals): # Update loop variables using the prev iteration outputs assert len(current_vals) == num_block_inputs + len(free_vars) - for i in range(len(current_vals)): + for (i, val) in enumerate(current_vals): if i < num_block_inputs: - outputs[block_input_names[i]] = current_vals[i] + outputs[block_input_names[i]] = val else: - outputs[free_vars[i-num_block_inputs]] = current_vals[i] + outputs[free_vars[i-num_block_inputs]] = val block_outputs = convert_block(body_block, outputs, convert_map, prelude) block_outputs += [outputs[name] for name in free_vars] diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index 6e56be792a3a..cea6ffdf8a1d 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -200,39 +200,6 @@ def define_tensor_concatenate(self): self.prelude.mod[concat_var] = \ Function([x, y], Match(x, [case], False), tensor_type_var(), []) - def define_tensor_concatenate_last(self): - """Defines a function to concatenate two tensor_t on axis -1. - tensor_concatenate(t) : tensor_t -> tensor_t -> tensor_t - """ - # We don't register concatenate for scalar tensor. - ndim = len(self.shape) - if ndim == 0: - return - - concat_name = self.get_name("tensor_concatenate_last") - concat_var = self._create_global_var(concat_name) - setattr(self.prelude, concat_name, concat_var) - output_shape = list(self.shape[:-1]) + [Any(),] - tensor_type_var, tensor_constructor = \ - self._get_adt_by_shape(output_shape) - - origin_tensor_constructor = self.get_var('tensor_constructor') - origin_tensor_type_var = self.get_var('tensor_t') - x = Var("x", origin_tensor_type_var()) - y = Var("y", origin_tensor_type_var()) - t1 = Var("t1") - t2 = Var("t2") - - case = Clause(PatternConstructor(origin_tensor_constructor, [PatternVar(t1)]), - Match(y, - [Clause(PatternConstructor(origin_tensor_constructor, [PatternVar(t2)]), - tensor_constructor(op.concatenate([t1, t2], axis=-1)))], - False)) - - self.prelude.mod[concat_var] = \ - Function([x, y], Match(x, [case], False), tensor_type_var(), []) - - def define_tensor_expand_dims(self): """Defines a function to grow a tensor_t's rank by adding one dimension in front of the original tensor_t. @@ -514,44 +481,6 @@ def define_tensor_array_concat(self): Function([tensor_array], Match(tensor_array, [nil_case, cons_case], False), tensor_type_var(), []) - def define_tensor_array_concat_last(self): - """Defines a function to return the values in the tensor array as concatenated tensor_t. - tensor_array_concat(ta) : list[tensor_t] -> tensor_t - """ - # We don't register concat for scalar tensor array. - ndim = len(self.shape) - if ndim == 0: - return - - concat_name = self.get_name("tensor_array_concat_last") - concat_var = self._create_global_var(concat_name) - setattr(self.prelude, concat_name, concat_var) - - output_shape = list(self.shape[:-1]) + [Any(),] - tensor_type_var, _ = self._get_adt_by_shape(output_shape) - - # Register tensor concatenate and get tensor_nil var for output shape - origin_shape = self.shape - self.shape = output_shape - self.define_tensor_concatenate_last() - tensor_concat_var = self.get_var('tensor_concatenate_last') - tensor_nil_var = self.get_var('tensor_nil') - self.shape = origin_shape - - tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) - hd = Var("hd") - tl = Var("tl") - nil_case = Clause(PatternConstructor(self.prelude.nil), tensor_nil_var()) - cons_case = Clause(PatternConstructor(self.prelude.cons, [PatternVar(hd), PatternVar(tl)]), - Match(tl, [ - Clause(PatternConstructor(self.prelude.nil), hd), - Clause(PatternWildcard(), - tensor_concat_var(hd, concat_var(tl))) - ], False)) - self.prelude.mod[concat_var] = \ - Function([tensor_array], - Match(tensor_array, [nil_case, cons_case], False), tensor_type_var(), []) - def define_tensor_array_stack(self): """Defines a function to get the values in the tensor array as a stack tensor_t. tensor_array_stack(l) : list[tensor_t] -> tensor_t @@ -641,7 +570,6 @@ def register(self): self.define_tensor_adt() self.define_tensor_take() self.define_tensor_concatenate() - self.define_tensor_concatenate_last() self.define_tensor_expand_dims() self.define_tensor_array() self.define_tensor_array_read() @@ -650,7 +578,6 @@ def register(self): self.define_tensor_array_scatter() self.define_tensor_array_split() self.define_tensor_array_concat() - self.define_tensor_array_concat_last() self.define_tensor_array_stack() self.define_tensor_array_gather() diff --git a/tests/python/frontend/pytorch/lstm_test.py b/tests/python/frontend/pytorch/lstm_test.py new file mode 100644 index 000000000000..672cd44fc6cc --- /dev/null +++ b/tests/python/frontend/pytorch/lstm_test.py @@ -0,0 +1,348 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Tests on torch lstm model conversion """ +# originally from https://github.com/pytorch/pytorch/blob/master/benchmarks/fastrnns/custom_lstms.py +# described in https://pytorch.org/blog/optimizing-cuda-rnn-with-torchscript/ +import numpy as np +import torch +import torch.nn as nn +from torch.nn import Parameter +import torch.jit as jit +from typing import List, Tuple +from torch import Tensor + +import tvm +from tvm import relay +from tvm.relay.frontend.pytorch import from_pytorch +from tvm.relay.prelude import Prelude +from tvm.runtime.container import ADT, tuple_object + + +class LayerNormLSTMCell(jit.ScriptModule): + def __init__(self, input_size, hidden_size): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size)) + self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size)) + + ln = nn.LayerNorm + + self.layernorm_i = ln(4 * hidden_size) + self.layernorm_h = ln(4 * hidden_size) + self.layernorm_c = ln(hidden_size) + + @jit.script_method + def forward(self, input, state): + # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] + hx, cx = state + igates = self.layernorm_i(torch.mm(input, self.weight_ih.t())) + hgates = self.layernorm_h(torch.mm(hx, self.weight_hh.t())) + gates = igates + hgates + ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) + + ingate = torch.sigmoid(ingate) + forgetgate = torch.sigmoid(forgetgate) + cellgate = torch.tanh(cellgate) + outgate = torch.sigmoid(outgate) + + cy = self.layernorm_c((forgetgate * cx) + (ingate * cellgate)) + hy = outgate * torch.tanh(cy) + + return hy, (hy, cy) + + +class LSTMLayer(jit.ScriptModule): + def __init__(self, cell, *cell_args): + super().__init__() + self.cell = cell(*cell_args) + + @jit.script_method + def forward(self, input, state): + # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] + outputs = [] + for i in range(input.size(0)): + out, state = self.cell(input[i], state) + outputs += [out] + return torch.stack(outputs), state + + +class ReverseLSTMLayer(jit.ScriptModule): + def __init__(self, cell, *cell_args): + super(ReverseLSTMLayer, self).__init__() + self.cell = cell(*cell_args) + + @jit.script_method + def forward(self, inputs, state): + # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] + outputs = jit.annotate(List[Tensor], []) + seq_len = inputs.size(0) + for i in range(seq_len): + out, state = self.cell(inputs[seq_len - i - 1], state) + # workaround for the lack of list rev support + outputs = [out] + outputs + return torch.stack(outputs), state + + +class BidirLSTMLayer(jit.ScriptModule): + __constants__ = ['directions'] + + def __init__(self, cell, *cell_args): + super(BidirLSTMLayer, self).__init__() + self.directions = nn.ModuleList([ + LSTMLayer(cell, *cell_args), + ReverseLSTMLayer(cell, *cell_args), + ]) + + @jit.script_method + def forward(self, input, states): + # type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]] + # List[LSTMState]: [forward LSTMState, backward LSTMState] + outputs = jit.annotate(List[Tensor], []) + output_states = jit.annotate(List[Tuple[Tensor, Tensor]], []) + for (i, direction) in enumerate(self.directions): + state = states[i] + out, out_state = direction(input, state) + outputs += [out] + output_states += [out_state] + # tensor array concat assumes axis == 0 for now + # return torch.cat(outputs, -1), output_states + return torch.cat(outputs, 0), output_states + + +def init_stacked_lstm(num_layers, layer, first_layer_args, other_layer_args): + layers = [layer(*first_layer_args)] + [layer(*other_layer_args) + for _ in range(num_layers - 1)] + return nn.ModuleList(layers) + + +class StackedLSTM(jit.ScriptModule): + __constants__ = ['layers'] # Necessary for iterating through self.layers + + def __init__(self, num_layers, layer, first_layer_args, other_layer_args): + super().__init__() + self.layers = init_stacked_lstm(num_layers, layer, first_layer_args, + other_layer_args) + + @jit.script_method + def forward(self, input, states): + # type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]] + # List[LSTMState]: One state per layer + output_states = jit.annotate(List[Tuple[Tensor, Tensor]], []) + output = input + for (i, rnn_layer) in enumerate(self.layers): + state = states[i] + output, out_state = rnn_layer(output, state) + output_states += [out_state] + return output, output_states + + +class StackedBidirLSTM(jit.ScriptModule): + __constants__ = ['layers'] # Necessary for iterating through self.layers + + def __init__(self, num_layers, layer, first_layer_args, other_layer_args): + super(StackedBidirLSTM, self).__init__() + self.layers = init_stacked_lstm(num_layers, layer, first_layer_args, + other_layer_args) + + @jit.script_method + def forward(self, input, states): + # type: (Tensor, List[List[Tuple[Tensor, Tensor]]]) -> Tuple[Tensor, List[List[Tuple[Tensor, Tensor]]]] + # List[List[LSTMState]]: The outer list is for layers, + # inner list is for directions. + output_states = jit.annotate(List[List[Tuple[Tensor, Tensor]]], []) + output = input + for (i, rnn_layer) in enumerate(self.layers): + state = states[i] + output, out_state = rnn_layer(output, state) + output_states += [out_state] + return output, output_states + + +def lstm(input_size, hidden_size): + return LSTMLayer(LayerNormLSTMCell, input_size, hidden_size) + + +def stacked_lstm(input_size, hidden_size, num_layers): + return StackedLSTM(num_layers, LSTMLayer, + first_layer_args=[LayerNormLSTMCell, input_size, hidden_size], + other_layer_args=[LayerNormLSTMCell, hidden_size, hidden_size]) + + +def bidir_lstm(input_size, hidden_size): + return BidirLSTMLayer(LayerNormLSTMCell, input_size, hidden_size) + + +def stacked_bidir_lstm(input_size, hidden_size, num_layers): + return StackedBidirLSTM(num_layers, BidirLSTMLayer, + first_layer_args=[LayerNormLSTMCell, input_size, hidden_size], + other_layer_args=[LayerNormLSTMCell, hidden_size * 2, hidden_size]) + + +def vmobj_to_list(o, dtype="float32"): + if isinstance(o, tvm.nd.NDArray): + return [o] + elif isinstance(o, tvm.runtime.container.ADT): + result = [] + for f in o: + result.extend(vmobj_to_list(f, dtype)) + return result + else: + raise RuntimeError("Unknown object type: %s" % type(o)) + + +def assert_equal(tvm_result, torch_result): + if isinstance(torch_result, (tuple, list)): + assert isinstance(tvm_result, list) + for tvm_res, pt_res in zip(tvm_result, torch_result): + assert_equal(tvm_res, pt_res) + elif isinstance(torch_result, torch.Tensor): + print(np.max(np.abs(tvm_result.asnumpy() - torch_result.numpy()))) + tvm.testing.assert_allclose(tvm_result.asnumpy(), torch_result.numpy(), + rtol=1e-5, atol=1e-5) + else: + tvm_res = np.asscalar(tvm_result.asnumpy()) + print(abs(torch_result - tvm_res)) + assert torch_result == tvm_res + + +def run_and_compare(mod, params, pt_result): + executor = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(0), target="llvm") + evaluator = executor.evaluate() + + exec_res = evaluator(**params) + + def flatten(nested): + res = [] + for r in nested: + if isinstance(r, torch.Tensor): + res.append(r) + else: + res.extend(flatten(r)) + return res + + if isinstance(exec_res, tvm.runtime.container.ADT): + assert not isinstance(pt_result, torch.Tensor) + tvm_res = vmobj_to_list(exec_res) + torch_res = flatten(pt_result) + else: + tvm_res = exec_res + torch_res = pt_result + + assert_equal(tvm_res, torch_res) + + +def convert_list_to_vmobj(py_lst): + def wrap_nd_array(arr): + return tvm.nd.array(arr, ctx=tvm.cpu(0)) + + mod = tvm.IRModule() + prelude = Prelude(mod) + adt_lst = ADT(prelude.nil.tag, []) + for elem in reversed(py_lst): + if isinstance(elem, np.ndarray): + vmobj = wrap_nd_array(elem) + elif isinstance(elem, tuple): + vmobj = tuple_object([wrap_nd_array(e) for e in elem]) + elif isinstance(elem, list): + vmobj = convert_list_to_vmobj(elem) + adt_lst = ADT(prelude.cons.tag, [vmobj, adt_lst]) + return adt_lst + + +def custom_lstm_test(): + input_name = "input" + states_name = "states" + seq_len = 5 + batch = 2 + input_size = 3 + hidden_size = 4 + num_layers = 3 + + inp = torch.randn(seq_len, batch, input_size) + + input_shapes = [(input_name, (seq_len, batch, input_size)), + (states_name, ((batch, hidden_size), (batch, hidden_size)))] + + input_shapes_stacked = [(input_name, (seq_len, batch, input_size)), + (states_name, [((batch, hidden_size), (batch, hidden_size)), + ((batch, hidden_size), (batch, hidden_size))])] + + input_shapes_stacked_bidir = [(input_name, (seq_len, batch, input_size)), + (states_name, [((batch, hidden_size), (batch, hidden_size)), + ((batch, hidden_size), (batch, hidden_size))])] + + input_shapes_stacked_bidir= [(input_name, (seq_len, batch, input_size)), + (states_name, [[((batch, hidden_size), + (batch, hidden_size)) + for _ in range(2)] + for _ in range(num_layers)])] + + states = [(torch.randn(batch, hidden_size), + torch.randn(batch, hidden_size)) + for _ in range(num_layers)] + + bidir_states = [(torch.randn(batch, hidden_size), + torch.randn(batch, hidden_size)) + for _ in range(2)] + + stacked_bidir_states = [[(torch.randn(batch, hidden_size), + torch.randn(batch, hidden_size)) + for _ in range(2)] + for _ in range(num_layers)] + + models = [ + (lstm(input_size, hidden_size).eval(), states[0], input_shapes), + (stacked_lstm(input_size, hidden_size, num_layers).eval(), states, input_shapes_stacked), + (bidir_lstm(input_size, hidden_size).eval(), bidir_states, input_shapes_stacked), + # (stacked_bidir_lstm(input_size, hidden_size, num_layers).eval(), + # stacked_bidir_states, input_shapes_stacked_bidir) + ] + + for (raw_model, states, input_shapes) in models: + script_module = torch.jit.script(raw_model) + mod, params = from_pytorch(script_module, input_shapes) + # print(mod["main"]) + + with torch.no_grad(): + pt_result = raw_model(inp.clone(), states) + + params[input_name] = inp.numpy() + + if isinstance(states, tuple): + states_np = tuple(st.numpy() for st in states) + elif isinstance(states, list) and isinstance(states[0], torch.Tensor): + states_np = [st.numpy() for st in states] + elif isinstance(states, list) and isinstance(states[0], tuple): + states_np = [tuple(st.numpy() for st in states[i]) + for i in range(len(states))] + elif isinstance(states, list) and isinstance(states[0], list): + states_np = [[tuple(st.numpy() for st in states[i]) + for i in range(len(states[layer]))] + for layer in range(num_layers)] + else: + assert False + + if isinstance(states_np, list): + params[states_name] = convert_list_to_vmobj(states_np) + else: + params[states_name] = states_np + + run_and_compare(mod, params, pt_result) + + +custom_lstm_test() diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 01879115a9c1..8e9928510220 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1363,3 +1363,8 @@ def forward(self, xs): # Test simple conditionals and loop test_control_flow() test_simple_rnn() + + # More complex recurrent models + from lstm_test import custom_lstm_test + + custom_lstm_test() From f53cc0caa052770c55377fda79f6135d1e613a9a Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 11 Apr 2020 15:45:31 +0900 Subject: [PATCH 26/31] revert unrelated change --- python/tvm/relay/frontend/common.py | 2 +- python/tvm/relay/prelude.py | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 2790ba328799..e86890f3639a 100644 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -487,7 +487,7 @@ def infer_shape(inputs, mod=None): checked_type = out_type.checked_type if hasattr(checked_type, 'shape'): # Regular operator that outputs tensors - return get_const_tuple(out_type.checked_type.shape) + return get_const_tuple(checked_type.shape) # The return type is not a tensor, for example List return checked_type diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index cea6ffdf8a1d..243eace0fb94 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -29,16 +29,21 @@ def get_tensor_array_shape(expr, dtype, prelude): """Get the static shape of a tensor array if it has fixed rank shape. + By design, static ADT tensor in TVM has type name in the format of static_tensor_dim0_dim1_..._dimN_t. + Parameters ---------- expr : Relay Expr Input expression. + dtype : str Data type. + prelude : Prelude Tensor array prelude + Returns ------- shape : tuple of (int, Any) or None @@ -65,7 +70,6 @@ def get_tensor_array_shape(expr, dtype, prelude): return tuple(shape) return None - def _get_name_static(canonical, dtype, shape): """Get name for static shape tensor array op corresponding to the canonical name""" @@ -200,6 +204,7 @@ def define_tensor_concatenate(self): self.prelude.mod[concat_var] = \ Function([x, y], Match(x, [case], False), tensor_type_var(), []) + def define_tensor_expand_dims(self): """Defines a function to grow a tensor_t's rank by adding one dimension in front of the original tensor_t. @@ -389,6 +394,7 @@ def define_tensor_array_split(self, take_var = self.get_var('tensor_take') self.shape = origin_shape + ta1 = Var("tensor_array", self.prelude.l(output_tensor_type_var())) value1 = Var('value1', value_type_var()) offset1 = Var('offset1', scalar_type('int32')) From 12740bb6bb44592ebae5a0e19c0f56eafacb1a2f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 11 Apr 2020 16:06:25 +0900 Subject: [PATCH 27/31] fix stacked bidir test --- python/tvm/relay/frontend/pytorch.py | 1 - tests/python/frontend/pytorch/lstm_test.py | 45 +++++++++------------- 2 files changed, 19 insertions(+), 27 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index c0003e3ae2c5..a604d16e5235 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1799,7 +1799,6 @@ def get_input(index): def get_var(name, val): if val is not None: - print(val) checked_type = _infer_type_with_prelude(val, prelude) return _expr.var(name, type_annotation=checked_type) return _expr.var(name) diff --git a/tests/python/frontend/pytorch/lstm_test.py b/tests/python/frontend/pytorch/lstm_test.py index 672cd44fc6cc..b72c07164336 100644 --- a/tests/python/frontend/pytorch/lstm_test.py +++ b/tests/python/frontend/pytorch/lstm_test.py @@ -190,7 +190,7 @@ def bidir_lstm(input_size, hidden_size): def stacked_bidir_lstm(input_size, hidden_size, num_layers): return StackedBidirLSTM(num_layers, BidirLSTMLayer, first_layer_args=[LayerNormLSTMCell, input_size, hidden_size], - other_layer_args=[LayerNormLSTMCell, hidden_size * 2, hidden_size]) + other_layer_args=[LayerNormLSTMCell, hidden_size, hidden_size]) def vmobj_to_list(o, dtype="float32"): @@ -272,36 +272,33 @@ def custom_lstm_test(): input_size = 3 hidden_size = 4 num_layers = 3 + state_tensor_shape = (batch, hidden_size) inp = torch.randn(seq_len, batch, input_size) input_shapes = [(input_name, (seq_len, batch, input_size)), - (states_name, ((batch, hidden_size), (batch, hidden_size)))] + (states_name, (state_tensor_shape, state_tensor_shape))] input_shapes_stacked = [(input_name, (seq_len, batch, input_size)), - (states_name, [((batch, hidden_size), (batch, hidden_size)), - ((batch, hidden_size), (batch, hidden_size))])] + (states_name, [(state_tensor_shape, state_tensor_shape), + (state_tensor_shape, state_tensor_shape)])] input_shapes_stacked_bidir = [(input_name, (seq_len, batch, input_size)), - (states_name, [((batch, hidden_size), (batch, hidden_size)), - ((batch, hidden_size), (batch, hidden_size))])] + (states_name, [[(state_tensor_shape, + state_tensor_shape) + for _ in range(2)] + for _ in range(num_layers)])] - input_shapes_stacked_bidir= [(input_name, (seq_len, batch, input_size)), - (states_name, [[((batch, hidden_size), - (batch, hidden_size)) - for _ in range(2)] - for _ in range(num_layers)])] - - states = [(torch.randn(batch, hidden_size), - torch.randn(batch, hidden_size)) + states = [(torch.randn(state_tensor_shape), + torch.randn(state_tensor_shape)) for _ in range(num_layers)] - bidir_states = [(torch.randn(batch, hidden_size), - torch.randn(batch, hidden_size)) + bidir_states = [(torch.randn(state_tensor_shape), + torch.randn(state_tensor_shape)) for _ in range(2)] - stacked_bidir_states = [[(torch.randn(batch, hidden_size), - torch.randn(batch, hidden_size)) + stacked_bidir_states = [[(torch.randn(state_tensor_shape), + torch.randn(state_tensor_shape)) for _ in range(2)] for _ in range(num_layers)] @@ -309,14 +306,13 @@ def custom_lstm_test(): (lstm(input_size, hidden_size).eval(), states[0], input_shapes), (stacked_lstm(input_size, hidden_size, num_layers).eval(), states, input_shapes_stacked), (bidir_lstm(input_size, hidden_size).eval(), bidir_states, input_shapes_stacked), - # (stacked_bidir_lstm(input_size, hidden_size, num_layers).eval(), - # stacked_bidir_states, input_shapes_stacked_bidir) + (stacked_bidir_lstm(input_size, hidden_size, num_layers).eval(), + stacked_bidir_states, input_shapes_stacked_bidir) ] for (raw_model, states, input_shapes) in models: script_module = torch.jit.script(raw_model) mod, params = from_pytorch(script_module, input_shapes) - # print(mod["main"]) with torch.no_grad(): pt_result = raw_model(inp.clone(), states) @@ -331,8 +327,8 @@ def custom_lstm_test(): states_np = [tuple(st.numpy() for st in states[i]) for i in range(len(states))] elif isinstance(states, list) and isinstance(states[0], list): - states_np = [[tuple(st.numpy() for st in states[i]) - for i in range(len(states[layer]))] + states_np = [[tuple(st.numpy() for st in states) + for states in states[layer]] for layer in range(num_layers)] else: assert False @@ -343,6 +339,3 @@ def custom_lstm_test(): params[states_name] = states_np run_and_compare(mod, params, pt_result) - - -custom_lstm_test() From b7bce2bbdc49d75a83aeaac3f97d346a4fe9b5e5 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 11 Apr 2020 16:11:36 +0900 Subject: [PATCH 28/31] minor fix to test --- python/tvm/relay/frontend/pytorch.py | 11 ++++------- tests/python/frontend/pytorch/lstm_test.py | 6 ------ 2 files changed, 4 insertions(+), 13 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index a604d16e5235..05bb6ea100fb 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -83,8 +83,6 @@ def _should_construct_dynamic_list(list_construct_node): def is_used_by_list_add(uses): for use in uses: op_name = use.user.kind() - if op_name == "prim::Loop": - continue output_type = _get_node_type(use.user) if op_name in ["aten::add", "aten::add_"] and output_type == "ListType": return True @@ -112,7 +110,7 @@ def inplace_add_to_add(op_name): if len(intersect) > 0 and intersect != set(["aten::add"]): return True - if is_used_by_list_add(uses): + if is_used_by_list_add(filter(lambda use: use.user.kind() != "prim::Loop", uses)): return True return False @@ -187,7 +185,7 @@ def _concatenate(prelude): def tensor_array_concat(lst, axis): assert axis == 0, "Tensor array concat supported only for axis 0" shape = _infer_type_with_prelude(prelude.hd(lst), prelude).shape - concat_shape = (Any(), ) + tuple(shape[1:]) + concat_shape = (Any(),) + tuple(shape[1:]) tensor_array = _map_tensor_array_constructor(lst, prelude, shape) static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", concat_shape) @@ -1218,7 +1216,6 @@ def _impl(inputs, input_types): def _tensor_array_stack(prelude): def _impl(inputs, input_types): - # TODO: check inputs[0] is a ADT List[TensorType] tensor_array = _convert_to_tensor_array(inputs[0], prelude) shape = get_tensor_array_shape(tensor_array, "float32", prelude) stack = prelude.get_var_static('tensor_array_stack', "float32", shape) @@ -1304,7 +1301,7 @@ def _convert_elemwise_input(data, input_type): return data def _wrap_const(c): - if not isinstance(c, _expr.Expr) and not isinstance(c, (list, tvm.tir.expr.Any)): + if not isinstance(c, (_expr.Expr, list, tvm.tir.expr.Any)): return _expr.const(c) return c @@ -1798,7 +1795,7 @@ def get_input(index): outputs.update(name_val_pairs) def get_var(name, val): - if val is not None: + if val: checked_type = _infer_type_with_prelude(val, prelude) return _expr.var(name, type_annotation=checked_type) return _expr.var(name) diff --git a/tests/python/frontend/pytorch/lstm_test.py b/tests/python/frontend/pytorch/lstm_test.py index b72c07164336..95694ed30e99 100644 --- a/tests/python/frontend/pytorch/lstm_test.py +++ b/tests/python/frontend/pytorch/lstm_test.py @@ -211,19 +211,13 @@ def assert_equal(tvm_result, torch_result): for tvm_res, pt_res in zip(tvm_result, torch_result): assert_equal(tvm_res, pt_res) elif isinstance(torch_result, torch.Tensor): - print(np.max(np.abs(tvm_result.asnumpy() - torch_result.numpy()))) tvm.testing.assert_allclose(tvm_result.asnumpy(), torch_result.numpy(), rtol=1e-5, atol=1e-5) - else: - tvm_res = np.asscalar(tvm_result.asnumpy()) - print(abs(torch_result - tvm_res)) - assert torch_result == tvm_res def run_and_compare(mod, params, pt_result): executor = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(0), target="llvm") evaluator = executor.evaluate() - exec_res = evaluator(**params) def flatten(nested): From 6bcf0f193a0a45c97faa64011b40ce701a7d81ee Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 12 Apr 2020 19:16:09 +0900 Subject: [PATCH 29/31] relax tol a bit, revert dnnl change to avoid conflict --- src/relay/backend/contrib/dnnl/codegen.cc | 6 ++++-- tests/python/frontend/pytorch/lstm_test.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index 2bf3a3284ee1..7f3aabf6e016 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -194,8 +194,10 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase { } out_.clear(); - buf_decl_.insert(buf_decl_.end(), ret.buffers.begin(), ret.buffers.end()); - out_.insert(out_.end(), ret.outputs.begin(), ret.outputs.end()); + for (size_t i = 0; i < ret.outputs.size(); ++i) { + buf_decl_.push_back(ret.buffers[i]); + out_.push_back(ret.outputs[i]); + } ext_func_body.push_back(ret.decl); } diff --git a/tests/python/frontend/pytorch/lstm_test.py b/tests/python/frontend/pytorch/lstm_test.py index 95694ed30e99..4616698c687f 100644 --- a/tests/python/frontend/pytorch/lstm_test.py +++ b/tests/python/frontend/pytorch/lstm_test.py @@ -212,7 +212,7 @@ def assert_equal(tvm_result, torch_result): assert_equal(tvm_res, pt_res) elif isinstance(torch_result, torch.Tensor): tvm.testing.assert_allclose(tvm_result.asnumpy(), torch_result.numpy(), - rtol=1e-5, atol=1e-5) + rtol=1e-4, atol=1e-4) def run_and_compare(mod, params, pt_result): From 3074c9a646418aa8957965fbee2b44966a766694 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 13 Apr 2020 09:43:22 +0900 Subject: [PATCH 30/31] simplify infer type, use input tensor shape rather than concat shape --- python/tvm/relay/frontend/pytorch.py | 29 ++++++++++++---------------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 05bb6ea100fb..0c38aca1c47d 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -29,14 +29,13 @@ from .. import analysis as _analysis from .. import expr as _expr from .. import op as _op -from ..function import Function -from .. import transform from ..ty import TupleType, TensorType, Any from ..loops import while_loop from .common import get_relay_op from .common import infer_shape as _infer_shape from .common import infer_value as _infer_value -from ..prelude import Prelude, StaticTensorArrayOps, get_tensor_array_shape +from .common import infer_type as _infer_type +from ..prelude import Prelude, StaticTensorArrayOps from . import qnn_torch @@ -45,11 +44,8 @@ # List ADT utilities def _infer_type_with_prelude(val, prelude): - mod = prelude.mod - func = Function([], val) - mod["main"] = func - mod = transform.InferType()(mod) - return mod["main"].body.checked_type + body = _infer_type(val, prelude.mod) + return body.checked_type def _convert_to_list_adt(py_lst, prelude): @@ -74,8 +70,10 @@ def _convert_to_tensor_array(adt_lst, prelude): if prelude.length(adt_lst) == 0: return prelude.nil() - shape = _infer_type_with_prelude(prelude.hd(adt_lst), prelude).shape - return _map_tensor_array_constructor(adt_lst, prelude, shape) + checked_type = _infer_type_with_prelude(prelude.hd(adt_lst), prelude) + shape = checked_type.shape + tensor_array = _map_tensor_array_constructor(adt_lst, prelude, shape) + return tensor_array, tuple(shape) def _should_construct_dynamic_list(list_construct_node): @@ -184,11 +182,9 @@ def _impl(inputs, input_types): def _concatenate(prelude): def tensor_array_concat(lst, axis): assert axis == 0, "Tensor array concat supported only for axis 0" - shape = _infer_type_with_prelude(prelude.hd(lst), prelude).shape - concat_shape = (Any(),) + tuple(shape[1:]) - - tensor_array = _map_tensor_array_constructor(lst, prelude, shape) - static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", concat_shape) + tensor_array, shape = _convert_to_tensor_array(lst, prelude) + concat_shape = (Any(),) + shape[1:] + static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape) static_tensor_array_ops.define_tensor_get_data(concat_shape) concat = prelude.get_var_static('tensor_array_concat', "float32", concat_shape) @@ -1216,8 +1212,7 @@ def _impl(inputs, input_types): def _tensor_array_stack(prelude): def _impl(inputs, input_types): - tensor_array = _convert_to_tensor_array(inputs[0], prelude) - shape = get_tensor_array_shape(tensor_array, "float32", prelude) + tensor_array, shape = _convert_to_tensor_array(inputs[0], prelude) stack = prelude.get_var_static('tensor_array_stack', "float32", shape) stacked = stack(tensor_array) From c185b4e27862984597fcd789166138293b1bee2f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 13 Apr 2020 11:23:47 +0900 Subject: [PATCH 31/31] more shape fix --- python/tvm/relay/frontend/pytorch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 0c38aca1c47d..506f6ba3ceb7 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -187,9 +187,9 @@ def tensor_array_concat(lst, axis): static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape) static_tensor_array_ops.define_tensor_get_data(concat_shape) - concat = prelude.get_var_static('tensor_array_concat', "float32", concat_shape) + concat = prelude.get_var_static('tensor_array_concat', "float32", shape) concatenated = concat(tensor_array) - get_tensor = prelude.get_var_static('tensor_get_data', "float32", concat_shape) + get_tensor = prelude.get_var_static('tensor_get_data', "float32", shape) return get_tensor(concatenated) def _impl(inputs, input_types):