diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 5ae05ab89160..f789eb8af35b 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -580,6 +580,48 @@ def _addmm(self, node: fx.Node) -> relax.Var: res = bias if res is None else self.block_builder.emit(relax.op.add(bias, res)) return res + def _avg_pool1d_impl( + self, + x: relax.Expr, + kernel_size: Union[int, Tuple[int]] = 1, + stride: Optional[Union[int, Tuple[int]]] = None, + padding: Optional[int] = 0, + ceil_mode: Optional[bool] = False, + count_include_pad: Optional[bool] = True, + ) -> relax.Var: + # Expand to 3D by adding batch dim if input is 2D + x_ndim = x.struct_info.ndim + if x_ndim == 2: + x = relax.op.expand_dims(x, axis=0) + stride = kernel_size if stride is None or stride == [] else stride + + result = self.block_builder.emit( + relax.op.nn.avg_pool1d( + x, + pool_size=kernel_size, + strides=stride, + padding=padding, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + layout="NCW", + ) + ) + # Remove added batch dim from result + if x_ndim == 2: + result = relax.op.squeeze(result, axis=[0]) + return result + + def _avg_pool1d(self, node: fx.Node) -> relax.Var: + args, kwargs = node.normalized_arguments(node) + x = self.env[args[0]] + kernel_size = args[1] if len(args) > 1 else kwargs["kernel_size"] + stride = args[2] if len(args) > 2 else kwargs.get("stride", None) + padding = args[3] if len(args) > 3 else kwargs.get("padding", 0) + ceil_mode = args[4] if len(args) > 4 else kwargs.get("ceil_mode", False) + count_include_pad = args[5] if len(args) > 5 else kwargs.get("count_include_pad", True) + + return self._avg_pool1d_impl(x, kernel_size, stride, padding, ceil_mode, count_include_pad) + def _avg_pool2d_impl( self, x: relax.Expr, @@ -588,8 +630,13 @@ def _avg_pool2d_impl( padding: Optional[int] = 0, ceil_mode: Optional[bool] = False, ) -> relax.Var: + # Expand to 4D by adding batch dim if input is 3D + x_ndim = x.struct_info.ndim + if x_ndim == 3: + x = relax.op.expand_dims(x, axis=0) stride = kernel_size if stride is None or stride == [] else stride - return self.block_builder.emit( + + result = self.block_builder.emit( relax.op.nn.avg_pool2d( x, pool_size=kernel_size, @@ -599,6 +646,10 @@ def _avg_pool2d_impl( layout="NCHW", ) ) + # Remove added batch dim from result + if x_ndim == 3: + result = relax.op.squeeze(result, axis=[0]) + return result def _avg_pool2d(self, node: fx.Node) -> relax.Var: args, kwargs = node.normalized_arguments(node) @@ -609,6 +660,48 @@ def _avg_pool2d(self, node: fx.Node) -> relax.Var: ceil_mode = args[4] if len(args) > 4 else kwargs.get("ceil_mode", False) return self._avg_pool2d_impl(x, kernel_size, stride, padding, ceil_mode) + def _avg_pool3d_impl( + self, + x: relax.Expr, + kernel_size: Union[int, Tuple[int, int, int]] = (1, 1, 1), + stride: Optional[Union[int, Tuple[int, int, int]]] = None, + padding: Optional[int] = 0, + ceil_mode: Optional[bool] = False, + count_include_pad: Optional[bool] = True, + ) -> relax.Var: + # Expand to 5D by adding batch dim if input is 4D + x_ndim = x.struct_info.ndim + if x_ndim == 4: + x = relax.op.expand_dims(x, axis=0) + stride = kernel_size if stride is None or stride == [] else stride + + result = self.block_builder.emit( + relax.op.nn.avg_pool3d( + x, + pool_size=kernel_size, + strides=stride, + padding=padding, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + layout="NCDHW", + ) + ) + # Remove added batch dim from result + if x_ndim == 4: + result = relax.op.squeeze(result, axis=[0]) + return result + + def _avg_pool3d(self, node: fx.Node) -> relax.Var: + args, kwargs = node.normalized_arguments(node) + x = self.env[args[0]] + kernel_size = args[1] if len(args) > 1 else kwargs["kernel_size"] + stride = args[2] if len(args) > 2 else kwargs.get("stride", None) + padding = args[3] if len(args) > 3 else kwargs.get("padding", 0) + ceil_mode = args[4] if len(args) > 4 else kwargs.get("ceil_mode", False) + count_include_pad = args[5] if len(args) > 5 else kwargs.get("count_include_pad", True) + + return self._avg_pool3d_impl(x, kernel_size, stride, padding, ceil_mode, count_include_pad) + def _baddbmm(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] batch1 = self.env[node.args[1]] diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index dbe37b886017..fc37fd3fb9a6 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -401,7 +401,9 @@ def create_convert_map( "adaptive_avg_pool2d.default": self._adaptive_avg_pool2d, "adaptive_avg_pool3d.default": self._adaptive_avg_pool3d, "addmm.default": self._addmm, + "avg_pool1d.default": self._avg_pool1d, "avg_pool2d.default": self._avg_pool2d, + "avg_pool3d.default": self._avg_pool3d, "baddbmm.default": self._baddbmm, "bmm.default": self._binary_op( partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 8081a98f59ca..0e8814dd974e 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -230,6 +230,15 @@ def _adaptive_avg_pool3d_module(self, node: fx.Node) -> relax.Var: result = relax.op.squeeze(result, axis=[0]) return result + def _avg_pool1d_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + kernel_size = module.kernel_size + stride = module.stride + padding = module.padding + ceil_mode = module.ceil_mode + return self._avg_pool1d_impl(x, kernel_size, stride, padding, ceil_mode) + def _avg_pool2d_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] @@ -239,6 +248,15 @@ def _avg_pool2d_module(self, node: fx.Node) -> relax.Var: ceil_mode = module.ceil_mode return self._avg_pool2d_impl(x, kernel_size, stride, padding, ceil_mode) + def _avg_pool3d_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + kernel_size = module.kernel_size + stride = module.stride + padding = module.padding + ceil_mode = module.ceil_mode + return self._avg_pool3d_impl(x, kernel_size, stride, padding, ceil_mode) + def _batch_norm_2d_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] @@ -711,7 +729,9 @@ def create_convert_map( nn.AdaptiveAvgPool1d: self._adaptive_avg_pool1d_module, nn.AdaptiveAvgPool2d: self._adaptive_avg_pool2d_module, nn.AdaptiveAvgPool3d: self._adaptive_avg_pool3d_module, + nn.AvgPool1d: self._avg_pool1d_module, nn.AvgPool2d: self._avg_pool2d_module, + nn.AvgPool3d: self._avg_pool3d_module, nn.BatchNorm2d: self._batch_norm_2d_module, nn.Conv1d: self._conv1d_module, nn.Conv2d: self._conv2d_module, @@ -824,7 +844,9 @@ def create_convert_map( "adaptive_avg_pool2d": self._adaptive_avg_pool2d, "adaptive_avg_pool3d": self._adaptive_avg_pool3d, "addmm": self._addmm, + "avg_pool1d": self._avg_pool1d, "avg_pool2d": self._avg_pool2d, + "avg_pool3d": self._avg_pool3d, "baddbmm": self._baddbmm, "bmm": self._binary_op( partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index e234e8ad7b18..b68d488e26df 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -840,7 +840,7 @@ def avg_pool1d( padding: Union[int, Tuple[int, ...]] = (0, 0), dilation: Union[int, Tuple[int, int]] = (1,), ceil_mode: bool = False, - count_include_pad: bool = False, + count_include_pad: bool = True, layout: str = "NCW", out_layout: Optional[str] = None, ) -> Expr: @@ -1008,7 +1008,7 @@ def avg_pool3d( padding: Union[int, Tuple[int, ...]] = (0, 0, 0), dilation: Union[int, Tuple[int, int]] = (1, 1, 1), ceil_mode: bool = False, - count_include_pad: bool = False, + count_include_pad: bool = True, layout: str = "NCDHW", out_layout: Optional[str] = None, ) -> Expr: diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index ef198d2f83f3..b0aebff7049b 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -1367,6 +1367,102 @@ def main( verify_model(Addmm2(), example_args, {}, expected2) +def test_avg_pool1d(): + class AvgPool1d1(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AvgPool1d(kernel_size=1) + + def forward(self, input): + return self.pool(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 10), dtype="float32") = R.nn.avg_pool1d( + input_1, + pool_size=[1], + strides=[1], + dilation=[1], + padding=[0, 0], + ceil_mode=False, + count_include_pad=True, + layout="NCW", + out_layout="NCW", + ) + gv: R.Tuple(R.Tensor((1, 3, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + class AvgPool1d2(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AvgPool1d(kernel_size=3, stride=2, padding=1, ceil_mode=True) + + def forward(self, input): + return self.pool(input) + + class AvgPool1d3(Module): + def forward(self, input): + return torch.nn.functional.avg_pool1d( + input, kernel_size=3, stride=2, padding=1, ceil_mode=True + ) + + @tvm.script.ir_module + class expected2: + @R.function + def main(input_1: R.Tensor((1, 3, 10), dtype="float32")): + with R.dataflow(): + lv = R.nn.avg_pool1d( + input_1, + pool_size=[3], + strides=[2], + dilation=[1], + padding=[1, 1], + ceil_mode=True, + count_include_pad=True, + layout="NCW", + out_layout="NCW", + ) + gv = (lv,) + R.output(gv) + return gv + + class AvgPool1d4(Module): + def forward(self, input): + return torch.nn.functional.avg_pool1d(input, kernel_size=2, stride=2, padding=0) + + @tvm.script.ir_module + class expected3: + @R.function + def main(input_1: R.Tensor((1, 3, 10), dtype="float32")): + with R.dataflow(): + lv = R.nn.avg_pool1d( + input_1, + pool_size=[2], + strides=[2], + dilation=[1], + padding=[0, 0], + ceil_mode=False, + count_include_pad=True, + layout="NCW", + out_layout="NCW", + ) + gv = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, dtype=torch.float32),) + verify_model(AvgPool1d1(), example_args, {}, expected1) + verify_model(AvgPool1d2(), example_args, {}, expected2) + verify_model(AvgPool1d3(), example_args, {}, expected2) + verify_model(AvgPool1d4(), example_args, {}, expected3) + + def test_avg_pool2d(): class AvgPool2d1(Module): def __init__(self): @@ -1460,6 +1556,102 @@ def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")): verify_model(AvgPool2d4(), example_args, {}, expected3) +def test_avg_pool3d(): + class AvgPool3d1(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AvgPool3d(kernel_size=1) + + def forward(self, input): + return self.pool(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 8, 8, 8), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 8, 8, 8), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 8, 8, 8), dtype="float32") = R.nn.avg_pool3d( + input_1, + pool_size=[1, 1, 1], + strides=[1, 1, 1], + dilation=[1, 1, 1], + padding=[0, 0, 0, 0, 0, 0], + ceil_mode=False, + count_include_pad=True, + layout="NCDHW", + out_layout="NCDHW", + ) + gv: R.Tuple(R.Tensor((1, 3, 8, 8, 8), dtype="float32")) = (lv,) + R.output(gv) + return gv + + class AvgPool3d2(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AvgPool3d(kernel_size=3, stride=2, padding=1, ceil_mode=True) + + def forward(self, input): + return self.pool(input) + + class AvgPool3d3(Module): + def forward(self, input): + return torch.nn.functional.avg_pool3d( + input, kernel_size=3, stride=2, padding=1, ceil_mode=True + ) + + @tvm.script.ir_module + class expected2: + @R.function + def main(input_1: R.Tensor((1, 3, 8, 8, 8), dtype="float32")): + with R.dataflow(): + lv = R.nn.avg_pool3d( + input_1, + pool_size=[3, 3, 3], + strides=[2, 2, 2], + dilation=[1, 1, 1], + padding=[1, 1, 1, 1, 1, 1], + ceil_mode=True, + count_include_pad=True, + layout="NCDHW", + out_layout="NCDHW", + ) + gv = (lv,) + R.output(gv) + return gv + + class AvgPool3d4(Module): + def forward(self, input): + return torch.nn.functional.avg_pool3d(input, kernel_size=[2, 1, 2], stride=[2, 1, 2]) + + @tvm.script.ir_module + class expected3: + @R.function + def main(input_1: R.Tensor((1, 3, 8, 8, 8), dtype="float32")): + with R.dataflow(): + lv = R.nn.avg_pool3d( + input_1, + pool_size=[2, 1, 2], + strides=[2, 1, 2], + dilation=[1, 1, 1], + padding=[0, 0, 0, 0, 0, 0], + ceil_mode=False, + count_include_pad=True, + layout="NCDHW", + out_layout="NCDHW", + ) + gv = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 8, 8, 8, dtype=torch.float32),) + verify_model(AvgPool3d1(), example_args, {}, expected1) + verify_model(AvgPool3d2(), example_args, {}, expected2) + verify_model(AvgPool3d3(), example_args, {}, expected2) + verify_model(AvgPool3d4(), example_args, {}, expected3) + + def test_baddbmm(): class BAddBMM1(Module): def __init__(self): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 681474244ae8..53efab4e80cc 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -1287,6 +1287,103 @@ def main( verify_model(MaxPool3d3(), input_info, {}, expected3) +def test_avgpool1d(): + input_info = [([1, 3, 10], "float32")] + + class AvgPool1d(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AvgPool1d(kernel_size=1) + + def forward(self, input): + return self.pool(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10), dtype="float32"): + with R.dataflow(): + lv = R.nn.avg_pool1d( + input_1, + pool_size=[1], + strides=[1], + dilation=[1], + padding=[0, 0], + ceil_mode=False, + count_include_pad=True, + layout="NCW", + out_layout="NCW", + ) + gv = lv + R.output(gv) + return gv + + class AvgPool1d2(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AvgPool1d(kernel_size=4, stride=2, padding=2, ceil_mode=True) + + def forward(self, input): + return self.pool(input) + + class AvgPool1d3(Module): + def forward(self, input): + return torch.nn.functional.avg_pool1d( + input, kernel_size=4, stride=2, padding=2, ceil_mode=True + ) + + @tvm.script.ir_module + class expected2: + @R.function + def main(input_1: R.Tensor((1, 3, 10), dtype="float32")): + with R.dataflow(): + lv = R.nn.avg_pool1d( + input_1, + pool_size=[4], + strides=[2], + dilation=[1], + padding=[2, 2], + ceil_mode=True, + count_include_pad=True, + layout="NCW", + out_layout="NCW", + ) + gv = lv + R.output(gv) + return gv + + class AvgPool1d4(Module): + def forward(self, input): + return torch.nn.functional.avg_pool1d(input, kernel_size=2) + + @tvm.script.ir_module + class expected3: + @R.function + def main(input_1: R.Tensor((1, 3, 10), dtype="float32")): + with R.dataflow(): + lv = R.nn.avg_pool1d( + input_1, + pool_size=[2], + strides=[2], + dilation=[1], + padding=[0, 0], + ceil_mode=False, + count_include_pad=True, + layout="NCW", + out_layout="NCW", + ) + gv = lv + R.output(gv) + return gv + + verify_model(AvgPool1d(), input_info, {}, expected1) + verify_model(AvgPool1d2(), input_info, {}, expected2) + verify_model(AvgPool1d3(), input_info, {}, expected2) + verify_model(AvgPool1d4(), input_info, {}, expected3) + + def test_avgpool2d(): input_info = [([1, 3, 10, 10], "float32")] @@ -1381,6 +1478,105 @@ def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")): verify_model(AvgPool2d4(), input_info, {}, expected3) +def test_avgpool3d(): + input_info = [([1, 3, 8, 8, 8], "float32")] + + class AvgPool3d(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AvgPool3d(kernel_size=[1, 1, 1]) + + def forward(self, input): + return self.pool(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 8, 8, 8), dtype="float32") + ) -> R.Tensor((1, 3, 8, 8, 8), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 3, 8, 8, 8), dtype="float32") = R.nn.avg_pool3d( + input_1, + pool_size=[1, 1, 1], + strides=[1, 1, 1], + dilation=[1, 1, 1], + padding=[0, 0, 0, 0, 0, 0], + ceil_mode=False, + count_include_pad=True, + layout="NCDHW", + out_layout="NCDHW", + ) + gv: R.Tensor((1, 3, 8, 8, 8), dtype="float32") = lv + R.output(gv) + return gv + + class AvgPool3d2(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AvgPool3d( + kernel_size=[3, 3, 3], stride=2, padding=1, ceil_mode=True + ) + + def forward(self, input): + return self.pool(input) + + class AvgPool3d3(Module): + def forward(self, input): + return torch.nn.functional.avg_pool3d( + input, kernel_size=[3, 3, 3], stride=2, padding=1, ceil_mode=True + ) + + @tvm.script.ir_module + class expected2: + @R.function + def main(input_1: R.Tensor((1, 3, 8, 8, 8), dtype="float32")): + with R.dataflow(): + lv = R.nn.avg_pool3d( + input_1, + pool_size=[3, 3, 3], + strides=[2, 2, 2], + dilation=[1, 1, 1], + padding=[1, 1, 1, 1, 1, 1], + ceil_mode=True, + count_include_pad=True, + layout="NCDHW", + out_layout="NCDHW", + ) + gv = lv + R.output(gv) + return gv + + class AvgPool3d4(Module): + def forward(self, input): + return torch.nn.functional.avg_pool3d(input, kernel_size=[2, 1, 2], stride=[2, 1, 2]) + + @tvm.script.ir_module + class expected3: + @R.function + def main(input_1: R.Tensor((1, 3, 8, 8, 8), dtype="float32")): + with R.dataflow(): + lv = R.nn.avg_pool3d( + input_1, + pool_size=[2, 1, 2], + strides=[2, 1, 2], + dilation=[1, 1, 1], + padding=[0, 0, 0, 0, 0, 0], + ceil_mode=False, + count_include_pad=True, + layout="NCDHW", + out_layout="NCDHW", + ) + gv = lv + R.output(gv) + return gv + + verify_model(AvgPool3d(), input_info, {}, expected1) + verify_model(AvgPool3d2(), input_info, {}, expected2) + verify_model(AvgPool3d3(), input_info, {}, expected2) + verify_model(AvgPool3d4(), input_info, {}, expected3) + + def test_adaptive_avgpool1d(): input_info = [([1, 3, 16], "float32")] diff --git a/tests/python/relax/test_op_nn_pooling.py b/tests/python/relax/test_op_nn_pooling.py index 846338a93781..d4461a122de8 100644 --- a/tests/python/relax/test_op_nn_pooling.py +++ b/tests/python/relax/test_op_nn_pooling.py @@ -30,7 +30,9 @@ def test_op_correctness(): assert relax.op.nn.max_pool1d(x1).op == Op.get("relax.nn.max_pool1d") assert relax.op.nn.max_pool2d(x).op == Op.get("relax.nn.max_pool2d") assert relax.op.nn.max_pool3d(x2).op == Op.get("relax.nn.max_pool3d") + assert relax.op.nn.avg_pool1d(x).op == Op.get("relax.nn.avg_pool1d") assert relax.op.nn.avg_pool2d(x).op == Op.get("relax.nn.avg_pool2d") + assert relax.op.nn.avg_pool3d(x).op == Op.get("relax.nn.avg_pool3d") assert relax.op.nn.adaptive_avg_pool1d(x).op == Op.get("relax.nn.adaptive_avg_pool1d") assert relax.op.nn.adaptive_avg_pool2d(x).op == Op.get("relax.nn.adaptive_avg_pool2d") assert relax.op.nn.adaptive_avg_pool3d(x).op == Op.get("relax.nn.adaptive_avg_pool3d") @@ -713,6 +715,214 @@ def test_max_pool3d_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.nn.max_pool3d(x1)) +def test_avg_pool1d_infer_struct_info(): + bb = relax.BlockBuilder() + vdev0 = VDevice("llvm") + x0 = relax.Var("x", R.Tensor((2, 3, 32), "float32")) + x1 = relax.Var("x", R.Tensor((2, 32, 3), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=3)) + x3 = relax.Var("x", R.Tensor("float32")) + x4 = relax.Var("x", R.Tensor(ndim=3)) + x5 = relax.Var("x", R.Tensor()) + x6 = relax.Var("x", R.Tensor((2, 4, 32, 16), "float32")) + x7 = relax.Var("x", R.Tensor((2, 3, 32), "float32", vdev0)) + + _check_inference(bb, relax.op.nn.avg_pool1d(x0), relax.TensorStructInfo((2, 3, 32), "float32")) + _check_inference( + bb, relax.op.nn.avg_pool1d(x7), relax.TensorStructInfo((2, 3, 32), "float32", vdev0) + ) + _check_inference( + bb, + relax.op.nn.avg_pool1d(x0, pool_size=3), + relax.TensorStructInfo((2, 3, 30), "float32"), + ) + _check_inference( + bb, + relax.op.nn.avg_pool1d(x0, padding=1), + relax.TensorStructInfo((2, 3, 34), "float32"), + ) + _check_inference( + bb, + relax.op.nn.avg_pool1d(x0, padding=[1, 2]), + relax.TensorStructInfo((2, 3, 35), "float32"), + ) + _check_inference( + bb, + relax.op.nn.avg_pool1d(x0, strides=2), + relax.TensorStructInfo((2, 3, 16), "float32"), + ) + _check_inference( + bb, + relax.op.nn.avg_pool1d(x0, dilation=2), + relax.TensorStructInfo((2, 3, 32), "float32"), + ) + _check_inference( + bb, + relax.op.nn.avg_pool1d(x1, layout="NWC"), + relax.TensorStructInfo((2, 32, 3), "float32"), + ) + _check_inference( + bb, + relax.op.nn.avg_pool1d(x0, out_layout="NWC"), + relax.TensorStructInfo((2, 32, 3), "float32"), + ) + _check_inference( + bb, relax.op.nn.avg_pool1d(x2), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.nn.avg_pool1d(x3), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference(bb, relax.op.nn.avg_pool1d(x4), relax.TensorStructInfo(dtype="", ndim=3)) + _check_inference(bb, relax.op.nn.avg_pool1d(x5), relax.TensorStructInfo(dtype="", ndim=3)) + + +def test_avg_pool1d_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + c = tir.Var("c", "int64") + c16 = tir.Var("c16", "int64") + iw = tir.Var("iw", "int64") + x0 = relax.Var("x", R.Tensor((n, c, iw), "float32")) + x1 = relax.Var("x", R.Tensor((n, c, iw, c16), "float32")) + + _check_inference( + bb, + relax.op.nn.avg_pool1d(x0, pool_size=3, strides=3, padding=2, dilation=2), + relax.TensorStructInfo( + ( + n, + c, + tvm.tir.floordiv(iw - 1, 3) + 1, + ), + "float32", + ), + ) + _check_inference( + bb, + relax.op.nn.avg_pool1d(x1, layout="NCW16c", out_layout="NWC"), + relax.TensorStructInfo((n, iw, c * 16), "float32"), + ) + + +def test_avg_pool1d_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference( + bb, relax.op.nn.avg_pool1d(x0), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, + relax.op.nn.avg_pool1d(x1, layout="NCW16c"), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference( + bb, + relax.op.nn.avg_pool1d(x2), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + + +def test_avg_pool1d_infer_struct_info_ceil_mode(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 32), "float32")) + + _check_inference( + bb, + relax.op.nn.avg_pool1d(x, pool_size=3, strides=2, ceil_mode=True), + relax.TensorStructInfo((2, 3, 16), "float32"), + ) + _check_inference( + bb, + relax.op.nn.avg_pool1d(x, pool_size=5, strides=2, ceil_mode=True), + relax.TensorStructInfo((2, 3, 15), "float32"), + ) + + +def test_avg_pool1d_infer_struct_info_ceil_mode_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + c = tir.Var("c", "int64") + iw = tir.Var("iw", "int64") + x = relax.Var("x", R.Tensor((n, c, iw), "float32")) + + _check_inference( + bb, + relax.op.nn.avg_pool1d(x, pool_size=3, strides=2, padding=1, dilation=2, ceil_mode=True), + relax.TensorStructInfo( + (n, c, tvm.tir.floordiv(iw, 2)), + "float32", + ), + ) + + +def test_avg_pool1d_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 32), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 32), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3, 32), "int64")) + _check_inference(bb, relax.op.nn.avg_pool1d(x0), relax.TensorStructInfo((2, 3, 32), "float16")) + _check_inference(bb, relax.op.nn.avg_pool1d(x1), relax.TensorStructInfo((2, 3, 32), "int8")) + _check_inference(bb, relax.op.nn.avg_pool1d(x2), relax.TensorStructInfo((2, 3, 32), "int64")) + + +def test_avg_pool1d_stride_padding_dilation_int64(): + x = relax.Var("x", R.Tensor((2, 3, 28), "float32")) + avg_pool1d = relax.op.nn.avg_pool1d(x, 3, strides=1, padding=1, dilation=1) + + assert avg_pool1d.attrs.strides[0].dtype == "int64" + assert avg_pool1d.attrs.padding[0].dtype == "int64" + assert avg_pool1d.attrs.padding[1].dtype == "int64" + assert avg_pool1d.attrs.dilation[0].dtype == "int64" + + +def test_avg_pool1d_wrong_pool_size_strides_padding_dilation_length(): + x = relax.Var("x", R.Tensor((2, 3, 28), "float32")) + with pytest.raises(TVMError): + relax.op.nn.avg_pool1d(x, pool_size=(1, 2)) + with pytest.raises(TVMError): + relax.op.nn.avg_pool1d(x, strides=(1, 2)) + with pytest.raises(TVMError): + relax.op.nn.avg_pool1d(x, padding=(1, 2, 3)) + with pytest.raises(TVMError): + relax.op.nn.avg_pool1d(x, dilation=(1, 2)) + + +def test_avg_pool1d_infer_struct_info_wrong_layout_string(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 28), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.avg_pool1d(x, layout="OIW")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.avg_pool1d(x, out_layout="OWI")) + + +def test_avg_pool1d_wrong_input_ndim(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=2)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.avg_pool1d(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.avg_pool1d(x1)) + + +def test_avg_pool1d_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 28), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.avg_pool1d(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.avg_pool1d(x1)) + + def test_avg_pool2d_infer_struct_info(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") @@ -943,6 +1153,262 @@ def test_avg_pool2d_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.nn.avg_pool2d(x1)) +def test_avg_pool3d_infer_struct_info(): + bb = relax.BlockBuilder() + vdev0 = VDevice("llvm") + + x0 = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "float32")) + x1 = relax.Var("x", R.Tensor((2, 32, 32, 32, 3), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=5)) + x3 = relax.Var("x", R.Tensor("float32")) + x4 = relax.Var("x", R.Tensor(ndim=5)) + x5 = relax.Var("x", R.Tensor()) + x6 = relax.Var("x", R.Tensor((2, 4, 32, 32, 32, 16), "float32")) + x7 = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "float32", vdev0)) + + _check_inference( + bb, relax.op.nn.avg_pool3d(x0), relax.TensorStructInfo((2, 3, 32, 32, 32), "float32") + ) + _check_inference( + bb, relax.op.nn.avg_pool3d(x7), relax.TensorStructInfo((2, 3, 32, 32, 32), "float32", vdev0) + ) + _check_inference( + bb, + relax.op.nn.avg_pool3d(x0, pool_size=3), + relax.TensorStructInfo((2, 3, 30, 30, 30), "float32"), + ) + _check_inference( + bb, + relax.op.nn.avg_pool3d(x0, pool_size=(5, 3, 3)), + relax.TensorStructInfo((2, 3, 28, 30, 30), "float32"), + ) + _check_inference( + bb, + relax.op.nn.avg_pool3d(x0, padding=1), + relax.TensorStructInfo((2, 3, 34, 34, 34), "float32"), + ) + _check_inference( + bb, + relax.op.nn.avg_pool3d(x0, padding=[1, 2, 3]), + relax.TensorStructInfo((2, 3, 34, 36, 38), "float32"), + ) + _check_inference( + bb, + relax.op.nn.avg_pool3d(x0, strides=2), + relax.TensorStructInfo((2, 3, 16, 16, 16), "float32"), + ) + _check_inference( + bb, + relax.op.nn.avg_pool3d(x0, dilation=2), + relax.TensorStructInfo((2, 3, 32, 32, 32), "float32"), + ) + _check_inference( + bb, + relax.op.nn.avg_pool3d(x1, layout="NCDHW"), + relax.TensorStructInfo((2, 32, 32, 32, 3), "float32"), + ) + _check_inference( + bb, + relax.op.nn.avg_pool3d(x0, out_layout="NCDHW"), + relax.TensorStructInfo((2, 3, 32, 32, 32), "float32"), + ) + _check_inference( + bb, + relax.op.nn.avg_pool3d(x6, layout="NCDHW16c", out_layout="NDHWC16c"), + relax.TensorStructInfo((2, 32, 32, 32, 4, 16), "float32"), + ) + _check_inference( + bb, relax.op.nn.avg_pool3d(x2), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference( + bb, relax.op.nn.avg_pool3d(x3), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference(bb, relax.op.nn.avg_pool3d(x4), relax.TensorStructInfo(dtype="", ndim=5)) + _check_inference(bb, relax.op.nn.avg_pool3d(x5), relax.TensorStructInfo(dtype="", ndim=5)) + + +def test_avg_pool3d_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + c = tir.Var("c", "int64") + c16 = tir.Var("c16", "int64") + id_ = tir.Var("id", "int64") + ih = tir.Var("ih", "int64") + iw = tir.Var("iw", "int64") + x0 = relax.Var("x", R.Tensor((n, c, id_, ih, iw), "float32")) + x1 = relax.Var("x", R.Tensor((n, c, id_, ih, iw, c16), "float32")) + + _check_inference( + bb, + relax.op.nn.avg_pool3d( + x0, pool_size=(3, 3, 3), strides=(3, 3, 3), padding=(2, 2, 2), dilation=(2, 2, 2) + ), + relax.TensorStructInfo( + ( + n, + c, + tvm.tir.floordiv(id_ - 1, 3) + 1, + tvm.tir.floordiv(ih - 1, 3) + 1, + tvm.tir.floordiv(iw - 1, 3) + 1, + ), + "float32", + ), + ) + _check_inference( + bb, + relax.op.nn.avg_pool3d(x1, layout="NCDHW16c", out_layout="NDHWC"), + relax.TensorStructInfo((n, id_, ih, iw, c * 16), "float32"), + ) + + +def test_avg_pool3d_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=5)) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=6)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference( + bb, relax.op.nn.avg_pool3d(x0), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference( + bb, + relax.op.nn.avg_pool3d(x1, layout="NCDHW16c"), + relax.TensorStructInfo(dtype="float32", ndim=6), + ) + _check_inference( + bb, + relax.op.nn.avg_pool3d(x2), + relax.TensorStructInfo(dtype="float32", ndim=5), + ) + + +def test_avg_pool3d_infer_struct_info_ceil_mode(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "float32")) + + _check_inference( + bb, + relax.op.nn.avg_pool3d(x, pool_size=3, strides=2, ceil_mode=True), + relax.TensorStructInfo((2, 3, 16, 16, 16), "float32"), + ) + _check_inference( + bb, + relax.op.nn.avg_pool3d(x, pool_size=(5, 3, 3), strides=2, ceil_mode=True), + relax.TensorStructInfo((2, 3, 15, 16, 16), "float32"), + ) + + +def test_avg_pool3d_infer_struct_info_ceil_mode_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + c = tir.Var("c", "int64") + id_ = tir.Var("id", "int64") + ih = tir.Var("ih", "int64") + iw = tir.Var("iw", "int64") + x = relax.Var("x", R.Tensor((n, c, id_, ih, iw), "float32")) + + _check_inference( + bb, + relax.op.nn.avg_pool3d( + x, + pool_size=(3, 3, 3), + strides=(2, 2, 2), + padding=(1, 1, 1), + dilation=(2, 2, 2), + ceil_mode=True, + ), + relax.TensorStructInfo( + ( + n, + c, + tvm.tir.floordiv(id_, 2), + tvm.tir.floordiv(ih, 2), + tvm.tir.floordiv(iw, 2), + ), + "float32", + ), + ) + + +def test_avg_pool3d_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "int64")) + + _check_inference( + bb, relax.op.nn.avg_pool3d(x0), relax.TensorStructInfo((2, 3, 32, 32, 32), "float16") + ) + _check_inference( + bb, relax.op.nn.avg_pool3d(x1), relax.TensorStructInfo((2, 3, 32, 32, 32), "int8") + ) + _check_inference( + bb, relax.op.nn.avg_pool3d(x2), relax.TensorStructInfo((2, 3, 32, 32, 32), "int64") + ) + + +def test_avg_pool3d_stride_padding_dilation_int64(): + x = relax.Var("x", R.Tensor((2, 3, 28, 28, 28), "float32")) + avg_pool3d = relax.op.nn.avg_pool3d( + x, (3, 3, 3), strides=(1, 1, 1), padding=(1, 1, 1), dilation=(1, 1, 1) + ) + + assert avg_pool3d.attrs.strides[0].dtype == "int64" + assert avg_pool3d.attrs.strides[1].dtype == "int64" + assert avg_pool3d.attrs.strides[2].dtype == "int64" + assert avg_pool3d.attrs.padding[0].dtype == "int64" + assert avg_pool3d.attrs.padding[1].dtype == "int64" + assert avg_pool3d.attrs.padding[2].dtype == "int64" + assert avg_pool3d.attrs.dilation[0].dtype == "int64" + assert avg_pool3d.attrs.dilation[1].dtype == "int64" + assert avg_pool3d.attrs.dilation[2].dtype == "int64" + + +def test_avg_pool3d_wrong_pool_size_strides_padding_dilation_length(): + x = relax.Var("x", R.Tensor((2, 3, 28, 28, 28), "float32")) + with pytest.raises(TVMError): + relax.op.nn.avg_pool3d(x, pool_size=(1, 2, 3, 4)) + with pytest.raises(TVMError): + relax.op.nn.avg_pool3d(x, strides=(1, 2, 3, 4)) + with pytest.raises(TVMError): + relax.op.nn.avg_pool3d(x, padding=(1, 2, 3, 4)) + with pytest.raises(TVMError): + relax.op.nn.avg_pool3d(x, dilation=(1, 2, 3, 4)) + + +def test_avg_pool3d_infer_struct_info_wrong_layout_string(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 28, 28, 28), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.avg_pool3d(x, layout="OIHW")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.avg_pool3d(x, out_layout="OHWI")) + + +def test_avg_pool3d_wrong_input_ndim(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28, 28, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.avg_pool3d(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.avg_pool3d(x1)) + + +def test_avg_pool3d_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28, 28))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 28, 28, 28), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.avg_pool3d(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.avg_pool3d(x1)) + + def test_adaptive_avg_pool1d_infer_struct_info(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm")