diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index e26cee26584b..832934417484 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -562,12 +562,14 @@ struct AttentionAttrs : public tvm::AttrsNode { /*! \brief Attributes used for the padding operator */ struct PadAttrs : public tvm::AttrsNode { Array pad_width; + runtime::Float pad_value = 0.0; tvm::String pad_mode; - TVM_DECLARE_ATTRS(PadAttrs, "relay.attrs.PadAttrs") { + TVM_DECLARE_ATTRS(PadAttrs, "relax.attrs.PadAttrs") { TVM_ATTR_FIELD(pad_width).describe( "Number of values padded to the edges of each axis, " "in the format of (before_1, after_1, ..., before_N, after_N)"); + TVM_ATTR_FIELD(pad_value).set_default(0.0).describe("The value to fill in padded area with"); TVM_ATTR_FIELD(pad_mode) .set_default("constant") .describe( diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index 4664ec549388..4c6d921db79b 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -1438,7 +1438,7 @@ def pad( x: Tensor, pad: List[int], mode: str = "constant", - value: int = 0, + value: float = 0.0, name: str = "pad", ) -> Tensor: """ @@ -1454,7 +1454,7 @@ def pad( mod : str Padding mode to use, constant implies padded elements will use value argument. - value : int + value : float What to pad with in constant mode. name : str Name hint for this operator. @@ -1464,7 +1464,7 @@ def pad( result : Tensor Padded output tensor. """ - return wrap_nested(_op.nn.pad(x._expr, pad_width=pad, pad_value=value, pad_mode=mode), name) + return wrap_nested(_op.nn.pad(x._expr, pad_width=pad, pad_mode=mode, pad_value=value), name) def square(x: Tensor, name: str = "square") -> Tensor: diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index 62d8b84321ce..5a1895cbc14f 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -20,7 +20,7 @@ from tvm import DataType from tvm.tir import FloatImm -from ...expr import Expr, const +from ...expr import Expr from . import _ffi_api @@ -513,7 +513,12 @@ def conv2d_transpose( ) -def pad(data, pad_width, pad_value=0, pad_mode="constant"): +def pad( + data: Expr, + pad_width: Tuple[Tuple[int, int], ...], + pad_mode: Optional[str] = "constant", + pad_value: Optional[Union[float, Expr]] = 0.0, +): r"""Padding This operator takes in a tensor and pads each axis by the specified @@ -523,23 +528,24 @@ def pad(data, pad_width, pad_value=0, pad_mode="constant"): ---------- data: relax.Expr The input data to the operator - pad_width: tuple of >, required + pad_width: Tuple[Tuple[int, int], ...], required Number of values padded to the edges of each axis, in the format of ((before_1, after_1), ..., (before_N, after_N)) - pad_value: float - The value used for padding - pad_mode: 'constant', 'edge', 'reflect' + pad_mode: Optional[str] + 'constant', 'edge', or 'reflect' 'constant' pads with constant_value pad_value 'edge' pads using the edge values of the input array 'reflect' pads by reflecting values with respect to the edge + Default is 'constant' + pad_value: Optional[Union[float, Expr]] + The value used for padding. Default is 0. + Returns ------- result : relax.Expr The computed result. """ - if not isinstance(pad_value, Expr): - pad_value = const(pad_value) - return _ffi_api.pad(data, pad_width, pad_value, pad_mode) + return _ffi_api.pad(data, pad_width, pad_mode, pad_value) def max_pool1d( diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index 8317d4504e1e..d9fb4701f7e9 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -231,7 +231,7 @@ def _nn_pad(bb: BlockBuilder, call: Call) -> Expr: call.args[0], pad_before=pad_before, pad_after=pad_after, - pad_value=float(call.args[1].data.numpy()), + pad_value=call.attrs.pad_value, primfunc_name_hint="pad", ) diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index 7eccf47e4b06..526b816d0945 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -136,12 +136,13 @@ TVM_REGISTER_OP("relax.nn.log_softmax") /* relax.nn.pad */ TVM_REGISTER_NODE_TYPE(PadAttrs); -Expr pad(Expr data, Array pad_width, Expr pad_value, String pad_mode) { +Expr pad(Expr data, Array pad_width, String pad_mode, runtime::Float pad_value) { auto attrs = make_object(); attrs->pad_width = std::move(pad_width); attrs->pad_mode = std::move(pad_mode); + attrs->pad_value = pad_value; static const Op& op = Op::Get("relax.nn.pad"); - return Call(op, {data, pad_value}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relax.op.nn.pad").set_body_typed(pad); @@ -171,9 +172,8 @@ StructInfo InferStructInfoPad(const Call& call, const BlockBuilder& ctx) { } TVM_REGISTER_OP("relax.nn.pad") - .set_num_inputs(2) + .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") - .add_argument("pad_value", "Tensor", "The value to fill in padded area with.") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoPad) .set_attr("FPurity", Bool(true)); diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py index 12436cf8023f..d83d0567e482 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn.py +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -3621,7 +3621,7 @@ def main(x: R.Tensor((2, 128, 28), "float32")) -> R.Tensor((2, 130, 30), "float3 class Expected: @R.function def main( - x: R.Tensor((2, 128, 28), dtype="float32") + x: R.Tensor((2, 128, 28), dtype="float32"), ) -> R.Tensor((2, 130, 30), dtype="float32"): gv = R.call_tir(Expected.pad, (x), out_sinfo=R.Tensor((2, 130, 30), dtype="float32")) return gv