diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index e26cee26584b..9ecf5f3aaed9 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -562,12 +562,14 @@ struct AttentionAttrs : public tvm::AttrsNode { /*! \brief Attributes used for the padding operator */ struct PadAttrs : public tvm::AttrsNode { Array pad_width; + FloatImm pad_value; tvm::String pad_mode; TVM_DECLARE_ATTRS(PadAttrs, "relay.attrs.PadAttrs") { TVM_ATTR_FIELD(pad_width).describe( "Number of values padded to the edges of each axis, " "in the format of (before_1, after_1, ..., before_N, after_N)"); + TVM_ATTR_FIELD(pad_value).describe("The value to fill in padded area with"); TVM_ATTR_FIELD(pad_mode) .set_default("constant") .describe( diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index 8317d4504e1e..d9fb4701f7e9 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -231,7 +231,7 @@ def _nn_pad(bb: BlockBuilder, call: Call) -> Expr: call.args[0], pad_before=pad_before, pad_after=pad_after, - pad_value=float(call.args[1].data.numpy()), + pad_value=call.attrs.pad_value, primfunc_name_hint="pad", ) diff --git a/python/tvm/relay/frontend/tensorflow_ops.py b/python/tvm/relay/frontend/tensorflow_ops.py index 14171afb3d44..009e0f87e253 100644 --- a/python/tvm/relay/frontend/tensorflow_ops.py +++ b/python/tvm/relay/frontend/tensorflow_ops.py @@ -2359,7 +2359,7 @@ def _impl(inputs, attr, params, mod): else: paddings = tuple(tuple(l) for l in padlist) attr["pad_width"] = paddings - attr["pad_value"] = 0 + attr["pad_value"] = _expr.const(0) new_inputs = [inputs[0]] if name == "PadV2": try: diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index 7eccf47e4b06..6009a24e87b0 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -136,12 +136,13 @@ TVM_REGISTER_OP("relax.nn.log_softmax") /* relax.nn.pad */ TVM_REGISTER_NODE_TYPE(PadAttrs); -Expr pad(Expr data, Array pad_width, Expr pad_value, String pad_mode) { +Expr pad(Expr data, Array pad_width, FloatImm pad_value, String pad_mode) { auto attrs = make_object(); attrs->pad_width = std::move(pad_width); + attrs->pad_value = pad_value; attrs->pad_mode = std::move(pad_mode); static const Op& op = Op::Get("relax.nn.pad"); - return Call(op, {data, pad_value}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relax.op.nn.pad").set_body_typed(pad); @@ -171,9 +172,8 @@ StructInfo InferStructInfoPad(const Call& call, const BlockBuilder& ctx) { } TVM_REGISTER_OP("relax.nn.pad") - .set_num_inputs(2) + .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") - .add_argument("pad_value", "Tensor", "The value to fill in padded area with.") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoPad) .set_attr("FPurity", Bool(true));