diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 7a971c00cd2f..3ea70df9a13f 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -901,6 +901,24 @@ def _max_pool2d(self, node: fx.Node) -> relax.Var: return self._max_pool2d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) + def _pad(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + pad = node.args[1] + mode = node.args[2] if len(node.args) > 2 else node.kwargs.get("mode", "constant") + value = node.args[3] if len(node.args) > 3 else node.kwargs.get("value", 0.0) + value = 0.0 if value is None else value + + # Calculate symmetric padding width for each dimension + # and applying them in reverse order to match the input dimensions. + input_ndim = x.struct_info.ndim + pad_width = [0] * (input_ndim * 2) + pad_pairs = [pad[i : i + 2] for i in range(0, len(pad), 2)] + reversed_pairs = list(reversed(pad_pairs)) + flattened = [value for pair in reversed_pairs for value in pair] + pad_width[-len(flattened) :] = flattened + + return self.block_builder.emit(relax.op.nn.pad(x, pad_width, mode, value)) + def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var: transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1, 3]) query = transpose_S_H(self.env[node.args[0]]) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 4084e35de5db..9064de37f0b3 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -299,6 +299,7 @@ def create_convert_map( "log1p.default": self._log1p, "log_softmax.int": self._log_softmax, "neg.default": self._unary_op(relax.op.negative), + "pad.default": self._pad, "prelu.default": self._prelu, "reciprocal.default": self._reciprocal, "relu.default": self._unary_op(relax.op.nn.relu), diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 4ef0b05acabc..e6b1fdd223ea 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -649,6 +649,7 @@ def create_convert_map( "logical_not": self._unary_op(relax.op.logical_not), "log_softmax": self._log_softmax, "neg": self._unary_op(relax.op.negative), + "pad": self._pad, "prelu": self._prelu, "reciprocal": self._reciprocal, "relu": self._unary_op(relax.op.nn.relu), diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index 9d9eb3ef4820..e201b596f936 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -515,9 +515,9 @@ def conv2d_transpose( def pad( data: Expr, - pad_width: Tuple[Tuple[int, int], ...], + pad_width: Union[List[int], Tuple[int, ...]], pad_mode: Optional[str] = "constant", - pad_value: Optional[Union[float, Expr]] = 0.0, + pad_value: Optional[float] = 0.0, ): r"""Padding @@ -528,14 +528,15 @@ def pad( ---------- data: relax.Expr The input data to the operator - pad_width: Tuple[Tuple[int, int], ...], required + pad_width: Union[List[int], Tuple[int, ...]], required Number of values padded to the edges of each axis, in the format of ((before_1, after_1), ..., (before_N, after_N)) pad_mode: Optional[str] - 'constant', 'edge', or 'reflect' - 'constant' pads with constant_value pad_value - 'edge' pads using the edge values of the input array - 'reflect' pads by reflecting values with respect to the edge + 'constant', 'reflect', 'replicate', 'circular' + 'constant' pads with constant value pad_value + 'reflect' pads by mirroring values excluding the edge + 'replicate' pads by repeating the edge values. + 'circular' pads by looping values from the other side Default is 'constant' pad_value: Optional[Union[float, Expr]] The value used for padding. Default is 0. diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index 5d942e5f645d..6a6f0ed6cb93 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -222,18 +222,31 @@ def _nn_conv2d_transpose(bb: BlockBuilder, call: Call) -> Expr: @register_legalize("relax.nn.pad") def _nn_pad(bb: BlockBuilder, call: Call) -> Expr: - # Unpack pad_width into two separate lists for topi. + pad_mode = call.attrs.pad_mode pad_widths = call.attrs.pad_width pad_before = pad_widths[::2] pad_after = pad_widths[1::2] - return bb.call_te( - topi.nn.pad, - call.args[0], - pad_before=pad_before, - pad_after=pad_after, - pad_value=call.attrs.pad_value, - primfunc_name_hint="pad", - ) + if pad_mode == "reflect": + return bb.call_te( + topi.nn.reflect_pad, call.args[0], pad_before=pad_before, pad_after=pad_after + ) + elif pad_mode == "replicate": + return bb.call_te( + topi.nn.replicate_pad, call.args[0], pad_before=pad_before, pad_after=pad_after + ) + elif pad_mode == "circular": + return bb.call_te( + topi.nn.circular_pad, call.args[0], pad_before=pad_before, pad_after=pad_after + ) + else: + return bb.call_te( + topi.nn.pad, + call.args[0], + pad_before=pad_before, + pad_after=pad_after, + pad_value=call.attrs.pad_value, + primfunc_name_hint="pad", + ) @register_legalize("relax.nn.max_pool1d") diff --git a/python/tvm/topi/nn/pad.py b/python/tvm/topi/nn/pad.py index 8833ef38d694..a3a7379d8da9 100644 --- a/python/tvm/topi/nn/pad.py +++ b/python/tvm/topi/nn/pad.py @@ -19,14 +19,46 @@ import tvm from tvm import te +from tvm.tir import if_then_else from .. import tag from ..utils import equal_const_int +def get_padded_shape(data, pad_before, pad_after=None): + """ + Calculates the output shape of a tensor after applying padding. + + Args: + data (tvm.te.Tensor): The input tensor to which padding is applied. + pad_before : list / tuple of n ints + Pad width on each dimension to pad the before the axis begin. + pad_after : list / tuple of n ints, optional + Pad width each dimension to pad the after the axis end. + + Raises: + ValueError: If `pad_before` or `pad_after` lengths mismatch with `data` dimensions. + + Returns: + tuple: A tuple representing the padded shape of the tensor. + """ + n = data.ndim + pad_after = pad_after if pad_after else pad_before + + if len(pad_before) != n: + raise ValueError(f"pad_before length {len(pad_before)} != input dims {n}") + if len(pad_after) != n: + raise ValueError(f"pad_after length {len(pad_after)} != input dims {n}") + + ana = tvm.arith.Analyzer() + out_shape = tuple(ana.simplify(data.shape[i] + pad_before[i] + pad_after[i]) for i in range(n)) + + return out_shape + + @tvm.te.tag_scope(tag=tag.INJECTIVE + ",pad") def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput", attrs=None): - """Pad Input with zeros. + """Pad Input with using pad values. Parameters ---------- @@ -145,3 +177,143 @@ def _pad(*indices): return data(*mapped_tuple) return te.compute(out_shape, _pad, name=name) + + +@tvm.te.tag_scope(tag=tag.INJECTIVE + ",pad") +def reflect_pad(data, pad_before, pad_after=None, name="ReflectPadInput"): + """ + Apply reflect padding to the input tensor. + + Parameters + ---------- + data : tvm.te.Tensor + Input tensor. + + pad_before : List[int] + Amount to pad before each dimension. + + pad_after : List[int], optional + Amount to pad after each dimension. If None, defaults to pad_before. + + name : str + Name of the resulting tensor. + + Returns + ------- + out : tvm.te.Tensor + Reflect-padded tensor. + """ + out_shape = get_padded_shape(data, pad_before, pad_after) + + def _pad(*indices): + index_tuple = [] + for i in range(data.ndim): + idx = indices[i] + size = data.shape[i] + before = pad_before[i] + + orig_idx = idx - before + + reflected_idx = if_then_else( + orig_idx < 0, + -orig_idx, # reflect from start (no repeat) + if_then_else( + orig_idx >= size, + (2 * size - 2) - orig_idx, # reflect from end + orig_idx, + ), + ) + index_tuple.append(reflected_idx) + return data(*index_tuple) + + return te.compute(out_shape, _pad, name=name) + + +@tvm.te.tag_scope(tag=tag.INJECTIVE + ",pad") +def replicate_pad(data, pad_before, pad_after=None, name="ReplicatePadInput"): + """ + Apply replicate padding (edge padding) to the input tensor. + + Parameters + ---------- + data : tvm.te.Tensor + Input tensor. + + pad_before : List[int] + Amount to pad before each dimension. + + pad_after : List[int], optional + Amount to pad after each dimension. If None, defaults to pad_before. + + name : str + Name of the resulting tensor. + + Returns + ------- + out : tvm.te.Tensor + Replicate-padded tensor. + """ + out_shape = get_padded_shape(data, pad_before, pad_after) + + def _pad(*indices): + index_tuple = [] + for i in range(data.ndim): + idx = indices[i] + size = data.shape[i] + before = pad_before[i] + + orig_idx = idx - before + clamped_idx = if_then_else( + orig_idx < 0, + tvm.tir.const(0, "int32"), # replicate first element + if_then_else( + orig_idx >= size, + size - 1, # replicate last element + orig_idx, + ), + ) + index_tuple.append(clamped_idx) + return data(*index_tuple) + + return te.compute(out_shape, _pad, name=name) + + +@tvm.te.tag_scope(tag=tag.INJECTIVE + ",pad") +def circular_pad(data, pad_before, pad_after=None, name="CircularPadInput"): + """ + Apply circular padding (wrap around) to the input tensor. + + Parameters + ---------- + data : tvm.te.Tensor + Input tensor. + + pad_before : List[int] + Amount to pad before each dimension. + + pad_after : List[int], optional + Amount to pad after each dimension. If None, defaults to pad_before. + + name : str + Name of the resulting tensor. + + Returns + ------- + out : tvm.te.Tensor + Circular-padded tensor. + """ + out_shape = get_padded_shape(data, pad_before, pad_after) + + def _pad(*indices): + index_tuple = [] + for i in range(data.ndim): + idx = indices[i] + size = data.shape[i] + before = pad_before[i] + + orig_idx = idx - before + wrapped_idx = tvm.tir.indexmod(orig_idx + size, size) + index_tuple.append(wrapped_idx) + return data(*index_tuple) + + return te.compute(out_shape, _pad, name=name) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 8db96849993a..4c60fcd6512c 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -1850,6 +1850,95 @@ def main( verify_model(model, example_args, binding, expected2) +def test_pad(): + class PadModel(torch.nn.Module): + def __init__(self, pad, mode="constant", value=0.0): + super().__init__() + self.pad = pad + self.mode = mode + self.value = value + + def forward(self, x): + if self.mode == "constant": + return torch.nn.functional.pad(x, self.pad, mode=self.mode, value=self.value) + else: + return torch.nn.functional.pad(x, self.pad, mode=self.mode) + + @tvm.script.ir_module + class expected_constant: + @R.function + def main( + x: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad( + x, + pad_width=[0, 0, 0, 0, 2, 2, 1, 1], + pad_mode="constant", + pad_value=0.0, + ) + gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv,) + R.output(gv) + return gv + + @tvm.script.ir_module + class expected_reflect: + @R.function + def main( + x: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad( + x, + pad_width=[0, 0, 0, 0, 2, 2, 1, 1], + pad_mode="reflect", + pad_value=0.0, + ) + gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv,) + R.output(gv) + return gv + + @tvm.script.ir_module + class expected_replicate: + @R.function + def main( + x: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad( + x, + pad_width=[0, 0, 0, 0, 2, 2, 1, 1], + pad_mode="replicate", + pad_value=0.0, + ) + gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv,) + R.output(gv) + return gv + + @tvm.script.ir_module + class expected_circular: + @R.function + def main( + x: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad( + x, + pad_width=[0, 0, 0, 0, 2, 2, 1, 1], + pad_mode="circular", + pad_value=0.0, + ) + gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(PadModel(pad=[1, 1, 2, 2]), example_args, {}, expected_constant) + verify_model(PadModel(pad=[1, 1, 2, 2], mode="reflect"), example_args, {}, expected_reflect) + verify_model(PadModel(pad=[1, 1, 2, 2], mode="replicate"), example_args, {}, expected_replicate) + verify_model(PadModel(pad=[1, 1, 2, 2], mode="circular"), example_args, {}, expected_circular) + + def test_einsum(): class Einsum1(Module): def __init__(self): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 4a2ca336e1e8..53c925e14ee6 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -503,6 +503,95 @@ def main( verify_model(model, input_info, binding, expected2) +def test_pad(): + class PadModel(torch.nn.Module): + def __init__(self, pad, mode="constant", value=0.0): + super().__init__() + self.pad = pad + self.mode = mode + self.value = value + + def forward(self, x): + if self.mode == "constant": + return torch.nn.functional.pad(x, self.pad, mode=self.mode, value=self.value) + else: + return torch.nn.functional.pad(x, self.pad, mode=self.mode) + + @tvm.script.ir_module + class expected_constant: + @R.function + def main( + x: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 14, 12), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad( + x, + pad_width=[0, 0, 0, 0, 2, 2, 1, 1], + pad_mode="constant", + pad_value=0.0, + ) + gv: R.Tensor((1, 3, 14, 12), dtype="float32") = lv + R.output(gv) + return gv + + @tvm.script.ir_module + class expected_reflect: + @R.function + def main( + x: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 14, 12), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad( + x, + pad_width=[0, 0, 0, 0, 2, 2, 1, 1], + pad_mode="reflect", + pad_value=0.0, + ) + gv: R.Tensor((1, 3, 14, 12), dtype="float32") = lv + R.output(gv) + return gv + + @tvm.script.ir_module + class expected_replicate: + @R.function + def main( + x: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 14, 12), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad( + x, + pad_width=[0, 0, 0, 0, 2, 2, 1, 1], + pad_mode="replicate", + pad_value=0.0, + ) + gv: R.Tensor((1, 3, 14, 12), dtype="float32") = lv + R.output(gv) + return gv + + @tvm.script.ir_module + class expected_circular: + @R.function + def main( + x: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 14, 12), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad( + x, + pad_width=[0, 0, 0, 0, 2, 2, 1, 1], + pad_mode="circular", + pad_value=0.0, + ) + gv: R.Tensor((1, 3, 14, 12), dtype="float32") = lv + R.output(gv) + return gv + + input_infos = [([1, 3, 10, 10], "float32")] + verify_model(PadModel(pad=[1, 1, 2, 2]), input_infos, {}, expected_constant) + verify_model(PadModel(pad=[1, 1, 2, 2], mode="reflect"), input_infos, {}, expected_reflect) + verify_model(PadModel(pad=[1, 1, 2, 2], mode="replicate"), input_infos, {}, expected_replicate) + verify_model(PadModel(pad=[1, 1, 2, 2], mode="circular"), input_infos, {}, expected_circular) + + def test_linear(): # nn.Linear class Dense1(Module):