From 6efd698eccd85ad9d4bdefd11ff5089b6772f894 Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Thu, 10 Apr 2025 05:58:01 +0000 Subject: [PATCH 1/3] add pad op support into frontend pipelines fixing end of files formatting issue fixing trailing space issues update the docstring for pad mode in nn file fixing lint issues remove trailing whitespaces fix lint format issues in test script fix lint issue in pad file import statement modify docstring of pad function fixing dtype error in unity check fixing lint issues in pad.py file resolve arg mismatch error resolved error while handling pad value attr fix dtype of pad value attribute add helper function for different pad mode test script enhanced to check different pad mode remove trailing whitespaces in test script add docstring for helper function update test script --- .../torch/base_fx_graph_translator.py | 17 ++ .../torch/exported_program_translator.py | 1 + .../tvm/relax/frontend/torch/fx_translator.py | 1 + python/tvm/relax/op/nn/nn.py | 15 +- python/tvm/relax/transform/legalize_ops/nn.py | 31 +++- python/tvm/topi/nn/pad.py | 174 +++++++++++++++++- .../test_frontend_from_exported_program.py | 89 +++++++++ tests/python/relax/test_frontend_from_fx.py | 89 +++++++++ 8 files changed, 400 insertions(+), 17 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 7a971c00cd2f..e4fd793e3f05 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -901,6 +901,23 @@ def _max_pool2d(self, node: fx.Node) -> relax.Var: return self._max_pool2d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) + def _pad(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + pad = node.args[1] + mode = node.args[2] if len(node.args) > 2 else node.kwargs.get("mode", "constant") + value = node.args[3] if len(node.args) > 3 else node.kwargs.get("value", 0.0) + + # Calculate symmetric padding width for each dimension + # and applying them in reverse order to match the input dimensions. + input_ndim = x.struct_info.ndim + pad_width = [0] * (input_ndim * 2) + pad_pairs = [pad[i : i + 2] for i in range(0, len(pad), 2)] + reversed_pairs = list(reversed(pad_pairs)) + flattened = [value for pair in reversed_pairs for value in pair] + pad_width[-len(flattened) :] = flattened + + return self.block_builder.emit(relax.op.nn.pad(x, pad_width, mode, value)) + def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var: transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1, 3]) query = transpose_S_H(self.env[node.args[0]]) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 4084e35de5db..9064de37f0b3 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -299,6 +299,7 @@ def create_convert_map( "log1p.default": self._log1p, "log_softmax.int": self._log_softmax, "neg.default": self._unary_op(relax.op.negative), + "pad.default": self._pad, "prelu.default": self._prelu, "reciprocal.default": self._reciprocal, "relu.default": self._unary_op(relax.op.nn.relu), diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 4ef0b05acabc..e6b1fdd223ea 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -649,6 +649,7 @@ def create_convert_map( "logical_not": self._unary_op(relax.op.logical_not), "log_softmax": self._log_softmax, "neg": self._unary_op(relax.op.negative), + "pad": self._pad, "prelu": self._prelu, "reciprocal": self._reciprocal, "relu": self._unary_op(relax.op.nn.relu), diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index 9d9eb3ef4820..e201b596f936 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -515,9 +515,9 @@ def conv2d_transpose( def pad( data: Expr, - pad_width: Tuple[Tuple[int, int], ...], + pad_width: Union[List[int], Tuple[int, ...]], pad_mode: Optional[str] = "constant", - pad_value: Optional[Union[float, Expr]] = 0.0, + pad_value: Optional[float] = 0.0, ): r"""Padding @@ -528,14 +528,15 @@ def pad( ---------- data: relax.Expr The input data to the operator - pad_width: Tuple[Tuple[int, int], ...], required + pad_width: Union[List[int], Tuple[int, ...]], required Number of values padded to the edges of each axis, in the format of ((before_1, after_1), ..., (before_N, after_N)) pad_mode: Optional[str] - 'constant', 'edge', or 'reflect' - 'constant' pads with constant_value pad_value - 'edge' pads using the edge values of the input array - 'reflect' pads by reflecting values with respect to the edge + 'constant', 'reflect', 'replicate', 'circular' + 'constant' pads with constant value pad_value + 'reflect' pads by mirroring values excluding the edge + 'replicate' pads by repeating the edge values. + 'circular' pads by looping values from the other side Default is 'constant' pad_value: Optional[Union[float, Expr]] The value used for padding. Default is 0. diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index 5d942e5f645d..6a6f0ed6cb93 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -222,18 +222,31 @@ def _nn_conv2d_transpose(bb: BlockBuilder, call: Call) -> Expr: @register_legalize("relax.nn.pad") def _nn_pad(bb: BlockBuilder, call: Call) -> Expr: - # Unpack pad_width into two separate lists for topi. + pad_mode = call.attrs.pad_mode pad_widths = call.attrs.pad_width pad_before = pad_widths[::2] pad_after = pad_widths[1::2] - return bb.call_te( - topi.nn.pad, - call.args[0], - pad_before=pad_before, - pad_after=pad_after, - pad_value=call.attrs.pad_value, - primfunc_name_hint="pad", - ) + if pad_mode == "reflect": + return bb.call_te( + topi.nn.reflect_pad, call.args[0], pad_before=pad_before, pad_after=pad_after + ) + elif pad_mode == "replicate": + return bb.call_te( + topi.nn.replicate_pad, call.args[0], pad_before=pad_before, pad_after=pad_after + ) + elif pad_mode == "circular": + return bb.call_te( + topi.nn.circular_pad, call.args[0], pad_before=pad_before, pad_after=pad_after + ) + else: + return bb.call_te( + topi.nn.pad, + call.args[0], + pad_before=pad_before, + pad_after=pad_after, + pad_value=call.attrs.pad_value, + primfunc_name_hint="pad", + ) @register_legalize("relax.nn.max_pool1d") diff --git a/python/tvm/topi/nn/pad.py b/python/tvm/topi/nn/pad.py index 8833ef38d694..a3a7379d8da9 100644 --- a/python/tvm/topi/nn/pad.py +++ b/python/tvm/topi/nn/pad.py @@ -19,14 +19,46 @@ import tvm from tvm import te +from tvm.tir import if_then_else from .. import tag from ..utils import equal_const_int +def get_padded_shape(data, pad_before, pad_after=None): + """ + Calculates the output shape of a tensor after applying padding. + + Args: + data (tvm.te.Tensor): The input tensor to which padding is applied. + pad_before : list / tuple of n ints + Pad width on each dimension to pad the before the axis begin. + pad_after : list / tuple of n ints, optional + Pad width each dimension to pad the after the axis end. + + Raises: + ValueError: If `pad_before` or `pad_after` lengths mismatch with `data` dimensions. + + Returns: + tuple: A tuple representing the padded shape of the tensor. + """ + n = data.ndim + pad_after = pad_after if pad_after else pad_before + + if len(pad_before) != n: + raise ValueError(f"pad_before length {len(pad_before)} != input dims {n}") + if len(pad_after) != n: + raise ValueError(f"pad_after length {len(pad_after)} != input dims {n}") + + ana = tvm.arith.Analyzer() + out_shape = tuple(ana.simplify(data.shape[i] + pad_before[i] + pad_after[i]) for i in range(n)) + + return out_shape + + @tvm.te.tag_scope(tag=tag.INJECTIVE + ",pad") def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput", attrs=None): - """Pad Input with zeros. + """Pad Input with using pad values. Parameters ---------- @@ -145,3 +177,143 @@ def _pad(*indices): return data(*mapped_tuple) return te.compute(out_shape, _pad, name=name) + + +@tvm.te.tag_scope(tag=tag.INJECTIVE + ",pad") +def reflect_pad(data, pad_before, pad_after=None, name="ReflectPadInput"): + """ + Apply reflect padding to the input tensor. + + Parameters + ---------- + data : tvm.te.Tensor + Input tensor. + + pad_before : List[int] + Amount to pad before each dimension. + + pad_after : List[int], optional + Amount to pad after each dimension. If None, defaults to pad_before. + + name : str + Name of the resulting tensor. + + Returns + ------- + out : tvm.te.Tensor + Reflect-padded tensor. + """ + out_shape = get_padded_shape(data, pad_before, pad_after) + + def _pad(*indices): + index_tuple = [] + for i in range(data.ndim): + idx = indices[i] + size = data.shape[i] + before = pad_before[i] + + orig_idx = idx - before + + reflected_idx = if_then_else( + orig_idx < 0, + -orig_idx, # reflect from start (no repeat) + if_then_else( + orig_idx >= size, + (2 * size - 2) - orig_idx, # reflect from end + orig_idx, + ), + ) + index_tuple.append(reflected_idx) + return data(*index_tuple) + + return te.compute(out_shape, _pad, name=name) + + +@tvm.te.tag_scope(tag=tag.INJECTIVE + ",pad") +def replicate_pad(data, pad_before, pad_after=None, name="ReplicatePadInput"): + """ + Apply replicate padding (edge padding) to the input tensor. + + Parameters + ---------- + data : tvm.te.Tensor + Input tensor. + + pad_before : List[int] + Amount to pad before each dimension. + + pad_after : List[int], optional + Amount to pad after each dimension. If None, defaults to pad_before. + + name : str + Name of the resulting tensor. + + Returns + ------- + out : tvm.te.Tensor + Replicate-padded tensor. + """ + out_shape = get_padded_shape(data, pad_before, pad_after) + + def _pad(*indices): + index_tuple = [] + for i in range(data.ndim): + idx = indices[i] + size = data.shape[i] + before = pad_before[i] + + orig_idx = idx - before + clamped_idx = if_then_else( + orig_idx < 0, + tvm.tir.const(0, "int32"), # replicate first element + if_then_else( + orig_idx >= size, + size - 1, # replicate last element + orig_idx, + ), + ) + index_tuple.append(clamped_idx) + return data(*index_tuple) + + return te.compute(out_shape, _pad, name=name) + + +@tvm.te.tag_scope(tag=tag.INJECTIVE + ",pad") +def circular_pad(data, pad_before, pad_after=None, name="CircularPadInput"): + """ + Apply circular padding (wrap around) to the input tensor. + + Parameters + ---------- + data : tvm.te.Tensor + Input tensor. + + pad_before : List[int] + Amount to pad before each dimension. + + pad_after : List[int], optional + Amount to pad after each dimension. If None, defaults to pad_before. + + name : str + Name of the resulting tensor. + + Returns + ------- + out : tvm.te.Tensor + Circular-padded tensor. + """ + out_shape = get_padded_shape(data, pad_before, pad_after) + + def _pad(*indices): + index_tuple = [] + for i in range(data.ndim): + idx = indices[i] + size = data.shape[i] + before = pad_before[i] + + orig_idx = idx - before + wrapped_idx = tvm.tir.indexmod(orig_idx + size, size) + index_tuple.append(wrapped_idx) + return data(*index_tuple) + + return te.compute(out_shape, _pad, name=name) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 8db96849993a..4c60fcd6512c 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -1850,6 +1850,95 @@ def main( verify_model(model, example_args, binding, expected2) +def test_pad(): + class PadModel(torch.nn.Module): + def __init__(self, pad, mode="constant", value=0.0): + super().__init__() + self.pad = pad + self.mode = mode + self.value = value + + def forward(self, x): + if self.mode == "constant": + return torch.nn.functional.pad(x, self.pad, mode=self.mode, value=self.value) + else: + return torch.nn.functional.pad(x, self.pad, mode=self.mode) + + @tvm.script.ir_module + class expected_constant: + @R.function + def main( + x: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad( + x, + pad_width=[0, 0, 0, 0, 2, 2, 1, 1], + pad_mode="constant", + pad_value=0.0, + ) + gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv,) + R.output(gv) + return gv + + @tvm.script.ir_module + class expected_reflect: + @R.function + def main( + x: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad( + x, + pad_width=[0, 0, 0, 0, 2, 2, 1, 1], + pad_mode="reflect", + pad_value=0.0, + ) + gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv,) + R.output(gv) + return gv + + @tvm.script.ir_module + class expected_replicate: + @R.function + def main( + x: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad( + x, + pad_width=[0, 0, 0, 0, 2, 2, 1, 1], + pad_mode="replicate", + pad_value=0.0, + ) + gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv,) + R.output(gv) + return gv + + @tvm.script.ir_module + class expected_circular: + @R.function + def main( + x: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad( + x, + pad_width=[0, 0, 0, 0, 2, 2, 1, 1], + pad_mode="circular", + pad_value=0.0, + ) + gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(PadModel(pad=[1, 1, 2, 2]), example_args, {}, expected_constant) + verify_model(PadModel(pad=[1, 1, 2, 2], mode="reflect"), example_args, {}, expected_reflect) + verify_model(PadModel(pad=[1, 1, 2, 2], mode="replicate"), example_args, {}, expected_replicate) + verify_model(PadModel(pad=[1, 1, 2, 2], mode="circular"), example_args, {}, expected_circular) + + def test_einsum(): class Einsum1(Module): def __init__(self): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 4a2ca336e1e8..53c925e14ee6 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -503,6 +503,95 @@ def main( verify_model(model, input_info, binding, expected2) +def test_pad(): + class PadModel(torch.nn.Module): + def __init__(self, pad, mode="constant", value=0.0): + super().__init__() + self.pad = pad + self.mode = mode + self.value = value + + def forward(self, x): + if self.mode == "constant": + return torch.nn.functional.pad(x, self.pad, mode=self.mode, value=self.value) + else: + return torch.nn.functional.pad(x, self.pad, mode=self.mode) + + @tvm.script.ir_module + class expected_constant: + @R.function + def main( + x: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 14, 12), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad( + x, + pad_width=[0, 0, 0, 0, 2, 2, 1, 1], + pad_mode="constant", + pad_value=0.0, + ) + gv: R.Tensor((1, 3, 14, 12), dtype="float32") = lv + R.output(gv) + return gv + + @tvm.script.ir_module + class expected_reflect: + @R.function + def main( + x: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 14, 12), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad( + x, + pad_width=[0, 0, 0, 0, 2, 2, 1, 1], + pad_mode="reflect", + pad_value=0.0, + ) + gv: R.Tensor((1, 3, 14, 12), dtype="float32") = lv + R.output(gv) + return gv + + @tvm.script.ir_module + class expected_replicate: + @R.function + def main( + x: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 14, 12), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad( + x, + pad_width=[0, 0, 0, 0, 2, 2, 1, 1], + pad_mode="replicate", + pad_value=0.0, + ) + gv: R.Tensor((1, 3, 14, 12), dtype="float32") = lv + R.output(gv) + return gv + + @tvm.script.ir_module + class expected_circular: + @R.function + def main( + x: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 14, 12), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad( + x, + pad_width=[0, 0, 0, 0, 2, 2, 1, 1], + pad_mode="circular", + pad_value=0.0, + ) + gv: R.Tensor((1, 3, 14, 12), dtype="float32") = lv + R.output(gv) + return gv + + input_infos = [([1, 3, 10, 10], "float32")] + verify_model(PadModel(pad=[1, 1, 2, 2]), input_infos, {}, expected_constant) + verify_model(PadModel(pad=[1, 1, 2, 2], mode="reflect"), input_infos, {}, expected_reflect) + verify_model(PadModel(pad=[1, 1, 2, 2], mode="replicate"), input_infos, {}, expected_replicate) + verify_model(PadModel(pad=[1, 1, 2, 2], mode="circular"), input_infos, {}, expected_circular) + + def test_linear(): # nn.Linear class Dense1(Module): From abb37b0243b190dcbb75f27e168c446419473829 Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Tue, 15 Apr 2025 09:05:25 +0000 Subject: [PATCH 2/3] fix pad op arg handling in fx graph --- python/tvm/relax/frontend/torch/base_fx_graph_translator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index e4fd793e3f05..b656368ddc87 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -905,7 +905,7 @@ def _pad(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] pad = node.args[1] mode = node.args[2] if len(node.args) > 2 else node.kwargs.get("mode", "constant") - value = node.args[3] if len(node.args) > 3 else node.kwargs.get("value", 0.0) + value = node.args[3] if len(node.args) > 3 else 0.0 # Calculate symmetric padding width for each dimension # and applying them in reverse order to match the input dimensions. From 986ea7d67f6fd1365bfe21a451de8a1ad4b87322 Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Thu, 17 Apr 2025 05:12:58 +0000 Subject: [PATCH 3/3] fix issue by updated the retrieval of value arg --- python/tvm/relax/frontend/torch/base_fx_graph_translator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index b656368ddc87..3ea70df9a13f 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -905,7 +905,8 @@ def _pad(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] pad = node.args[1] mode = node.args[2] if len(node.args) > 2 else node.kwargs.get("mode", "constant") - value = node.args[3] if len(node.args) > 3 else 0.0 + value = node.args[3] if len(node.args) > 3 else node.kwargs.get("value", 0.0) + value = 0.0 if value is None else value # Calculate symmetric padding width for each dimension # and applying them in reverse order to match the input dimensions.