diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 42adefa314293..43704e39953c1 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -1044,13 +1044,10 @@ struct UpSampling3DAttrs : public tvm::AttrsNode { /*! \brief Attributes used for the padding operator */ struct PadAttrs : public tvm::AttrsNode { - double pad_value; Array> pad_width; std::string pad_mode; TVM_DECLARE_ATTRS(PadAttrs, "relay.attrs.PadAttrs") { - TVM_ATTR_FIELD(pad_value).set_default(0.0).describe( - "The value used for padding when mode is 'constant'."); TVM_ATTR_FIELD(pad_width).describe( "Number of values padded to the edges of each axis, " "in the format of ((before_1, after_1), ..., (before_N, after_N))"); diff --git a/python/tvm/relay/op/contrib/arm_compute_lib.py b/python/tvm/relay/op/contrib/arm_compute_lib.py index 6234c944d4e4a..9152b50e76869 100644 --- a/python/tvm/relay/op/contrib/arm_compute_lib.py +++ b/python/tvm/relay/op/contrib/arm_compute_lib.py @@ -17,16 +17,15 @@ # pylint: disable=invalid-name, unused-argument """Arm Compute Library supported operators.""" import tvm - from tvm import relay from tvm._ffi import register_func -from tvm.relay.expr import const from tvm.relay import transform from tvm.relay.build_module import bind_params_by_name +from tvm.relay.expr import const -from ...dataflow_pattern import wildcard, is_op, is_constant, is_expr -from .register import register_pattern_table +from ...dataflow_pattern import is_constant, is_expr, is_op, wildcard from ..strategy.generic import is_depthwise_conv2d +from .register import register_pattern_table def is_arm_compute_runtime_enabled(): @@ -140,7 +139,7 @@ def conv_pattern(): pattern : dataflow_pattern.AltPattern Denotes the convolution pattern. """ - pattern = is_op("nn.pad")(wildcard()) | wildcard() + pattern = is_op("nn.pad")(wildcard(), wildcard()) | wildcard() pattern = is_op("nn.conv2d")(pattern, is_constant()) pattern = pattern.optional(lambda x: is_op("nn.bias_add")(x, is_constant())) pattern = pattern.optional(is_op("nn.relu")) @@ -154,7 +153,7 @@ def qnn_conv_pattern(): pattern : dataflow_pattern.AltPattern Denotes the convolution pattern. """ - pattern = is_op("nn.pad")(wildcard()) | wildcard() + pattern = is_op("nn.pad")(wildcard(), wildcard()) | wildcard() pattern = is_op("qnn.conv2d")( pattern, is_constant(), is_constant(), is_constant(), is_constant(), is_constant() ) diff --git a/python/tvm/relay/op/contrib/ethosn.py b/python/tvm/relay/op/contrib/ethosn.py index 2c63d63a36ef7..2cd787a74dc12 100644 --- a/python/tvm/relay/op/contrib/ethosn.py +++ b/python/tvm/relay/op/contrib/ethosn.py @@ -22,10 +22,10 @@ from tvm.relay import transform from tvm.relay.build_module import bind_params_by_name -from ...dataflow_pattern import wildcard, is_op, is_constant from ... import qnn as _qnn -from .register import register_pattern_table +from ...dataflow_pattern import is_constant, is_op, wildcard from . import _ethosn as support +from .register import register_pattern_table class Available(Enum): @@ -82,7 +82,7 @@ def pattern_table(): """Get the Ethos-N compiler pattern table.""" def qnn_conv_pattern(): - pattern = is_op("nn.pad")(wildcard()) | wildcard() + pattern = is_op("nn.pad")(wildcard(), wildcard()) | wildcard() pattern = is_op("qnn.conv2d")( pattern, is_constant(), is_constant(), is_constant(), is_constant(), is_constant() ) diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 5acdafef0ed96..e5ca7e4a4717b 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -18,10 +18,10 @@ """Neural network operations.""" from tvm.relay import expr -from . import _make +from ...expr import Constant, Expr, const from ..dyn.nn import _make as _dyn_make +from . import _make from .utils import get_pad_tuple1d, get_pad_tuple2d, get_pad_tuple3d -from ...expr import const, Expr, Constant def conv1d( @@ -1606,15 +1606,11 @@ def pad(data, pad_width, pad_value=0, pad_mode="constant"): result : tvm.relay.Expr The computed result. """ - if isinstance(pad_value, Constant): - pad_value = pad_value.data.asnumpy().item() if isinstance(pad_width, Constant): pad_width = [list(i) for i in pad_width.data.asnumpy()] - if isinstance(pad_width, Expr) or (isinstance(pad_value, Expr)): - if not isinstance(pad_width, Expr): - pad_width = const(list(pad_width)) - if not isinstance(pad_value, Expr): - pad_value = const(pad_value) + if not isinstance(pad_value, Expr): + pad_value = const(pad_value) + if isinstance(pad_width, Expr): return _dyn_make.pad(data, pad_width, pad_value, pad_mode) return _make.pad(data, pad_width, pad_value, pad_mode) diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h index e5a20abd76242..a3a182515b6d3 100644 --- a/src/relay/op/make_op.h +++ b/src/relay/op/make_op.h @@ -58,7 +58,7 @@ Expr MakeAutoSchedulerLayoutTransform(Expr data, String src_layout, String dst_l Expr MakeOnes(Array shape, DataType dtype); -Expr MakePad(Expr data, Array> pad_width, double pad_value, String pad_mode); +Expr MakePad(Expr data, Array> pad_width, Expr pad_value, String pad_mode); Expr MakeReduce(Expr data, Array axis, bool keepdims, bool exclude, String op_name); diff --git a/src/relay/op/nn/pad.cc b/src/relay/op/nn/pad.cc index c6b987eb42aa7..42a1ea186dbf6 100644 --- a/src/relay/op/nn/pad.cc +++ b/src/relay/op/nn/pad.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include @@ -44,7 +45,7 @@ Array> PadInferCorrectLayout(const Attrs& attrs, const Array(attrs.as()); - Layout ret; + Layout ret_data; // If new_in_layouts are defined, this code tries to modify the layout. bool is_layout_modified = new_in_layouts.defined(); if (new_in_layouts.defined()) { @@ -55,8 +56,8 @@ Array> PadInferCorrectLayout(const Attrs& attrs, const Array> axis_pad_width; int index_counter = 0; - ICHECK_EQ(new_in_layouts.size(), 1); - ICHECK_EQ(old_in_layouts.size(), 1); + ICHECK_EQ(new_in_layouts.size(), 2); + ICHECK_EQ(old_in_layouts.size(), 2); for (auto iter_var : old_in_layouts[0]->axes) { const auto& old_layout_axis = LayoutAxis::Get(iter_var); axis_pad_width.emplace(old_layout_axis.name(), params->pad_width[index_counter]); @@ -95,26 +96,29 @@ Array> PadInferCorrectLayout(const Attrs& attrs, const Arraypad_width = new_pad_width; } } if (!is_layout_modified) { if (old_in_layouts.defined()) { - ICHECK_EQ(old_in_layouts.size(), 1); - ret = old_in_layouts[0]; + ICHECK_EQ(old_in_layouts.size(), 2); + ret_data = old_in_layouts[0]; } else { - ret = Layout::Undef(); + ret_data = Layout::Undef(); } } - return Array>{{ret}, {ret}}; + // The pad value is always a scalar + Layout ret_pad_value = Layout("1"); + return Array>{{ret_data, ret_pad_value}, {ret_data}}; } bool PadRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { - ICHECK_EQ(types.size(), 2); + // types = [pad_data_type, pad_value_type, ret_type] + ICHECK_EQ(types.size(), 3); const auto* data = types[0].as(); if (data == nullptr) return false; @@ -151,7 +155,7 @@ bool PadRel(const Array& types, int num_inputs, const Attrs& attrs, } } - reporter->Assign(types[1], TensorType(Array(oshape), data->dtype)); + reporter->Assign(types[2], TensorType(Array(oshape), data->dtype)); return true; } @@ -170,20 +174,19 @@ Array PadCompute(const Attrs& attrs, const Array& inputs for (size_t i = 0; i < pad_width.size(); ++i) { pad_after.push_back(pad_width[i][1]); } - const auto* out_ttype = out_type.as(); - return Array{topi::pad(inputs[0], pad_before, pad_after, - tvm::tir::make_const(out_ttype->dtype, param->pad_value), - "T_pad", topi::kElementWise, param->pad_mode)}; + te::Tensor cast_pad_value = topi::cast(inputs[1], inputs[0]->dtype); + const PrimExpr& pad_value = cast_pad_value(Array()); + return Array{topi::pad(inputs[0], pad_before, pad_after, pad_value, "T_pad", + topi::kElementWise, param->pad_mode)}; } // Handler to create a call to the padding op used by front-end FFI -Expr MakePad(Expr data, Array> pad_width, double pad_value, String pad_mode) { +Expr MakePad(Expr data, Array> pad_width, Expr pad_value, String pad_mode) { auto attrs = make_object(); - attrs->pad_value = pad_value; attrs->pad_width = std::move(pad_width); attrs->pad_mode = std::move(pad_mode); static const Op& op = Op::Get("nn.pad"); - return Call(op, {data}, Attrs(attrs), {}); + return Call(op, {data, pad_value}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.pad").set_body_typed(MakePad); @@ -193,8 +196,9 @@ RELAY_REGISTER_OP("nn.pad") )code" TVM_ADD_FILELINE) .set_attrs_type() - .set_num_inputs(1) + .set_num_inputs(2) .add_argument("data", "Tensor", "The input tensor.") + .add_argument("pad_val", "Tensor", "The value to fill the padded area with") .set_support_level(2) .add_type_rel("Pad", PadRel) .set_attr("FInferCorrectLayout", PadInferCorrectLayout) diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index 21335ec2fb341..1a81d10c25832 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -234,8 +234,7 @@ Expr Conv2DPadInput(const Expr& data, const Expr& input_zero_point, const Conv2D } else { LOG(FATAL) << "qnn.conv2d does not support " << param->data_layout << " layout"; } - auto pad_value = GetScalarFromConstant(input_zero_point); - padded_data = Pad(data, pad_width, pad_value, "constant"); + padded_data = Pad(data, pad_width, input_zero_point, "constant"); } return padded_data; } diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc index 0590b41550ce5..8bfe1d83bd9ef 100644 --- a/src/relay/transforms/dynamic_to_static.cc +++ b/src/relay/transforms/dynamic_to_static.cc @@ -179,7 +179,9 @@ class DynamicToStaticMutator : public MixedModeMutator { const PadAttrs* param = call_node->attrs.as(); ICHECK(param); - return MakePad(call_node->args[0], ToMatrix(pad_width->data), ToScalar(pad_fill->data), + + Expr pad_value = args[2]; + return MakePad(call_node->args[0], ToMatrix(pad_width->data), pad_value, param->pad_mode); } return Expr(nullptr); diff --git a/src/relay/transforms/fold_explicit_padding.cc b/src/relay/transforms/fold_explicit_padding.cc index d959e5b75e40d..b600953e0765d 100644 --- a/src/relay/transforms/fold_explicit_padding.cc +++ b/src/relay/transforms/fold_explicit_padding.cc @@ -45,7 +45,7 @@ class SimplifyConvPad { SimplifyConvPad() { x_ = IsWildcard(); w_ = IsWildcard(); - pad_ = IsOp("nn.pad")({x_}); + pad_ = IsOp("nn.pad")({x_, IsWildcard()}); conv1d_ = IsOp("nn.conv1d"); conv2d_ = IsOp("nn.conv2d"); conv3d_ = IsOp("nn.conv3d"); @@ -119,7 +119,11 @@ class SimplifyConvPad { ICHECK(pad_node); const PadAttrs* param = pad_node->attrs.as(); ICHECK(param); - if (param->pad_mode == "constant" && param->pad_value == 0.0) { + Array args = pad_node->args; + + // Possibly perform more optimizations if the pad_value is 0 + const ConstantNode* pad_value = args[1].as(); + if (param->pad_mode == "constant" && pad_value && ToScalar(pad_value->data) == 0.0) { Attrs attrs; if (node_map.count(conv1d_)) { attrs = GetAttrs(param, call_node->attrs.as()); diff --git a/src/relay/transforms/pattern_utils.h b/src/relay/transforms/pattern_utils.h index 8d9f723dffea1..975c07e563be7 100644 --- a/src/relay/transforms/pattern_utils.h +++ b/src/relay/transforms/pattern_utils.h @@ -676,7 +676,7 @@ static inline Expr AvgPool2D(Expr data, Array pool_size, Array> pad_width, double pad_value, +static inline Expr Pad(Expr data, Array> pad_width, Expr pad_value, std::string pad_mode) { Array> pad_width_int; for (size_t i = 0; i < pad_width.size(); ++i) { diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index c5843758c3d2c..9640868458370 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -18,15 +18,13 @@ """ import numpy as np import tvm -from tvm import te -from tvm import autotvm -from tvm import relay +import tvm.testing +import tvm.topi.testing +from tvm import autotvm, relay, te +from tvm.contrib import utils from tvm.relay import transform from tvm.relay.testing import run_infer_type -from tvm.contrib import utils -import tvm.topi.testing from tvm.topi.cuda.conv3d_winograd import _infer_tile_size -import tvm.testing @tvm.testing.uses_gpu @@ -1197,6 +1195,30 @@ def test_pad_infer_type(): yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((n + (-2), c + (-4), h + (-2), w + 8), "float32") + # dealing with dynamic vals + n, c, h, w = te.size_var("n"), 2, 3, te.size_var("w") + t = relay.var("t", relay.TensorType((n, c, h, w), "float32")) + y = relay.nn.pad( + t, ((1, 1), (2, 2), (3, 3), (4, 4)), pad_value=relay.var("pad_value", "float32") + ) + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType((n + 2, 6, 9, w + 8), "float32") + + +def _get_numpy_pad(dshape, data, pad, pad_value=0): + mod_pad = [] + for axis, (pad_x, pad_y) in enumerate(pad): + indices = range(dshape[axis]) + if pad_x < 0: + indices = indices[abs(pad_x) :] + pad_x = 0 + if pad_y < 0: + indices = indices[:pad_y] + pad_y = 0 + data = np.take(data, indices, axis) + mod_pad.append((pad_x, pad_y)) + return np.pad(data, tuple(mod_pad), "constant", constant_values=pad_value) + @tvm.testing.uses_gpu def test_pad_run(): @@ -1209,20 +1231,7 @@ def _test_run(dtype): y = relay.nn.pad(x, pad) func = relay.Function([x], y) data = np.random.uniform(size=dshape).astype(dtype) - mod_pad = [] - mod_data = data - for axis, (pad_x, pad_y) in enumerate(pad): - indices = range(dshape[axis]) - if pad_x < 0: - indices = indices[abs(pad_x) :] - pad_x = 0 - if pad_y < 0: - indices = indices[:pad_y] - pad_y = 0 - mod_data = np.take(mod_data, indices, axis) - mod_pad.append((pad_x, pad_y)) - - ref_res = np.pad(mod_data, tuple(mod_pad), "constant") + ref_res = _get_numpy_pad(dshape, data, pad) for target, dev in tvm.testing.enabled_targets(): intrp1 = relay.create_executor("graph", device=dev, target=target) op_res1 = intrp1.evaluate(func)(data) @@ -1232,6 +1241,30 @@ def _test_run(dtype): _test_run("int32") +@tvm.testing.uses_gpu +def test_pad_run_dynamic_pad_value(): + def _test_run(dtype): + dshape = (4, 6, 3, 5) + pad = ((-1, -1), (2, -2), (0, -2), (4, 4)) + + data = relay.var("data", shape=dshape, dtype=dtype) + pad_value = relay.var("pad_value", dtype) + pad_data = relay.nn.pad(data, pad, pad_value=pad_value) + f = relay.Function([data, pad_value], pad_data) + + data_arr = np.random.uniform(-10, 10, size=dshape).astype(dtype) + pad_value_arr = 2.0 + ref_res = _get_numpy_pad(dshape, data_arr, pad, pad_value=pad_value_arr) + + for target, dev in tvm.testing.enabled_targets(): + intrp = relay.create_executor(kind="graph", device=dev, target=target) + result = intrp.evaluate(f)(data_arr, pad_value_arr) + tvm.testing.assert_allclose(result.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) + + _test_run("float32") + _test_run("int32") + + @tvm.testing.uses_gpu def test_lrn(): n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w") @@ -1766,6 +1799,7 @@ def _test_correlation( test_flatten_infer_type() test_pad_infer_type() test_pad_run() + test_pad_run_dynamic_pad_value() test_conv3d_transpose_infer_type() test_conv3d_transpose_ncdhw_run() test_conv2d_transpose_infer_type()