Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,24 @@ def _max_pool2d(self, node: fx.Node) -> relax.Var:

return self._max_pool2d_impl(x, kernel_size, stride, padding, dilation, ceil_mode)

def _pad(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
pad = node.args[1]
mode = node.args[2] if len(node.args) > 2 else node.kwargs.get("mode", "constant")
value = node.args[3] if len(node.args) > 3 else node.kwargs.get("value", 0.0)
value = 0.0 if value is None else value

# Calculate symmetric padding width for each dimension
# and applying them in reverse order to match the input dimensions.
input_ndim = x.struct_info.ndim
pad_width = [0] * (input_ndim * 2)
pad_pairs = [pad[i : i + 2] for i in range(0, len(pad), 2)]
reversed_pairs = list(reversed(pad_pairs))
flattened = [value for pair in reversed_pairs for value in pair]
pad_width[-len(flattened) :] = flattened

return self.block_builder.emit(relax.op.nn.pad(x, pad_width, mode, value))

def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var:
transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1, 3])
query = transpose_S_H(self.env[node.args[0]])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ def create_convert_map(
"log1p.default": self._log1p,
"log_softmax.int": self._log_softmax,
"neg.default": self._unary_op(relax.op.negative),
"pad.default": self._pad,
"prelu.default": self._prelu,
"reciprocal.default": self._reciprocal,
"relu.default": self._unary_op(relax.op.nn.relu),
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,7 @@ def create_convert_map(
"logical_not": self._unary_op(relax.op.logical_not),
"log_softmax": self._log_softmax,
"neg": self._unary_op(relax.op.negative),
"pad": self._pad,
"prelu": self._prelu,
"reciprocal": self._reciprocal,
"relu": self._unary_op(relax.op.nn.relu),
Expand Down
15 changes: 8 additions & 7 deletions python/tvm/relax/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,9 +515,9 @@ def conv2d_transpose(

def pad(
data: Expr,
pad_width: Tuple[Tuple[int, int], ...],
pad_width: Union[List[int], Tuple[int, ...]],
pad_mode: Optional[str] = "constant",
pad_value: Optional[Union[float, Expr]] = 0.0,
pad_value: Optional[float] = 0.0,
):
r"""Padding

Expand All @@ -528,14 +528,15 @@ def pad(
----------
data: relax.Expr
The input data to the operator
pad_width: Tuple[Tuple[int, int], ...], required
pad_width: Union[List[int], Tuple[int, ...]], required
Number of values padded to the edges of each axis, in the format
of ((before_1, after_1), ..., (before_N, after_N))
pad_mode: Optional[str]
'constant', 'edge', or 'reflect'
'constant' pads with constant_value pad_value
'edge' pads using the edge values of the input array
'reflect' pads by reflecting values with respect to the edge
'constant', 'reflect', 'replicate', 'circular'
'constant' pads with constant value pad_value
'reflect' pads by mirroring values excluding the edge
'replicate' pads by repeating the edge values.
'circular' pads by looping values from the other side
Default is 'constant'
pad_value: Optional[Union[float, Expr]]
The value used for padding. Default is 0.
Expand Down
31 changes: 22 additions & 9 deletions python/tvm/relax/transform/legalize_ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,18 +222,31 @@ def _nn_conv2d_transpose(bb: BlockBuilder, call: Call) -> Expr:

@register_legalize("relax.nn.pad")
def _nn_pad(bb: BlockBuilder, call: Call) -> Expr:
# Unpack pad_width into two separate lists for topi.
pad_mode = call.attrs.pad_mode
pad_widths = call.attrs.pad_width
pad_before = pad_widths[::2]
pad_after = pad_widths[1::2]
return bb.call_te(
topi.nn.pad,
call.args[0],
pad_before=pad_before,
pad_after=pad_after,
pad_value=call.attrs.pad_value,
primfunc_name_hint="pad",
)
if pad_mode == "reflect":
return bb.call_te(
topi.nn.reflect_pad, call.args[0], pad_before=pad_before, pad_after=pad_after
)
elif pad_mode == "replicate":
return bb.call_te(
topi.nn.replicate_pad, call.args[0], pad_before=pad_before, pad_after=pad_after
)
elif pad_mode == "circular":
return bb.call_te(
topi.nn.circular_pad, call.args[0], pad_before=pad_before, pad_after=pad_after
)
else:
return bb.call_te(
topi.nn.pad,
call.args[0],
pad_before=pad_before,
pad_after=pad_after,
pad_value=call.attrs.pad_value,
primfunc_name_hint="pad",
)


@register_legalize("relax.nn.max_pool1d")
Expand Down
174 changes: 173 additions & 1 deletion python/tvm/topi/nn/pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,46 @@

import tvm
from tvm import te
from tvm.tir import if_then_else

from .. import tag
from ..utils import equal_const_int


def get_padded_shape(data, pad_before, pad_after=None):
"""
Calculates the output shape of a tensor after applying padding.

Args:
data (tvm.te.Tensor): The input tensor to which padding is applied.
pad_before : list / tuple of n ints
Pad width on each dimension to pad the before the axis begin.
pad_after : list / tuple of n ints, optional
Pad width each dimension to pad the after the axis end.

Raises:
ValueError: If `pad_before` or `pad_after` lengths mismatch with `data` dimensions.

Returns:
tuple: A tuple representing the padded shape of the tensor.
"""
n = data.ndim
pad_after = pad_after if pad_after else pad_before

if len(pad_before) != n:
raise ValueError(f"pad_before length {len(pad_before)} != input dims {n}")
if len(pad_after) != n:
raise ValueError(f"pad_after length {len(pad_after)} != input dims {n}")

ana = tvm.arith.Analyzer()
out_shape = tuple(ana.simplify(data.shape[i] + pad_before[i] + pad_after[i]) for i in range(n))

return out_shape


@tvm.te.tag_scope(tag=tag.INJECTIVE + ",pad")
def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput", attrs=None):
"""Pad Input with zeros.
"""Pad Input with using pad values.

Parameters
----------
Expand Down Expand Up @@ -145,3 +177,143 @@ def _pad(*indices):
return data(*mapped_tuple)

return te.compute(out_shape, _pad, name=name)


@tvm.te.tag_scope(tag=tag.INJECTIVE + ",pad")
def reflect_pad(data, pad_before, pad_after=None, name="ReflectPadInput"):
"""
Apply reflect padding to the input tensor.

Parameters
----------
data : tvm.te.Tensor
Input tensor.

pad_before : List[int]
Amount to pad before each dimension.

pad_after : List[int], optional
Amount to pad after each dimension. If None, defaults to pad_before.

name : str
Name of the resulting tensor.

Returns
-------
out : tvm.te.Tensor
Reflect-padded tensor.
"""
out_shape = get_padded_shape(data, pad_before, pad_after)

def _pad(*indices):
index_tuple = []
for i in range(data.ndim):
idx = indices[i]
size = data.shape[i]
before = pad_before[i]

orig_idx = idx - before

reflected_idx = if_then_else(
orig_idx < 0,
-orig_idx, # reflect from start (no repeat)
if_then_else(
orig_idx >= size,
(2 * size - 2) - orig_idx, # reflect from end
orig_idx,
),
)
index_tuple.append(reflected_idx)
return data(*index_tuple)

return te.compute(out_shape, _pad, name=name)


@tvm.te.tag_scope(tag=tag.INJECTIVE + ",pad")
def replicate_pad(data, pad_before, pad_after=None, name="ReplicatePadInput"):
"""
Apply replicate padding (edge padding) to the input tensor.

Parameters
----------
data : tvm.te.Tensor
Input tensor.

pad_before : List[int]
Amount to pad before each dimension.

pad_after : List[int], optional
Amount to pad after each dimension. If None, defaults to pad_before.

name : str
Name of the resulting tensor.

Returns
-------
out : tvm.te.Tensor
Replicate-padded tensor.
"""
out_shape = get_padded_shape(data, pad_before, pad_after)

def _pad(*indices):
index_tuple = []
for i in range(data.ndim):
idx = indices[i]
size = data.shape[i]
before = pad_before[i]

orig_idx = idx - before
clamped_idx = if_then_else(
orig_idx < 0,
tvm.tir.const(0, "int32"), # replicate first element
if_then_else(
orig_idx >= size,
size - 1, # replicate last element
orig_idx,
),
)
index_tuple.append(clamped_idx)
return data(*index_tuple)

return te.compute(out_shape, _pad, name=name)


@tvm.te.tag_scope(tag=tag.INJECTIVE + ",pad")
def circular_pad(data, pad_before, pad_after=None, name="CircularPadInput"):
"""
Apply circular padding (wrap around) to the input tensor.

Parameters
----------
data : tvm.te.Tensor
Input tensor.

pad_before : List[int]
Amount to pad before each dimension.

pad_after : List[int], optional
Amount to pad after each dimension. If None, defaults to pad_before.

name : str
Name of the resulting tensor.

Returns
-------
out : tvm.te.Tensor
Circular-padded tensor.
"""
out_shape = get_padded_shape(data, pad_before, pad_after)

def _pad(*indices):
index_tuple = []
for i in range(data.ndim):
idx = indices[i]
size = data.shape[i]
before = pad_before[i]

orig_idx = idx - before
wrapped_idx = tvm.tir.indexmod(orig_idx + size, size)
index_tuple.append(wrapped_idx)
return data(*index_tuple)

return te.compute(out_shape, _pad, name=name)
Loading