diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 9b4caa1adefa..7ceb272300e2 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -206,10 +206,24 @@ def alter_op_layout_conv2d(attrs, inputs, tinfos): return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, op) @reg.register_legalize("nn.conv2d") -def legalize_conv2d(attrs, inputs, arg_dtypes): - """Legalize conv2d""" - from ... import op - return topi.nn.conv2d_legalize(attrs, inputs, arg_dtypes, op) +def legalize_conv2d(attrs, inputs, types): + """Legalize conv2d op. + + Parameters + ---------- + attrs : tvm.attrs.Attrs + Attributes of current convolution + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + types : list of types + List of input and output types + + Returns + ------- + result : tvm.relay.Expr + The legalized expr + """ + return topi.nn.conv2d_legalize(attrs, inputs, types) reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) diff --git a/src/relay/pass/legalize.cc b/src/relay/pass/legalize.cc index c041cb9c668c..0079dabb88b3 100644 --- a/src/relay/pass/legalize.cc +++ b/src/relay/pass/legalize.cc @@ -42,11 +42,17 @@ Expr Legalizer(const Call& ref_call, const Array& new_args, const NodeRef& Expr new_e; bool modified = false; if (fop_legalize.count(op)) { - tvm::Array arg_types; + // Collect input and output dtypes to pass on to Legalize API. + tvm::Array types; for (auto& expr : ref_call->args) { - arg_types.push_back(expr->checked_type()); + types.push_back(expr->checked_type()); } - Expr legalized_value = fop_legalize[op](ref_call->attrs, new_args, arg_types); + types.push_back(ref_call->checked_type()); + + // Transform the op by calling the registered legalize function. + Expr legalized_value = fop_legalize[op](ref_call->attrs, new_args, types); + + // Check if the transformation succeeded. If not, revert back to the original ref_call->op. if (legalized_value.defined()) { new_e = legalized_value; modified = true; diff --git a/src/relay/qnn/op/dequantize.cc b/src/relay/qnn/op/dequantize.cc index 1e5944014934..e42be2af9b78 100644 --- a/src/relay/qnn/op/dequantize.cc +++ b/src/relay/qnn/op/dequantize.cc @@ -74,12 +74,12 @@ Expr DequantizeLower(const Expr& input_tensor, Expr DequantizeLegalize(const Attrs& attrs, const Array& new_args, - const Array& arg_types) { + const Array& types) { CHECK_EQ(new_args.size(), 1); auto& data = new_args[0]; const auto* dequantize_attrs = attrs.as(); CHECK(dequantize_attrs != nullptr); - CHECK_EQ(arg_types.size(), 1); + CHECK_EQ(types.size(), 2); return DequantizeLower(data, dequantize_attrs); } diff --git a/src/relay/qnn/op/quantize.cc b/src/relay/qnn/op/quantize.cc index 2f494008cc46..675cd4c5a700 100644 --- a/src/relay/qnn/op/quantize.cc +++ b/src/relay/qnn/op/quantize.cc @@ -85,13 +85,13 @@ Expr QuantizeLower(const Expr& input_tensor, Expr QuantizeLegalize(const Attrs& attrs, const Array& new_args, - const Array& arg_types) { + const Array& types) { CHECK_EQ(new_args.size(), 1); auto& data = new_args[0]; const auto* quantize_attrs = attrs.as(); CHECK(quantize_attrs != nullptr); - CHECK_EQ(arg_types.size(), 1); + CHECK_EQ(types.size(), 2); return QuantizeLower(data, quantize_attrs); } diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index e3052b71c4ca..ebc537e1ac90 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -109,7 +109,7 @@ std::pair GetFixedPointMultiplierShift(double double_multiplie * 7) Cast to the out_dtype. */ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param, - const Array& input_shape) { + const Array& input_shape, const DataType& out_dtype) { double double_multiplier = param->input_scale / param->output_scale; // Choose high precision datatype to be int64. This is for avoiding overflow @@ -173,10 +173,10 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param, auto shifted_int64_t = Add(output_zp, scaled_int64_t); // 7) Clip to the out_dtype min/max. - auto q_min = GetQmin(param->out_dtype); - auto q_max = GetQmax(param->out_dtype); + auto q_min = GetQmin(out_dtype); + auto q_max = GetQmax(out_dtype); auto clipped_t = Clip(shifted_int64_t, q_min, q_max); - return Cast(clipped_t, param->out_dtype); + return Cast(clipped_t, out_dtype); } /* @@ -193,25 +193,32 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param, * Q_output = zp_output + (scale_input)/(scale_ouptut) * (Q_input - zp_input) */ Expr RequantizeLegalize(const Attrs& attrs, const Array& new_args, - const Array& arg_types) { + const Array& types) { CHECK_EQ(new_args.size(), 1); auto& quantized_data = new_args[0]; const auto* param = attrs.as(); CHECK(param != nullptr); // Find input shape. - CHECK_EQ(arg_types.size(), 1); - auto input_dtype = arg_types[0]; - auto input_tensor_type = input_dtype.as(); - CHECK(input_tensor_type != nullptr) << "Type information missing." - << " Please run infer_type pass."; - Array input_shape = input_tensor_type->shape; + CHECK_EQ(types.size(), 2); + auto in_type = types[0]; + auto in_tensor_type = in_type.as(); + CHECK(in_tensor_type != nullptr) << "Type information missing." + << " Please run infer_type pass."; + Array input_shape = in_tensor_type->shape; + + // Find the output dtype. + auto out_type = types[1]; + auto out_tensor_type = out_type.as(); + CHECK(out_tensor_type != nullptr) << "Type information missing." + << " Please run infer_type pass."; + auto out_dtype = out_tensor_type->dtype; // Check rounding validity. CHECK(param->rounding == "UPWARD" || param->rounding == "TONEAREST") << "QNN requantize supports two rounding modes - UPWARD and " << "TONEAREST"; - return RequantizeLower(quantized_data, param, input_shape); + return RequantizeLower(quantized_data, param, input_shape, out_dtype); } /* @@ -261,7 +268,7 @@ The requantize operator converts one quantized tensor to another quantized tensor. For the output tensor, we are provided with output scale and zero point. The computation looks like this -Q_output = zp_output + (scale_input)/(scale_ouptut) * (Q_input - zp_input) +Q_output = zp_output + (scale_input)/(scale_output) * (Q_input - zp_input) )code" TVM_ADD_FILELINE) .set_attrs_type_key("relay.attrs.RequantizeAttrs") diff --git a/tests/python/relay/test_pass_legalize.py b/tests/python/relay/test_pass_legalize.py index 52deeb58ca35..393c86282be6 100644 --- a/tests/python/relay/test_pass_legalize.py +++ b/tests/python/relay/test_pass_legalize.py @@ -47,7 +47,7 @@ def before(): return y @register_legalize("nn.conv2d", level=100) - def legalize_conv2d(attrs, inputs, arg_types): + def legalize_conv2d(attrs, inputs, types): data, weight = inputs weight = relay.multiply(weight, relay.const(2.0, "float32")) return relay.nn.conv2d(data, weight, **attrs) @@ -80,7 +80,7 @@ def before(): called = [False] @register_legalize("nn.global_max_pool2d", level=101) - def legalize_conv2d(attrs, inputs, arg_types): + def legalize_conv2d(attrs, inputs, types): called[0] = True return None @@ -103,12 +103,13 @@ def before(): return func @register_legalize("concatenate", level=100) - def legalize_concatenate(attrs, inputs, arg_types): + def legalize_concatenate(attrs, inputs, types): # Check that the correct multi-input case is handled. assert len(inputs) == 1 assert isinstance(inputs[0], tvm.relay.expr.Tuple) - assert len(arg_types) == 1 - assert isinstance(arg_types[0], tvm.relay.ty.TupleType) + assert len(types) == 2 + assert isinstance(types[0], tvm.relay.ty.TupleType) + assert isinstance(types[1], tvm.relay.ty.TensorType) return None def expected(): @@ -153,9 +154,9 @@ def before(): return func @register_legalize("nn.conv2d", level=101) - def legalize_conv2d(attrs, inputs, arg_types): + def legalize_conv2d(attrs, inputs, types): from topi.arm_cpu.conv2d import _conv2d_legalize - return _conv2d_legalize(attrs, inputs, arg_types, tvm.relay.op) + return _conv2d_legalize(attrs, inputs, types) a = before() b = run_opt_pass(a, transform.Legalize()) diff --git a/topi/python/topi/arm_cpu/conv2d.py b/topi/python/topi/arm_cpu/conv2d.py index e4b65fd2e718..fbb85d558d89 100644 --- a/topi/python/topi/arm_cpu/conv2d.py +++ b/topi/python/topi/arm_cpu/conv2d.py @@ -18,10 +18,11 @@ """Conv2D schedule for ARM CPU""" from __future__ import absolute_import as _abs -import warnings +import logging import tvm from tvm import autotvm +from tvm import relay import tvm.contrib.nnpack from ..generic import schedule_conv2d_nchw, schedule_conv2d_winograd_without_weight_transform, \ @@ -35,6 +36,8 @@ from ..nn.util import get_const_int, get_pad_tuple from ..nn.winograd_util import winograd_transform_matrices +logger = logging.getLogger('topi') + @autotvm.register_topi_compute(conv2d, 'arm_cpu', ['direct']) def conv2d_arm_cpu(cfg, data, kernel, strides, padding, dilation, layout, out_dtype): """TOPI compute callback for conv2d @@ -679,7 +682,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F): if layout != 'NCHW': return None if dilation != (1, 1): - warnings.warn("Does not support weight pre-transform for dilated convolution.") + logger.warning("Does not support weight pre-transform for dilated convolution.") return None data, kernel = tinfos[0:2] @@ -794,31 +797,46 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F): return None @conv2d_legalize.register("arm_cpu") -def _conv2d_legalize(attrs, inputs, arg_types, F): - if F.__name__ != 'tvm.relay.op': - return None +def _conv2d_legalize(attrs, inputs, arg_types): + """Legalizes Conv2D op. + + Parameters + ---------- + attrs : tvm.attrs.Attrs + Attributes of current convolution + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + types : list of types + List of input and output types + + Returns + ------- + result : tvm.relay.Expr + The legalized expr + """ + if attrs['data_layout'] == 'NHWC': data, kernel = inputs if attrs['kernel_layout'] == 'HWIO': # Handle HWIO layout. This is common in TF graph. - kernel = F.transpose(kernel, axes=(3, 2, 0, 1)) + kernel = relay.transpose(kernel, axes=(3, 2, 0, 1)) elif attrs['kernel_layout'] == 'HWOI': # Handle HWOI layout. This is common in TF depthwise conv2d graph. - kernel = F.transpose(kernel, axes=(2, 3, 0, 1)) + kernel = relay.transpose(kernel, axes=(2, 3, 0, 1)) elif attrs['kernel_layout'] != 'OIHW': return None - warnings.warn("Legalize arm_cpu - NHWC schedule absent. Inserting layout transforms to " - + "fallback to NCHW. This can result in performance degradation.") + logger.warning("Legalize arm_cpu - NHWC schedule absent. Inserting layout transforms to " + + "fallback to NCHW. This can result in performance degradation.") # Set new attrs for the tranposed conv. new_attrs = {k: attrs[k] for k in attrs.keys()} new_attrs['data_layout'] = 'NCHW' new_attrs['kernel_layout'] = 'OIHW' # Convert from NHWC to NCHW. - data = F.transpose(data, axes=(0, 3, 1, 2)) - conv = F.nn.conv2d(data, kernel, **new_attrs) + data = relay.transpose(data, axes=(0, 3, 1, 2)) + conv = relay.nn.conv2d(data, kernel, **new_attrs) # Convert back to original NHWC layout. - out = F.transpose(conv, axes=(0, 2, 3, 1)) + out = relay.transpose(conv, axes=(0, 2, 3, 1)) return out return None diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index e7ab7ba99990..05580c894b53 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -72,22 +72,22 @@ def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=N @tvm.target.generic_func -def conv2d_legalize(attrs, inputs, arg_dtypes, F): +def conv2d_legalize(attrs, inputs, types): """Legalizes Conv2D op. + Parameters ---------- - attrs : nnvm.top.AttrDict or tvm.attrs.Attrs + attrs : tvm.attrs.Attrs Attributes of current convolution inputs : list of tvm.relay.Expr - The args of the Relay expr to be legalized. - arg_dtypes : list of types - List of types of input arguments - F: symbol - The context, can be either nnvm.sym or relay.op - Note - ---- - Unlike other TOPI functions, this function operates on both graph level and operator level, - so we have to pass 'F' to make it support our two versions of graph IR, NNVM and Relay. + The args of the Relay expr to be legalized + types : list of types + List of input and output types + + Returns + ------- + result : tvm.relay.Expr + The legalized expr """ # not to change by default return None