diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index a542ccc48af0..506f6ba3ceb7 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -25,20 +25,95 @@ import numpy as np import tvm -from tvm.ir import module as _module from .. import analysis as _analysis from .. import expr as _expr from .. import op as _op +from ..ty import TupleType, TensorType, Any from ..loops import while_loop from .common import get_relay_op from .common import infer_shape as _infer_shape from .common import infer_value as _infer_value +from .common import infer_type as _infer_type +from ..prelude import Prelude, StaticTensorArrayOps from . import qnn_torch __all__ = ["from_pytorch"] + +# List ADT utilities +def _infer_type_with_prelude(val, prelude): + body = _infer_type(val, prelude.mod) + return body.checked_type + + +def _convert_to_list_adt(py_lst, prelude): + elem_tys = [_infer_type_with_prelude(elem, prelude) for elem in py_lst] + msg = "List elements should have identical types" + assert all(map(lambda ty: ty == elem_tys[0], elem_tys)), msg + + adt_lst = prelude.nil() + for elem in reversed(py_lst): + adt_lst = prelude.cons(elem, adt_lst) + return adt_lst + + +def _map_tensor_array_constructor(adt_lst, prelude, shape): + static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape) + static_tensor_array_ops.register() + tensor_create = prelude.get_var_static('tensor_constructor', "float32", shape) + return prelude.map(tensor_create, adt_lst) + + +def _convert_to_tensor_array(adt_lst, prelude): + if prelude.length(adt_lst) == 0: + return prelude.nil() + + checked_type = _infer_type_with_prelude(prelude.hd(adt_lst), prelude) + shape = checked_type.shape + tensor_array = _map_tensor_array_constructor(adt_lst, prelude, shape) + return tensor_array, tuple(shape) + + +def _should_construct_dynamic_list(list_construct_node): + # if this list is element-accessed or modified at runtime, generate List ADT + def is_used_by_list_add(uses): + for use in uses: + op_name = use.user.kind() + output_type = _get_node_type(use.user) + if op_name in ["aten::add", "aten::add_"] and output_type == "ListType": + return True + return False + + def inplace_add_to_add(op_name): + if op_name == "aten::add_": + return "aten::add" + else: + return op_name + + uses = _get_uses(list_construct_node) + + for loop_use in filter(lambda use: use.user.kind() == "prim::Loop", uses): + block_input_index = loop_use.offset - 1 + block = list(loop_use.user.blocks())[0] + list_loop_var = list(block.inputs())[block_input_index] + uses += _get_uses(list_loop_var.node()) + + op_names = map(inplace_add_to_add, set(use.user.kind() for use in uses)) + + list_ops = set(["aten::add", "aten::__getitem__", "aten::stack"]) + intersect = list_ops.intersection(op_names) + + if len(intersect) > 0 and intersect != set(["aten::add"]): + return True + + if is_used_by_list_add(filter(lambda use: use.user.kind() != "prim::Loop", uses)): + return True + + return False + + # operator implementation def _elemwise(name): def _impl(inputs, input_types): @@ -103,11 +178,27 @@ def _impl(inputs, input_types): return _op.transform.expand_dims(data, int(axis), 1) return _impl -def _concatenate(): + +def _concatenate(prelude): + def tensor_array_concat(lst, axis): + assert axis == 0, "Tensor array concat supported only for axis 0" + tensor_array, shape = _convert_to_tensor_array(lst, prelude) + concat_shape = (Any(),) + shape[1:] + static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape) + static_tensor_array_ops.define_tensor_get_data(concat_shape) + + concat = prelude.get_var_static('tensor_array_concat', "float32", shape) + concatenated = concat(tensor_array) + get_tensor = prelude.get_var_static('tensor_get_data', "float32", shape) + return get_tensor(concatenated) + def _impl(inputs, input_types): data = inputs[0] axis = inputs[1] + if not isinstance(data, list): + return tensor_array_concat(data, axis) + if isinstance(data, _expr.Expr): data = [data] @@ -130,7 +221,7 @@ def _impl(inputs, input_types): else: end = data.shape - begin = [0]*len(end) + begin = [0] * len(end) dim = int(inputs[1]) begin[dim] = int(inputs[2]) @@ -371,7 +462,7 @@ def _impl(inputs, input_types): ceil_mode = int(inputs[5]) if dilation != (1, 1): - msg = "MaxPool2d with dilation %s is not implemented" % (str(dilation), ) + msg = "MaxPool2d with dilation %s is not implemented" % (str(dilation)) raise NotImplementedError(msg) return _op.nn.max_pool2d(data, pool_size, strides, padding, "NCHW", ceil_mode) @@ -388,7 +479,7 @@ def _impl(inputs, input_types): ceil_mode = int(inputs[5]) if dilation != (1,): - msg = "MaxPool1d with dilation %s is not implemented" % (str(dilation), ) + msg = "MaxPool1d with dilation %s is not implemented" % (str(dilation)) raise NotImplementedError(msg) return _op.nn.max_pool1d(data, pool_size, strides, padding, "NCW", ceil_mode) @@ -404,7 +495,7 @@ def _impl(inputs, input_types): dilation = _infer_shape(inputs[4]) ceil_mode = int(inputs[5]) if dilation != (1, 1, 1): - msg = "MaxPool3d with dilation %s is not implemented" % (str(dilation), ) + msg = "MaxPool3d with dilation %s is not implemented" % (str(dilation)) raise NotImplementedError(msg) return _op.nn.max_pool3d(data, @@ -618,13 +709,13 @@ def _impl(inputs, input_types): scale=True) return _impl -def _transpose(): +def _transpose(prelude): def _impl(inputs, input_types): data = inputs[0] import torch if isinstance(data, _expr.Expr): - ndims = len(_infer_shape(data)) + ndims = len(_infer_shape(data, prelude.mod)) elif isinstance(data, list): ndims = data elif isinstance(data, (torch.Tensor, np.ndarray)): @@ -693,15 +784,30 @@ def _impl(inputs, input_types): return dense_out return _impl -def _size(): + +def _size(prelude): + def _impl_dynamic(inp, axis): + shape_dynamic = _op.shape_of(inp) + if axis is not None: + return _op.take(shape_dynamic, _expr.const(axis), 0) + return shape_dynamic + def _impl(inputs, input_types): - shape = _infer_shape(inputs[0]) + shape = _infer_shape(inputs[0], prelude.mod) + axis = None if len(inputs) > 1: axis = int(inputs[1]) + + if any(map(lambda s: isinstance(s, tvm.tir.expr.Any), shape)): + if axis is None or isinstance(shape[axis], tvm.tir.expr.Any): + return _impl_dynamic(inputs[0], axis) + + if axis is not None: return shape[axis] return shape return _impl + def _numtotensor(): def _impl(inputs, input_types): val = inputs[0] @@ -862,7 +968,7 @@ def func(x): return _impl -def _chunk(): +def _chunk(prelude): def _impl(inputs, input_types): data = inputs[0] @@ -870,7 +976,7 @@ def _impl(inputs, input_types): axis = int(inputs[2]) if isinstance(data, _expr.Expr): - inferred_shape = _infer_shape(data) + inferred_shape = _infer_shape(data, prelude.mod) shape = [] for infer in inferred_shape: @@ -894,7 +1000,6 @@ def _impl(inputs, input_types): chunk_out = _op.transform.strided_slice(data, begin, end, stride) chunks.append(chunk_out) - if dim % num_chunks: begin = [0] * len(shape) end = shape[:] @@ -1077,6 +1182,49 @@ def _impl(inputs, input_types): return _op.cast(inputs[0], "float32") return _impl + +def _mm(): + def _impl(inputs, input_types): + return _op.nn.dense(inputs[0], inputs[1]) + return _impl + + +def _list_getitem(prelude): + def _impl(inputs, input_types): + return prelude.nth(inputs[0], _wrap_const(inputs[1])) + return _impl + + +def _list_len(prelude): + def _impl(inputs, input_types): + return prelude.length(inputs[0]) + return _impl + + +def _add(prelude): + # add_ is overloaded for tensor add and list concat + def _impl(inputs, input_types): + if input_types[0] == "ListType": + return prelude.concat(inputs[0], inputs[1]) + return _elemwise("add")(inputs, input_types) + return _impl + + +def _tensor_array_stack(prelude): + def _impl(inputs, input_types): + tensor_array, shape = _convert_to_tensor_array(inputs[0], prelude) + stack = prelude.get_var_static('tensor_array_stack', "float32", shape) + stacked = stack(tensor_array) + + stacked_shape = (Any(),) + shape + static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape) + static_tensor_array_ops.define_tensor_get_data(stacked_shape) + # passing stacked_shape below gives "'Prelude' object has no attribute" error + get_tensor = prelude.get_var_static('tensor_get_data', "float32", shape) + return get_tensor(stacked) + return _impl + + # Helper functions for operator implementation def _convert_dtype_value(val): convert_torch_dtype_map = {7:"torch.float64", @@ -1148,112 +1296,117 @@ def _convert_elemwise_input(data, input_type): return data def _wrap_const(c): - if not isinstance(c, _expr.Expr) and not isinstance(c, list): + if not isinstance(c, (_expr.Expr, list, tvm.tir.expr.Any)): return _expr.const(c) return c # Operator mappings - -_convert_map = { - "aten::device" : _none(), - "aten::add" : _elemwise("add"), - "aten::add_" : _elemwise("add"), - "aten::sub" : _elemwise("subtract"), - "aten::sub_" : _elemwise("subtract"), - "aten::max" : _elemwise("maximum"), - "aten::min" : _elemwise("minimum"), - "aten::mul" : _elemwise("multiply"), - "aten::mul_" : _elemwise("multiply"), - "aten::pow" : _elemwise("power"), - "aten::div" : _elemwise("divide"), - "aten::div_" : _elemwise("divide"), - "aten::abs" : _abs(), - "aten::arange" : _arange(), - "aten::ones" : _ones(), - "aten::zeros" : _zeros(), - "aten::reciprocal" : _reciprocal(), - "aten::repeat" : _repeat(), - "aten::repeat_interleave" : _repeat_interleave(), - "aten::to" : _to(), - "aten::squeeze" : _squeeze(), - "aten::unsqueeze" : _unsqueeze(), - "aten::cat" : _concatenate(), - "aten::slice" : _slice(), - "aten::split" : _split(), - "aten::split_with_sizes" : _split_with_sizes(), - "aten::select" : _select(), - "aten::relu" : _relu(), - "aten::relu_" : _relu(), - "aten::prelu" : _prelu(), - "aten::leaky_relu" : _leaky_relu(), - "aten::elu" : _elu(), - "aten::celu" : _celu(), - "aten::gelu" : _gelu(), - "aten::selu" : _selu(), - "aten::log_sigmoid" : _log_sigmoid(), - "aten::adaptive_avg_pool2d" : _adaptive_avg_pool_2d(), - "aten::adaptive_max_pool2d" : _adaptive_max_pool_2d(), - "aten::max_pool2d" : _maxpool_2d(), - "aten::max_pool2d_with_indices" : _maxpool_2d(), - "aten::max_pool1d" : _maxpool_1d(), - "aten::max_pool3d" : _maxpool_3d(), - "aten::hardtanh" : _hardtanh(), - "aten::hardtanh_" : _hardtanh(), - "aten::_convolution" : _convolution(), - "aten::softmax" : _softmax(), - "aten::threshold" : _threshold(), - "aten::threshold_" : _threshold(), - "aten::contiguous" : _contiguous(), - "aten::batch_norm" : _batch_norm(), - "aten::instance_norm" : _instance_norm(), - "aten::layer_norm" : _layer_norm(), - "aten::transpose" : _transpose(), - "aten::transpose_" : _transpose(), - "aten::t" : _transpose(), - "aten::flatten" : _flatten(), - "aten::addmm" : _dense(), - "aten::size" : _size(), - "aten::view" : _view(), - "aten::reshape" : _reshape(), - "aten::clone" : _clone(), - "aten::log_softmax" : _log_softmax(), - "aten::sigmoid" : _sigmoid(), - "aten::softplus" : _softplus(), - "aten::avg_pool2d" : _avg_pool2d(), - "aten::avg_pool3d" : _avg_pool3d(), - "aten::dropout" : _dropout(), - "aten::dropout_" : _dropout(), - "aten::feature_dropout" : _dropout(), - "aten::alpha_dropout" : _dropout(), - "aten::mean" : _mean(), - "aten::chunk" : _chunk(), - "aten::matmul" : _matmul(), - "aten::expand" : _expand(), - "aten::Int" : _int(), - "prim::NumToTensor" : _numtotensor(), - "prim::ListUnpack" : _identity(), - "aten::constant_pad_nd" : _pad(), - "aten::permute" : _transpose(), - "aten::sum" : _reduce("sum"), - "aten::prod" : _reduce("prod"), - "aten::sqrt" : _sqrt(), - 'aten::floor' : _floor(), - "aten::detach" : _identity(), - "aten::upsample_bilinear2d" : _upsample("bilinear"), - "aten::upsample_nearest2d" : _upsample("nearest_neighbor"), - "aten::expand_as" : _expand_as(), - "aten::lt" : _elemwise("less"), - "aten::gt" : _elemwise("greater"), - "aten::le" : _elemwise("less_equal"), - "aten::ge" : _elemwise("greater_equal"), - "aten::ne" : _elemwise("not_equal"), - "aten::Bool" : _Bool(), - "aten::Float" : _Float(), - "aten::neg" : _neg(), - "aten::tanh" : _tanh(), - "aten::adaptive_avg_pool3d" : _adaptive_avg_pool_3d(), - "aten::adaptive_max_pool3d" : _adaptive_max_pool_3d() -} +def _get_convert_map(prelude): + convert_map = { + "aten::device" : _none(), + "aten::sub" : _elemwise("subtract"), + "aten::sub_" : _elemwise("subtract"), + "aten::max" : _elemwise("maximum"), + "aten::min" : _elemwise("minimum"), + "aten::mul" : _elemwise("multiply"), + "aten::mul_" : _elemwise("multiply"), + "aten::pow" : _elemwise("power"), + "aten::abs" : _abs(), + "aten::arange" : _arange(), + "aten::div" : _elemwise("divide"), + "aten::div_" : _elemwise("divide"), + "aten::ones" : _ones(), + "aten::zeros" : _zeros(), + "aten::reciprocal" : _reciprocal(), + "aten::repeat" : _repeat(), + "aten::repeat_interleave" : _repeat_interleave(), + "aten::to" : _to(), + "aten::squeeze" : _squeeze(), + "aten::unsqueeze" : _unsqueeze(), + "aten::cat" : _concatenate(prelude), + "aten::slice" : _slice(), + "aten::split" : _split(), + "aten::split_with_sizes" : _split_with_sizes(), + "aten::select" : _select(), + "aten::relu" : _relu(), + "aten::relu_" : _relu(), + "aten::prelu" : _prelu(), + "aten::leaky_relu" : _leaky_relu(), + "aten::elu" : _elu(), + "aten::celu" : _celu(), + "aten::gelu" : _gelu(), + "aten::selu" : _selu(), + "aten::log_sigmoid" : _log_sigmoid(), + "aten::adaptive_avg_pool2d" : _adaptive_avg_pool_2d(), + "aten::adaptive_max_pool2d" : _adaptive_max_pool_2d(), + "aten::max_pool2d" : _maxpool_2d(), + "aten::max_pool2d_with_indices" : _maxpool_2d(), + "aten::max_pool1d" : _maxpool_1d(), + "aten::max_pool3d" : _maxpool_3d(), + "aten::hardtanh" : _hardtanh(), + "aten::hardtanh_" : _hardtanh(), + "aten::_convolution" : _convolution(), + "aten::softmax" : _softmax(), + "aten::threshold" : _threshold(), + "aten::threshold_" : _threshold(), + "aten::contiguous" : _contiguous(), + "aten::batch_norm" : _batch_norm(), + "aten::instance_norm" : _instance_norm(), + "aten::layer_norm" : _layer_norm(), + "aten::transpose" : _transpose(prelude), + "aten::transpose_" : _transpose(prelude), + "aten::t" : _transpose(prelude), + "aten::flatten" : _flatten(), + "aten::addmm" : _dense(), + "aten::size" : _size(prelude), + "aten::view" : _view(), + "aten::reshape" : _reshape(), + "aten::clone" : _clone(), + "aten::log_softmax" : _log_softmax(), + "aten::sigmoid" : _sigmoid(), + "aten::softplus" : _softplus(), + "aten::avg_pool2d" : _avg_pool2d(), + "aten::avg_pool3d" : _avg_pool3d(), + "aten::dropout" : _dropout(), + "aten::dropout_" : _dropout(), + "aten::feature_dropout" : _dropout(), + "aten::alpha_dropout" : _dropout(), + "aten::mean" : _mean(), + "aten::chunk" : _chunk(prelude), + "aten::matmul" : _matmul(), + "aten::expand" : _expand(), + "aten::Int" : _int(), + "prim::NumToTensor" : _numtotensor(), + "aten::constant_pad_nd" : _pad(), + "aten::permute" : _transpose(prelude), + "aten::sum" : _reduce("sum"), + "aten::prod" : _reduce("prod"), + "aten::sqrt" : _sqrt(), + 'aten::floor' : _floor(), + "aten::detach" : _identity(), + "aten::upsample_bilinear2d" : _upsample("bilinear"), + "aten::upsample_nearest2d" : _upsample("nearest_neighbor"), + "aten::expand_as" : _expand_as(), + "aten::lt" : _elemwise("less"), + "aten::gt" : _elemwise("greater"), + "aten::le" : _elemwise("less_equal"), + "aten::ge" : _elemwise("greater_equal"), + "aten::ne" : _elemwise("not_equal"), + "aten::Bool" : _Bool(), + "aten::Float" : _Float(), + "aten::neg" : _neg(), + "aten::tanh" : _tanh(), + "aten::adaptive_avg_pool3d" : _adaptive_avg_pool_3d(), + "aten::adaptive_max_pool3d" : _adaptive_max_pool_3d(), + "aten::mm" : _matmul(), + "relay::tensor_array_stack" : _tensor_array_stack(prelude), + "aten::add" : _add(prelude), + "aten::add_" : _add(prelude), + "aten::stack" : _tensor_array_stack(prelude), + "aten::__getitem__" : _list_getitem(prelude), + "aten::len" : _list_len(prelude), + } + return convert_map def _run_jit_passes(graph): @@ -1289,13 +1442,29 @@ def _get_op_inputs(op_node, outputs): return [outputs[name] for name in _get_input_names(op_node)] -def _report_missing_conversion(op_names): +def _get_node_type(node): + assert node.outputsSize() == 1 + return node.output().type().kind() + + +def _get_uses(node): + uses = [] + for output in node.outputs(): + uses += output.uses() + return uses + + +def _get_users(node): + return [use.user for use in _get_uses(node)] + + +def _report_missing_conversion(op_names, convert_map): """ Check if all ops in an input graph are supported by TVM """ known_ops = ["prim::Constant", "prim::GetAttr", "prim::ListConstruct", "prim::ListUnpack", "prim::TupleConstruct", "prim::TupleUnpack", "prim::If", "prim::Loop"] - known_ops += list(_convert_map.keys()) + known_ops += list(convert_map.keys()) known_ops += list(qnn_torch.convert_map.keys()) missing = [op_name for op_name in op_names @@ -1361,7 +1530,7 @@ def _get_input_types(op_node): input_list_types.append(in_ty.scalarType().lower()) elif input_node_kind == 'ListType': - input_list_types.append(str(in_ty.getElementType()).lower()) + input_list_types.append("ListType") elif input_node_kind in ['IntType', 'FloatType', 'BoolType', 'StringType', 'OptionalType']: input_list_types.append(str(in_ty).lower()) @@ -1422,21 +1591,69 @@ def _get_operator_nodes(nodes): return ops -def _get_relay_input_vars(graph, input_shapes): +def _get_graph_input_names(graph): + """ Get the graph input names (use after graph copy and run jit passes) """ + # Variable names could change the first time a copy is made and after + # _run_jit_passes is called, expected that those functions already invoked + ir_inputs = _get_input_names(graph) + return ir_inputs[1:] # remove self at the 0th arg + + +def _get_relay_input_vars(graph, input_shapes, prelude): """ Return Relay vars from input shapes and create entries based on expected graph inputs - to allow translation """ + def get_relay_ty(ishape): + if _is_int_seq(ishape) or len(ishape) == 0: + return TensorType(ishape) + elif isinstance(ishape, tuple): + return TupleType([get_relay_ty(elem) for elem in ishape]) + elif isinstance(ishape, list): + assert len(ishape) > 0 + elem_tys = [get_relay_ty(s) for s in ishape] + msg = "List elements should have identical types" + assert all(map(lambda ty: ty == elem_tys[0], elem_tys)), msg + return prelude.l(elem_tys[0]) + raise NotImplementedError("unsupported input type") + + input_types = [(tup[0], get_relay_ty(tup[1])) for tup in input_shapes] input_vars = {} ir_inputs = _get_graph_input_names(graph) - for ir_input, (name, shape) in zip(ir_inputs, input_shapes): - inp = _expr.var(name, shape=shape) + for ir_input, (name, itype) in zip(ir_inputs, input_types): + inp = _expr.var(name, type_annotation=itype) # Translate from graph input to user input name input_vars[ir_input] = inp return input_vars +def _unpack_tuple(tup): + def unpack(tup, num_fields): + return [_expr.TupleGetItem(tup, i) for i in range(num_fields)] + + if isinstance(tup, _expr.Tuple): + return unpack(tup, len(tup.fields)) + elif isinstance(tup.type_annotation, TupleType): + return unpack(tup, len(tup.type_annotation.fields)) + # shouldn't happen + assert False + + +def _get_free_vars_from_block(block): + block_inp_names = _get_input_names(block) + bound_names = block_inp_names + free_vars = set() + + for node in block.nodes(): + inp_names = _get_input_names(node) + list_diff = [name for name in inp_names if name not in bound_names] + free_vars.update(list_diff) + bound_names += _get_output_names(node) + + return free_vars + + def get_use_chains(root_node, terminate=lambda _: False): """ Track a chain of users of this node forward, returning a list of chains @@ -1446,9 +1663,7 @@ def concat_lists(lists): return itertools.chain.from_iterable(lists) def inner(current, accum): - users = [] - for output in current.outputs(): - users += [use.user for use in output.uses()] + users = _get_users(current) if not users or terminate(users): return [accum] @@ -1512,24 +1727,24 @@ def convert_params(graph, state_dict): return params, param_tensors, packed_param_map -def convert_block(block, outputs): +def convert_block(block, outputs, convert_map, prelude): """ Translate Torch "Block", used for prim::If and prim::Loop """ ops = _get_operator_nodes(block.nodes()) ret_names = _get_input_names(block.returnNode()) - return convert_operators(ops, outputs, ret_names) + return convert_operators(ops, outputs, ret_names, convert_map, prelude) -def convert_if(if_node, outputs): +def convert_if(if_node, outputs, convert_map, prelude): """ Translate Torch prim::If to Relay If """ cond = outputs[if_node.inputsAt(0).debugName()] blocks = list(if_node.blocks()) - true_branch = convert_block(blocks[0], outputs) - false_branch = convert_block(blocks[1], outputs) + true_branch = convert_block(blocks[0], outputs, convert_map, prelude) + false_branch = convert_block(blocks[1], outputs, convert_map, prelude) assert len(true_branch) == 1 and len(false_branch) == 1 return _expr.If(cond, true_branch[0], false_branch[0]) -def convert_loop(loop_node, outputs): +def convert_loop(loop_node, outputs, convert_map, prelude): """ Translate Torch prim::Loop to Relay while_loop """ def get_input(index): ivalue = loop_node.inputsAt(index) @@ -1555,8 +1770,54 @@ def get_input(index): is_while_loop = (isinstance(max_loop_count, _expr.Constant) and _get_constant(loop_node.inputsAt(0).node()) == sys.maxsize) + if is_while_loop: + loop_iter_dtype = "bool" + # while loop with non input dependent condition such as while i < 10: + # init_cond is int, need to cast to bool to type check + if isinstance(init_cond, _expr.Constant): + init_cond = _op.cast(init_cond, "bool") + init_loop_iter_val = init_cond + else: + loop_iter_dtype = "int32" + # always count from 0 + init_loop_iter_val = _expr.const(0, dtype="int32") + body_block = list(loop_node.blocks())[0] block_input_names = _get_input_names(body_block) + num_block_inputs = len(block_input_names) + name_val_pairs = list(zip(block_input_names, + [init_loop_iter_val] + init_vals)) + outputs.update(name_val_pairs) + + def get_var(name, val): + if val: + checked_type = _infer_type_with_prelude(val, prelude) + return _expr.var(name, type_annotation=checked_type) + return _expr.var(name) + + loop_iter_var = _expr.var(block_input_names[0], shape=(), + dtype=loop_iter_dtype) + loop_vars = [get_var(name, val) for name, val in name_val_pairs[1:]] + + # Add non constant free variables to loop variables to prevent code blow up + # Without this, if there are two for loops in a row, which often happens + # if the outer loop is unrolled, the computation corresponding to the first for loop + # is inlined inside loop body, turning O(N) + O(N) computation into O(N^2). + # This issue was found when converting from Stacked LSTM test. Torch does not add the output + # of the eariler loop into loop variables of the next loop. + # So the variable corresponding to the first loop output appears free in the second loop body. + free_vars = [var for var in _get_free_vars_from_block(body_block) + if var in outputs and not isinstance(outputs[var], (_expr.Constant, int, float)) + and outputs[var]] + + prev_outputs = {} + for name in free_vars: + prev_output = outputs[name] + new_loop_var = get_var(name, prev_output) + prev_outputs[name] = prev_output + outputs[name] = new_loop_var + loop_vars.append(new_loop_var) + init_vals.append(prev_output) def cond(*current_vals): i = current_vals[0] @@ -1568,11 +1829,16 @@ def cond(*current_vals): def body(*current_vals): # Update loop variables using the prev iteration outputs - assert len(current_vals) == len(block_input_names) - for (i, iname) in enumerate(block_input_names): - outputs[iname] = current_vals[i] + assert len(current_vals) == num_block_inputs + len(free_vars) - block_outputs = convert_block(body_block, outputs) + for (i, val) in enumerate(current_vals): + if i < num_block_inputs: + outputs[block_input_names[i]] = val + else: + outputs[free_vars[i-num_block_inputs]] = val + + block_outputs = convert_block(body_block, outputs, convert_map, prelude) + block_outputs += [outputs[name] for name in free_vars] if not is_while_loop: # iter var increment implicit in torch, so do it manually @@ -1583,38 +1849,17 @@ def body(*current_vals): return block_outputs - def get_var(name, val): - if isinstance(val, _expr.Constant): - return _expr.var(name, shape=val.data.shape, dtype=val.data.dtype) - return _expr.var(name) - - if is_while_loop: - loop_iter_dtype = "bool" - # while loop with non input dependent condition such as while i < 10: - # init_cond is int, need to cast to bool to type check - if isinstance(init_cond, _expr.Constant): - init_cond = _op.cast(init_cond, "bool") - init_loop_iter_val = init_cond - else: - loop_iter_dtype = "int32" - # always count from 0 - init_loop_iter_val = _expr.const(0, dtype="int32") - - name_val_pairs = list(zip(block_input_names, - [init_loop_iter_val] + init_vals)) - outputs.update(name_val_pairs) - - loop_iter_var = _expr.var(block_input_names[0], shape=(), - dtype=loop_iter_dtype) - loop_vars = [get_var(name, val) for name, val in name_val_pairs[1:]] loop = while_loop(cond, [loop_iter_var] + loop_vars, body) loop_val = loop(init_loop_iter_val, *init_vals) + # restore original output values for free vars + outputs.update(prev_outputs) + # The first element is a loop counter or boolean condition, ignore it return [_expr.TupleGetItem(loop_val, i+1) for i in range(num_loop_var)] -def convert_operators(operators, outputs, ret_names): +def convert_operators(operators, outputs, ret_names, convert_map, prelude): """ Convert each Torch IR operators to Relay equivalent """ for node_name, op_node in operators: operator = op_node.kind() @@ -1622,24 +1867,33 @@ def convert_operators(operators, outputs, ret_names): if operator == "prim::Constant": outputs[node_name] = _get_constant(op_node) - elif operator == 'prim::ListConstruct' and _is_int_seq(inputs): + elif operator == "prim::ListConstruct" and _is_int_seq(inputs): outputs[node_name] = _expr.var(node_name, shape=inputs) - elif operator in ['prim::ListConstruct', 'prim::TupleConstruct']: + elif operator == "prim::ListConstruct" and _should_construct_dynamic_list(op_node): + outputs[node_name] = _convert_to_list_adt(inputs, prelude) + elif operator == "prim::ListConstruct": + # This assumes that no more elements will be appended to this list + # In this case, we keep the Python list outputs[node_name] = inputs - elif operator in ["prim::ListUnpack", 'prim::TupleUnpack']: + elif operator == "prim::TupleConstruct": + outputs[node_name] = _expr.Tuple(inputs) + elif operator in ["prim::ListUnpack", "prim::TupleUnpack"]: assert len(inputs) == 1 - unpacked_names = _get_output_names(op_node) - outputs.update(zip(unpacked_names, inputs[0])) + if isinstance(inputs[0], (list, _expr.TupleWrapper)): + unpacked = inputs[0] + else: + unpacked = _unpack_tuple(inputs[0]) + outputs.update(zip(_get_output_names(op_node), unpacked)) elif operator == "prim::If": - if_out = convert_if(op_node, outputs) + if_out = convert_if(op_node, outputs, convert_map, prelude) outputs[node_name] = if_out elif operator == "prim::Loop": - loop_out = convert_loop(op_node, outputs) + loop_out = convert_loop(op_node, outputs, convert_map, prelude) unpacked_names = _get_output_names(op_node) assert len(loop_out) == len(unpacked_names) outputs.update(zip(unpacked_names, loop_out)) else: - relay_op = _convert_map[operator] + relay_op = convert_map[operator] relay_out = relay_op(inputs, _get_input_types(op_node)) if isinstance(relay_out, tuple): @@ -1666,14 +1920,6 @@ def get_all_op_names(graph): return set(node.kind() for node in nodes) -def _get_graph_input_names(graph): - """ Get the graph input names (use after graph copy and run jit passes) """ - # Variable names could change the first time a copy is made and after - # _run_jit_passes is called, expected that those functions already invoked - ir_inputs = _get_input_names(graph) - return ir_inputs[1:] # remove self at the 0th arg - - def from_pytorch(script_module, input_shapes, custom_convert_map=None): """ Load PyTorch model in the form of a scripted PyTorch model and convert into relay. The companion parameters will be handled automatically. @@ -1700,18 +1946,23 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None): params : dict of str to tvm.runtime.NDArray Dict of converted parameters stored in tvm.runtime.ndarray format """ + mod = tvm.IRModule() + prelude = Prelude(mod) + + convert_map = _get_convert_map(prelude) + graph = script_module.graph.copy() _run_jit_passes(graph) if custom_convert_map: - _convert_map.update(custom_convert_map) + convert_map.update(custom_convert_map) op_names = get_all_op_names(graph) - _report_missing_conversion(op_names) + _report_missing_conversion(op_names, convert_map) _check_inputs(graph, input_shapes) params = script_module.state_dict() - outputs = _get_relay_input_vars(graph, input_shapes) + outputs = _get_relay_input_vars(graph, input_shapes, prelude) param_vars, tensors, packed_param_map = convert_params(graph, params) tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()} @@ -1726,14 +1977,11 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None): packed_param_map, weight_quant_params) qnn_torch.add_quant_params(tvm_params, weight_quant_params) - _convert_map.update(qnn_torch.convert_map) + convert_map.update(qnn_torch.convert_map) ret = convert_operators(_get_operator_nodes(graph.nodes()), - outputs, ret_name) - - if isinstance(ret[0], list): - ret[0] = _expr.Tuple(ret[0]) + outputs, ret_name, convert_map, prelude) - func = tvm.relay.Function(_analysis.free_vars(ret[0]), ret[0]) + mod["main"] = tvm.relay.Function(_analysis.free_vars(ret[0]), ret[0]) - return _module.IRModule.from_expr(func), tvm_params + return mod, tvm_params diff --git a/tests/python/frontend/pytorch/lstm_test.py b/tests/python/frontend/pytorch/lstm_test.py new file mode 100644 index 000000000000..4616698c687f --- /dev/null +++ b/tests/python/frontend/pytorch/lstm_test.py @@ -0,0 +1,335 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Tests on torch lstm model conversion """ +# originally from https://github.com/pytorch/pytorch/blob/master/benchmarks/fastrnns/custom_lstms.py +# described in https://pytorch.org/blog/optimizing-cuda-rnn-with-torchscript/ +import numpy as np +import torch +import torch.nn as nn +from torch.nn import Parameter +import torch.jit as jit +from typing import List, Tuple +from torch import Tensor + +import tvm +from tvm import relay +from tvm.relay.frontend.pytorch import from_pytorch +from tvm.relay.prelude import Prelude +from tvm.runtime.container import ADT, tuple_object + + +class LayerNormLSTMCell(jit.ScriptModule): + def __init__(self, input_size, hidden_size): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size)) + self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size)) + + ln = nn.LayerNorm + + self.layernorm_i = ln(4 * hidden_size) + self.layernorm_h = ln(4 * hidden_size) + self.layernorm_c = ln(hidden_size) + + @jit.script_method + def forward(self, input, state): + # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] + hx, cx = state + igates = self.layernorm_i(torch.mm(input, self.weight_ih.t())) + hgates = self.layernorm_h(torch.mm(hx, self.weight_hh.t())) + gates = igates + hgates + ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) + + ingate = torch.sigmoid(ingate) + forgetgate = torch.sigmoid(forgetgate) + cellgate = torch.tanh(cellgate) + outgate = torch.sigmoid(outgate) + + cy = self.layernorm_c((forgetgate * cx) + (ingate * cellgate)) + hy = outgate * torch.tanh(cy) + + return hy, (hy, cy) + + +class LSTMLayer(jit.ScriptModule): + def __init__(self, cell, *cell_args): + super().__init__() + self.cell = cell(*cell_args) + + @jit.script_method + def forward(self, input, state): + # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] + outputs = [] + for i in range(input.size(0)): + out, state = self.cell(input[i], state) + outputs += [out] + return torch.stack(outputs), state + + +class ReverseLSTMLayer(jit.ScriptModule): + def __init__(self, cell, *cell_args): + super(ReverseLSTMLayer, self).__init__() + self.cell = cell(*cell_args) + + @jit.script_method + def forward(self, inputs, state): + # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] + outputs = jit.annotate(List[Tensor], []) + seq_len = inputs.size(0) + for i in range(seq_len): + out, state = self.cell(inputs[seq_len - i - 1], state) + # workaround for the lack of list rev support + outputs = [out] + outputs + return torch.stack(outputs), state + + +class BidirLSTMLayer(jit.ScriptModule): + __constants__ = ['directions'] + + def __init__(self, cell, *cell_args): + super(BidirLSTMLayer, self).__init__() + self.directions = nn.ModuleList([ + LSTMLayer(cell, *cell_args), + ReverseLSTMLayer(cell, *cell_args), + ]) + + @jit.script_method + def forward(self, input, states): + # type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]] + # List[LSTMState]: [forward LSTMState, backward LSTMState] + outputs = jit.annotate(List[Tensor], []) + output_states = jit.annotate(List[Tuple[Tensor, Tensor]], []) + for (i, direction) in enumerate(self.directions): + state = states[i] + out, out_state = direction(input, state) + outputs += [out] + output_states += [out_state] + # tensor array concat assumes axis == 0 for now + # return torch.cat(outputs, -1), output_states + return torch.cat(outputs, 0), output_states + + +def init_stacked_lstm(num_layers, layer, first_layer_args, other_layer_args): + layers = [layer(*first_layer_args)] + [layer(*other_layer_args) + for _ in range(num_layers - 1)] + return nn.ModuleList(layers) + + +class StackedLSTM(jit.ScriptModule): + __constants__ = ['layers'] # Necessary for iterating through self.layers + + def __init__(self, num_layers, layer, first_layer_args, other_layer_args): + super().__init__() + self.layers = init_stacked_lstm(num_layers, layer, first_layer_args, + other_layer_args) + + @jit.script_method + def forward(self, input, states): + # type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]] + # List[LSTMState]: One state per layer + output_states = jit.annotate(List[Tuple[Tensor, Tensor]], []) + output = input + for (i, rnn_layer) in enumerate(self.layers): + state = states[i] + output, out_state = rnn_layer(output, state) + output_states += [out_state] + return output, output_states + + +class StackedBidirLSTM(jit.ScriptModule): + __constants__ = ['layers'] # Necessary for iterating through self.layers + + def __init__(self, num_layers, layer, first_layer_args, other_layer_args): + super(StackedBidirLSTM, self).__init__() + self.layers = init_stacked_lstm(num_layers, layer, first_layer_args, + other_layer_args) + + @jit.script_method + def forward(self, input, states): + # type: (Tensor, List[List[Tuple[Tensor, Tensor]]]) -> Tuple[Tensor, List[List[Tuple[Tensor, Tensor]]]] + # List[List[LSTMState]]: The outer list is for layers, + # inner list is for directions. + output_states = jit.annotate(List[List[Tuple[Tensor, Tensor]]], []) + output = input + for (i, rnn_layer) in enumerate(self.layers): + state = states[i] + output, out_state = rnn_layer(output, state) + output_states += [out_state] + return output, output_states + + +def lstm(input_size, hidden_size): + return LSTMLayer(LayerNormLSTMCell, input_size, hidden_size) + + +def stacked_lstm(input_size, hidden_size, num_layers): + return StackedLSTM(num_layers, LSTMLayer, + first_layer_args=[LayerNormLSTMCell, input_size, hidden_size], + other_layer_args=[LayerNormLSTMCell, hidden_size, hidden_size]) + + +def bidir_lstm(input_size, hidden_size): + return BidirLSTMLayer(LayerNormLSTMCell, input_size, hidden_size) + + +def stacked_bidir_lstm(input_size, hidden_size, num_layers): + return StackedBidirLSTM(num_layers, BidirLSTMLayer, + first_layer_args=[LayerNormLSTMCell, input_size, hidden_size], + other_layer_args=[LayerNormLSTMCell, hidden_size, hidden_size]) + + +def vmobj_to_list(o, dtype="float32"): + if isinstance(o, tvm.nd.NDArray): + return [o] + elif isinstance(o, tvm.runtime.container.ADT): + result = [] + for f in o: + result.extend(vmobj_to_list(f, dtype)) + return result + else: + raise RuntimeError("Unknown object type: %s" % type(o)) + + +def assert_equal(tvm_result, torch_result): + if isinstance(torch_result, (tuple, list)): + assert isinstance(tvm_result, list) + for tvm_res, pt_res in zip(tvm_result, torch_result): + assert_equal(tvm_res, pt_res) + elif isinstance(torch_result, torch.Tensor): + tvm.testing.assert_allclose(tvm_result.asnumpy(), torch_result.numpy(), + rtol=1e-4, atol=1e-4) + + +def run_and_compare(mod, params, pt_result): + executor = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(0), target="llvm") + evaluator = executor.evaluate() + exec_res = evaluator(**params) + + def flatten(nested): + res = [] + for r in nested: + if isinstance(r, torch.Tensor): + res.append(r) + else: + res.extend(flatten(r)) + return res + + if isinstance(exec_res, tvm.runtime.container.ADT): + assert not isinstance(pt_result, torch.Tensor) + tvm_res = vmobj_to_list(exec_res) + torch_res = flatten(pt_result) + else: + tvm_res = exec_res + torch_res = pt_result + + assert_equal(tvm_res, torch_res) + + +def convert_list_to_vmobj(py_lst): + def wrap_nd_array(arr): + return tvm.nd.array(arr, ctx=tvm.cpu(0)) + + mod = tvm.IRModule() + prelude = Prelude(mod) + adt_lst = ADT(prelude.nil.tag, []) + for elem in reversed(py_lst): + if isinstance(elem, np.ndarray): + vmobj = wrap_nd_array(elem) + elif isinstance(elem, tuple): + vmobj = tuple_object([wrap_nd_array(e) for e in elem]) + elif isinstance(elem, list): + vmobj = convert_list_to_vmobj(elem) + adt_lst = ADT(prelude.cons.tag, [vmobj, adt_lst]) + return adt_lst + + +def custom_lstm_test(): + input_name = "input" + states_name = "states" + seq_len = 5 + batch = 2 + input_size = 3 + hidden_size = 4 + num_layers = 3 + state_tensor_shape = (batch, hidden_size) + + inp = torch.randn(seq_len, batch, input_size) + + input_shapes = [(input_name, (seq_len, batch, input_size)), + (states_name, (state_tensor_shape, state_tensor_shape))] + + input_shapes_stacked = [(input_name, (seq_len, batch, input_size)), + (states_name, [(state_tensor_shape, state_tensor_shape), + (state_tensor_shape, state_tensor_shape)])] + + input_shapes_stacked_bidir = [(input_name, (seq_len, batch, input_size)), + (states_name, [[(state_tensor_shape, + state_tensor_shape) + for _ in range(2)] + for _ in range(num_layers)])] + + states = [(torch.randn(state_tensor_shape), + torch.randn(state_tensor_shape)) + for _ in range(num_layers)] + + bidir_states = [(torch.randn(state_tensor_shape), + torch.randn(state_tensor_shape)) + for _ in range(2)] + + stacked_bidir_states = [[(torch.randn(state_tensor_shape), + torch.randn(state_tensor_shape)) + for _ in range(2)] + for _ in range(num_layers)] + + models = [ + (lstm(input_size, hidden_size).eval(), states[0], input_shapes), + (stacked_lstm(input_size, hidden_size, num_layers).eval(), states, input_shapes_stacked), + (bidir_lstm(input_size, hidden_size).eval(), bidir_states, input_shapes_stacked), + (stacked_bidir_lstm(input_size, hidden_size, num_layers).eval(), + stacked_bidir_states, input_shapes_stacked_bidir) + ] + + for (raw_model, states, input_shapes) in models: + script_module = torch.jit.script(raw_model) + mod, params = from_pytorch(script_module, input_shapes) + + with torch.no_grad(): + pt_result = raw_model(inp.clone(), states) + + params[input_name] = inp.numpy() + + if isinstance(states, tuple): + states_np = tuple(st.numpy() for st in states) + elif isinstance(states, list) and isinstance(states[0], torch.Tensor): + states_np = [st.numpy() for st in states] + elif isinstance(states, list) and isinstance(states[0], tuple): + states_np = [tuple(st.numpy() for st in states[i]) + for i in range(len(states))] + elif isinstance(states, list) and isinstance(states[0], list): + states_np = [[tuple(st.numpy() for st in states) + for states in states[layer]] + for layer in range(num_layers)] + else: + assert False + + if isinstance(states_np, list): + params[states_name] = convert_list_to_vmobj(states_np) + else: + params[states_name] = states_np + + run_and_compare(mod, params, pt_result) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index d60ab9eeec5f..8e9928510220 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -526,13 +526,13 @@ def test_forward_maxpool2d(): input_data = torch.rand(input_shape).float() verify_model(torch.nn.MaxPool2d(kernel_size=[1, 1]).eval(), - input_data) + input_data) verify_model(torch.nn.MaxPool2d(kernel_size=[10, 10]).eval(), - input_data) + input_data) verify_model(torch.nn.MaxPool2d(kernel_size=[4, 4], padding=2, stride=2).eval(), - input_data) + input_data) def test_forward_maxpool1d(): torch.set_grad_enabled(False) @@ -540,13 +540,13 @@ def test_forward_maxpool1d(): input_data = torch.rand(input_shape).float() verify_model(torch.nn.MaxPool1d(kernel_size=1).eval(), - input_data) + input_data) verify_model(torch.nn.MaxPool1d(kernel_size=10).eval(), - input_data) - verify_model( torch.nn.MaxPool1d(kernel_size=4, + input_data) + verify_model(torch.nn.MaxPool1d(kernel_size=4, padding=2, stride=2).eval(), - input_data) + input_data) def test_forward_maxpool3d(): torch.set_grad_enabled(False) @@ -554,13 +554,13 @@ def test_forward_maxpool3d(): input_data = torch.rand(input_shape).float() verify_model(torch.nn.MaxPool3d(kernel_size=[1, 1, 1]).eval(), - input_data) + input_data) verify_model(torch.nn.MaxPool3d(kernel_size=[10, 10, 10]).eval(), - input_data) + input_data) verify_model(torch.nn.MaxPool3d(kernel_size=[4, 4, 4], padding=2, stride=2).eval(), - input_data) + input_data) def test_forward_split(): torch.set_grad_enabled(False) @@ -577,13 +577,13 @@ def forward(self, *args): input_data = torch.rand(input_shape).float() verify_model(Split(2, 0).float().eval(), - input_data=input_data) + input_data=input_data) verify_model(Split(3, 1).float().eval(), - input_data=input_data) + input_data=input_data) verify_model(Split(4, 1).float().eval(), - input_data=input_data) + input_data=input_data) verify_model(Split([2, 3, 5], 1).float().eval(), - input_data=input_data) + input_data=input_data) def test_forward_avgpool(): torch.set_grad_enabled(False) @@ -1363,3 +1363,8 @@ def forward(self, xs): # Test simple conditionals and loop test_control_flow() test_simple_rnn() + + # More complex recurrent models + from lstm_test import custom_lstm_test + + custom_lstm_test()