Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion include/tvm/relax/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -562,12 +562,14 @@ struct AttentionAttrs : public tvm::AttrsNode<AttentionAttrs> {
/*! \brief Attributes used for the padding operator */
struct PadAttrs : public tvm::AttrsNode<PadAttrs> {
Array<Integer> pad_width;
runtime::Float pad_value = 0.0;
tvm::String pad_mode;

TVM_DECLARE_ATTRS(PadAttrs, "relay.attrs.PadAttrs") {
TVM_DECLARE_ATTRS(PadAttrs, "relax.attrs.PadAttrs") {
TVM_ATTR_FIELD(pad_width).describe(
"Number of values padded to the edges of each axis, "
"in the format of (before_1, after_1, ..., before_N, after_N)");
TVM_ATTR_FIELD(pad_value).set_default(0.0).describe("The value to fill in padded area with");
TVM_ATTR_FIELD(pad_mode)
.set_default("constant")
.describe(
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/relax/frontend/nn/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1438,7 +1438,7 @@ def pad(
x: Tensor,
pad: List[int],
mode: str = "constant",
value: int = 0,
value: float = 0.0,
name: str = "pad",
) -> Tensor:
"""
Expand All @@ -1454,7 +1454,7 @@ def pad(
mod : str
Padding mode to use, constant implies padded elements will use
value argument.
value : int
value : float
What to pad with in constant mode.
name : str
Name hint for this operator.
Expand All @@ -1464,7 +1464,7 @@ def pad(
result : Tensor
Padded output tensor.
"""
return wrap_nested(_op.nn.pad(x._expr, pad_width=pad, pad_value=value, pad_mode=mode), name)
return wrap_nested(_op.nn.pad(x._expr, pad_width=pad, pad_mode=mode, pad_value=value), name)


def square(x: Tensor, name: str = "square") -> Tensor:
Expand Down
24 changes: 15 additions & 9 deletions python/tvm/relax/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from tvm import DataType
from tvm.tir import FloatImm

from ...expr import Expr, const
from ...expr import Expr
from . import _ffi_api


Expand Down Expand Up @@ -513,7 +513,12 @@ def conv2d_transpose(
)


def pad(data, pad_width, pad_value=0, pad_mode="constant"):
def pad(
data: Expr,
pad_width: Tuple[Tuple[int, int], ...],
pad_mode: Optional[str] = "constant",
pad_value: Optional[Union[float, Expr]] = 0.0,
):
r"""Padding

This operator takes in a tensor and pads each axis by the specified
Expand All @@ -523,23 +528,24 @@ def pad(data, pad_width, pad_value=0, pad_mode="constant"):
----------
data: relax.Expr
The input data to the operator
pad_width: tuple of <tuple of <int>>, required
pad_width: Tuple[Tuple[int, int], ...], required
Number of values padded to the edges of each axis, in the format
of ((before_1, after_1), ..., (before_N, after_N))
pad_value: float
The value used for padding
pad_mode: 'constant', 'edge', 'reflect'
pad_mode: Optional[str]
'constant', 'edge', or 'reflect'
'constant' pads with constant_value pad_value
'edge' pads using the edge values of the input array
'reflect' pads by reflecting values with respect to the edge
Default is 'constant'
pad_value: Optional[Union[float, Expr]]
The value used for padding. Default is 0.

Returns
-------
result : relax.Expr
The computed result.
"""
if not isinstance(pad_value, Expr):
pad_value = const(pad_value)
return _ffi_api.pad(data, pad_width, pad_value, pad_mode)
return _ffi_api.pad(data, pad_width, pad_mode, pad_value)


def max_pool1d(
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relax/transform/legalize_ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def _nn_pad(bb: BlockBuilder, call: Call) -> Expr:
call.args[0],
pad_before=pad_before,
pad_after=pad_after,
pad_value=float(call.args[1].data.numpy()),
pad_value=call.attrs.pad_value,
primfunc_name_hint="pad",
)

Expand Down
8 changes: 4 additions & 4 deletions src/relax/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,13 @@ TVM_REGISTER_OP("relax.nn.log_softmax")
/* relax.nn.pad */
TVM_REGISTER_NODE_TYPE(PadAttrs);

Expr pad(Expr data, Array<Integer> pad_width, Expr pad_value, String pad_mode) {
Expr pad(Expr data, Array<Integer> pad_width, String pad_mode, runtime::Float pad_value) {
auto attrs = make_object<PadAttrs>();
attrs->pad_width = std::move(pad_width);
attrs->pad_mode = std::move(pad_mode);
attrs->pad_value = pad_value;
static const Op& op = Op::Get("relax.nn.pad");
return Call(op, {data, pad_value}, Attrs(attrs), {});
return Call(op, {data}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relax.op.nn.pad").set_body_typed(pad);
Expand Down Expand Up @@ -171,9 +172,8 @@ StructInfo InferStructInfoPad(const Call& call, const BlockBuilder& ctx) {
}

TVM_REGISTER_OP("relax.nn.pad")
.set_num_inputs(2)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("pad_value", "Tensor", "The value to fill in padded area with.")
.set_attrs_type<PadAttrs>()
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoPad)
.set_attr<Bool>("FPurity", Bool(true));
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relax/test_transform_legalize_ops_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3621,7 +3621,7 @@ def main(x: R.Tensor((2, 128, 28), "float32")) -> R.Tensor((2, 130, 30), "float3
class Expected:
@R.function
def main(
x: R.Tensor((2, 128, 28), dtype="float32")
x: R.Tensor((2, 128, 28), dtype="float32"),
) -> R.Tensor((2, 130, 30), dtype="float32"):
gv = R.call_tir(Expected.pad, (x), out_sinfo=R.Tensor((2, 130, 30), dtype="float32"))
return gv
Expand Down
Loading