Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 99 additions & 1 deletion python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,50 @@ def _linear(self, node: fx.Node) -> relax.Var:
bias = args[2] if len(args) > 2 else None
return self.block_builder.emit(relax.op.linear(x, weight, bias, "float32"))

def _max_pool1d_impl(
self,
x: relax.Expr,
kernel_size: Union[int, Tuple[int]] = 1,
stride: Optional[Union[int, Tuple[int]]] = None,
padding: Optional[int] = 0,
dilation: Optional[int] = 1,
ceil_mode: Optional[bool] = False,
) -> relax.Var:
# Expand to 3D by adding batch dim if input is 2D
x_ndim = x.struct_info.ndim
if x_ndim == 2:
x = relax.op.expand_dims(x, axis=0)

stride = kernel_size if stride is None else stride

result = self.block_builder.emit(
relax.op.nn.max_pool1d(
x,
pool_size=kernel_size,
strides=stride,
padding=padding,
dilation=dilation,
ceil_mode=ceil_mode,
layout="NCW",
)
)

# Remove added batch dim from result
if x_ndim == 2:
result = relax.op.squeeze(result, axis=[0])
return result

def _max_pool1d(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
kernel_size = args[1]
stride = args[2] if len(args) > 2 else None
padding = args[3] if len(args) > 3 else 0
dilation = args[4] if len(args) > 4 else 1
ceil_mode = args[5] if len(args) > 5 else False

return self._max_pool1d_impl(x, kernel_size, stride, padding, dilation, ceil_mode)

def _max_pool2d_impl(
self,
x: relax.Expr,
Expand All @@ -871,8 +915,14 @@ def _max_pool2d_impl(
dilation: Optional[int] = 1,
ceil_mode: Optional[bool] = False,
) -> relax.Var:
# Expand to 4D by adding batch dim if input is 3D
x_ndim = x.struct_info.ndim
if x_ndim == 3:
x = relax.op.expand_dims(x, axis=0)

stride = kernel_size if stride is None else stride
return self.block_builder.emit(

result = self.block_builder.emit(
relax.op.nn.max_pool2d(
x,
pool_size=kernel_size,
Expand All @@ -884,6 +934,11 @@ def _max_pool2d_impl(
)
)

# Remove added batch dim from result
if x_ndim == 3:
result = relax.op.squeeze(result, axis=[0])
return result

def _max_pool2d(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
Expand All @@ -895,6 +950,49 @@ def _max_pool2d(self, node: fx.Node) -> relax.Var:

return self._max_pool2d_impl(x, kernel_size, stride, padding, dilation, ceil_mode)

def _max_pool3d_impl(
self,
x: relax.Expr,
kernel_size: Union[int, Tuple[int, int, int]] = (1, 1, 1),
stride: Optional[Union[int, Tuple[int, int, int]]] = None,
padding: Optional[int] = 0,
dilation: Optional[int] = 1,
ceil_mode: Optional[bool] = False,
) -> relax.Var:
# Expand to 5D by adding batch dim if input is 4D
x_ndim = x.struct_info.ndim
if x_ndim == 4:
x = relax.op.expand_dims(x, axis=0)

stride = kernel_size if stride is None else stride

result = self.block_builder.emit(
relax.op.nn.max_pool3d(
x,
pool_size=kernel_size,
strides=stride,
padding=padding,
dilation=dilation,
ceil_mode=ceil_mode,
layout="NCDHW",
)
)

# Remove added batch dim from result
if x_ndim == 4:
result = relax.op.squeeze(result, axis=[0])
return result

def _max_pool3d(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
kernel_size = args[1]
stride = args[2] if len(args) > 2 else None
padding = args[3] if len(args) > 3 else 0
dilation = args[4] if len(args) > 4 else 1
ceil_mode = args[5] if len(args) > 5 else False
return self._max_pool3d_impl(x, kernel_size, stride, padding, dilation, ceil_mode)

def _pad(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
pad = node.args[1]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,9 @@ def create_convert_map(
"group_norm.default": self._group_norm,
"layer_norm.default": self._layer_norm,
"linear.default": self._linear,
"max_pool1d.default": self._max_pool1d,
"max_pool2d.default": self._max_pool2d,
"max_pool3d.default": self._max_pool3d,
"scaled_dot_product_attention.default": self._scaled_dot_product_attention,
"unbind.int": self._unbind,
"upsample_bilinear2d.vec": self._upsample_bilinear2d,
Expand Down
26 changes: 26 additions & 0 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,17 @@ def _linear_module(self, node: fx.Node) -> relax.Var:
bias = self.params.get(module.bias, None)
return self.block_builder.emit(relax.op.linear(x, weight, bias, "float32"))

def _max_pool1d_module(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
kernel_size = module.kernel_size
stride = module.stride
padding = module.padding
dilation = module.dilation
ceil_mode = module.ceil_mode

return self._max_pool1d_impl(x, kernel_size, stride, padding, dilation, ceil_mode)

def _max_pool2d_module(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
Expand All @@ -460,6 +471,17 @@ def _max_pool2d_module(self, node: fx.Node) -> relax.Var:

return self._max_pool2d_impl(x, kernel_size, stride, padding, dilation, ceil_mode)

def _max_pool3d_module(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
kernel_size = module.kernel_size
stride = module.stride
padding = module.padding
dilation = module.dilation
ceil_mode = module.ceil_mode

return self._max_pool3d_impl(x, kernel_size, stride, padding, dilation, ceil_mode)

def _pixel_shuffle_module(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
Expand Down Expand Up @@ -661,7 +683,9 @@ def create_convert_map(
nn.GroupNorm: self._group_norm_module,
nn.LayerNorm: self._layer_norm_module,
nn.Linear: self._linear_module,
nn.MaxPool1d: self._max_pool1d_module,
nn.MaxPool2d: self._max_pool2d_module,
nn.MaxPool3d: self._max_pool3d_module,
nn.modules.sparse.Embedding: self._embedding_module,
nn.PixelShuffle: self._pixel_shuffle_module,
# tensor manipulation
Expand Down Expand Up @@ -772,7 +796,9 @@ def create_convert_map(
"interpolate": self._interpolate,
"layer_norm": self._layer_norm,
"linear": self._linear,
"max_pool1d": self._max_pool1d,
"max_pool2d": self._max_pool2d,
"max_pool3d": self._max_pool3d,
"scaled_dot_product_attention": self._scaled_dot_product_attention,
"stochastic_depth": lambda node: self.env[node.args[0]],
"unbind": self._unbind,
Expand Down
2 changes: 1 addition & 1 deletion src/relax/op/nn/pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ StructInfo InferStructInfoPool1D(const Call& call, const BlockBuilder& ctx) {

PrimExpr numerator_w = input_w + padding_w - attrs->dilation[0] * (kernel_w - 1) - 1;
if (attrs->ceil_mode) {
numerator_w += attrs->strides[1] - 1;
numerator_w += attrs->strides[0] - 1;
}
out_NCW_shape[2] = analyzer->Simplify(floordiv(numerator_w, attrs->strides[0]) + 1);

Expand Down
Loading