diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 48869767ad66..12c85b522dec 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -449,12 +449,53 @@ def _isin(self, node: fx.Node) -> relax.Var: ########## Neural Network ########## + def _adaptive_avg_pool1d(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + output_size = node.args[1] if len(node.args) > 1 else node.kwargs["output_size"] + # Expand to 3D by adding batch dim if input is 2D + x_ndim = x.struct_info.ndim + if x_ndim == 2: + x = relax.op.expand_dims(x, axis=0) + + result = self.block_builder.emit( + relax.op.nn.adaptive_avg_pool1d(x, output_size, layout="NCW") + ) + # Remove added batch dim from result + if x_ndim == 2: + result = relax.op.squeeze(result, axis=[0]) + return result + def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] output_size = node.args[1] - return self.block_builder.emit( + # Expand to 4D by adding batch dim if input is 3D + x_ndim = x.struct_info.ndim + if x_ndim == 3: + x = relax.op.expand_dims(x, axis=0) + + result = self.block_builder.emit( relax.op.nn.adaptive_avg_pool2d(x, output_size, layout="NCHW") ) + # Remove added batch dim from result + if x_ndim == 3: + result = relax.op.squeeze(result, axis=[0]) + return result + + def _adaptive_avg_pool3d(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + output_size = node.args[1] + # Expand to 5D by adding batch dim if input is 4D + x_ndim = x.struct_info.ndim + if x_ndim == 4: + x = relax.op.expand_dims(x, axis=0) + + result = self.block_builder.emit( + relax.op.nn.adaptive_avg_pool3d(x, output_size, layout="NCDHW") + ) + # Remove added batch dim from result + if x_ndim == 4: + result = relax.op.squeeze(result, axis=[0]) + return result def _addmm(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index df532fd1ea04..0b218cbdd1b2 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -387,7 +387,9 @@ def create_convert_map( "_native_batch_norm_legit_functional.default": self._batch_norm_legit_functional, "_native_batch_norm_legit_no_training.default": self._batch_norm_legit_no_training, "batch_norm.default": self._batch_norm_legit_no_training, + "adaptive_avg_pool1d.default": self._adaptive_avg_pool1d, "adaptive_avg_pool2d.default": self._adaptive_avg_pool2d, + "adaptive_avg_pool3d.default": self._adaptive_avg_pool3d, "addmm.default": self._addmm, "avg_pool2d.default": self._avg_pool2d, "baddbmm.default": self._baddbmm, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 5f65f86a4303..bc17706e130f 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -182,13 +182,53 @@ def call_binary_op(op, lhs, rhs): ########## Neural Network ########## + def _adaptive_avg_pool1d_module(self, node: fx.Node) -> relax.Var: + module = self.named_modules[node.target] + x = self.env[node.args[0]] + output_size = module.output_size + # Expand to 3D by adding batch dim if input is 2D + x_ndim = x.struct_info.ndim + if x_ndim == 2: + x = relax.op.expand_dims(x, axis=0) + result = self.block_builder.emit( + relax.op.nn.adaptive_avg_pool1d(x, output_size, layout="NCW") # (N, C, L) + ) + # Remove added batch dim from result + if x_ndim == 2: + result = relax.op.squeeze(result, axis=[0]) + return result + def _adaptive_avg_pool2d_module(self, node: fx.Node) -> relax.Var: module = self.named_modules[node.target] x = self.env[node.args[0]] output_size = module.output_size - return self.block_builder.emit( + # Expand to 4D by adding batch dim if input is 3D + x_ndim = x.struct_info.ndim + if x_ndim == 3: + x = relax.op.expand_dims(x, axis=0) + result = self.block_builder.emit( relax.op.nn.adaptive_avg_pool2d(x, output_size, layout="NCHW") ) + # Remove added batch dim from result + if x_ndim == 3: + result = relax.op.squeeze(result, axis=[0]) + return result + + def _adaptive_avg_pool3d_module(self, node: fx.Node) -> relax.Var: + module = self.named_modules[node.target] + x = self.env[node.args[0]] + output_size = module.output_size + # Expand to 5D by adding batch dim if input is 4D + x_ndim = x.struct_info.ndim + if x_ndim == 4: + x = relax.op.expand_dims(x, axis=0) + result = self.block_builder.emit( + relax.op.nn.adaptive_avg_pool3d(x, output_size, layout="NCDHW") # (N, C, D, H, W) + ) + # Remove added batch dim from result + if x_ndim == 4: + result = relax.op.squeeze(result, axis=[0]) + return result def _avg_pool2d_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] @@ -649,7 +689,9 @@ def create_convert_map( nn.Softplus: self._softplus_module, nn.Tanh: self._unary_op(relax.op.tanh), # neural network + nn.AdaptiveAvgPool1d: self._adaptive_avg_pool1d_module, nn.AdaptiveAvgPool2d: self._adaptive_avg_pool2d_module, + nn.AdaptiveAvgPool3d: self._adaptive_avg_pool3d_module, nn.AvgPool2d: self._avg_pool2d_module, nn.BatchNorm2d: self._batch_norm_2d_module, nn.Conv1d: self._conv1d_module, @@ -755,7 +797,9 @@ def create_convert_map( "truediv": self._binary_op(relax.op.divide, operator.truediv), "xor": self._binary_op(relax.op.bitwise_xor, operator.xor), # neural network + "adaptive_avg_pool1d": self._adaptive_avg_pool1d, "adaptive_avg_pool2d": self._adaptive_avg_pool2d, + "adaptive_avg_pool3d": self._adaptive_avg_pool3d, "addmm": self._addmm, "avg_pool2d": self._avg_pool2d, "baddbmm": self._baddbmm, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index f0bb33964ef2..f0597dae93c9 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -1145,6 +1145,38 @@ def main( verify_model(model, example_args, binding, expected1) +def test_adaptive_avgpool1d(): + class AdaptiveAvgPool1d0(torch.nn.Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AdaptiveAvgPool1d(output_size=5) + + def forward(self, input): + return self.pool(input) + + class AdaptiveAvgPool1d1(torch.nn.Module): + def forward(self, input): + return torch.nn.functional.adaptive_avg_pool1d(input, output_size=5) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 5), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 5), dtype="float32") = R.nn.adaptive_avg_pool1d( + input_1, output_size=[5], layout="NCW" + ) + gv: R.Tuple(R.Tensor((1, 3, 5), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, dtype=torch.float32),) + verify_model(AdaptiveAvgPool1d0(), example_args, {}, expected1) + verify_model(AdaptiveAvgPool1d1(), example_args, {}, expected1) + + def test_adaptive_avgpool2d(): class AdaptiveAvgPool2d0(Module): def __init__(self): @@ -1178,6 +1210,38 @@ def main( verify_model(AdaptiveAvgPool2d1(), example_args, {}, expected1) +def test_adaptive_avgpool3d(): + class AdaptiveAvgPool3d0(torch.nn.Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AdaptiveAvgPool3d([4, 4, 4]) + + def forward(self, input): + return self.pool(input) + + class AdaptiveAvgPool3d1(torch.nn.Module): + def forward(self, input): + return torch.nn.functional.adaptive_avg_pool3d(input, [4, 4, 4]) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 8, 8, 8), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 4, 4, 4), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 4, 4, 4), dtype="float32") = R.nn.adaptive_avg_pool3d( + input_1, output_size=[4, 4, 4], layout="NCDHW", out_layout="NCDHW" + ) + gv: R.Tuple(R.Tensor((1, 3, 4, 4, 4), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 8, 8, 8, dtype=torch.float32),) + verify_model(AdaptiveAvgPool3d0(), example_args, {}, expected1) + verify_model(AdaptiveAvgPool3d1(), example_args, {}, expected1) + + def test_addmm(): class Addmm1(Module): def __init__(self): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 490a2309aa37..2a95e6f1ea57 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -1181,6 +1181,39 @@ def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")): verify_model(AvgPool2d4(), input_info, {}, expected3) +def test_adaptive_avgpool1d(): + input_info = [([1, 3, 16], "float32")] + + class AdaptiveAvgPool1d0(torch.nn.Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AdaptiveAvgPool1d(8) + + def forward(self, input): + return self.pool(input) + + class AdaptiveAvgPool1d1(torch.nn.Module): + def forward(self, input): + return torch.nn.functional.adaptive_avg_pool1d(input, 8) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 16), dtype="float32") + ) -> R.Tensor((1, 3, 8), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 3, 8), dtype="float32") = R.nn.adaptive_avg_pool1d( + input_1, output_size=[8], layout="NCW", out_layout="NCW" + ) + gv: R.Tensor((1, 3, 8), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(AdaptiveAvgPool1d0(), input_info, {}, expected1) + verify_model(AdaptiveAvgPool1d1(), input_info, {}, expected1) + + def test_adaptive_avgpool2d(): input_info = [([1, 3, 10, 10], "float32")] @@ -1215,6 +1248,39 @@ def main( verify_model(AdaptiveAvgPool2d1(), input_info, {}, expected1) +def test_adaptive_avgpool3d(): + input_info = [([1, 3, 16, 16, 16], "float32")] + + class AdaptiveAvgPool3d0(torch.nn.Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AdaptiveAvgPool3d((8, 8, 8)) + + def forward(self, input): + return self.pool(input) + + class AdaptiveAvgPool3d1(torch.nn.Module): + def forward(self, input): + return torch.nn.functional.adaptive_avg_pool3d(input, (8, 8, 8)) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 16, 16, 16), dtype="float32") + ) -> R.Tensor((1, 3, 8, 8, 8), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 3, 8, 8, 8), dtype="float32") = R.nn.adaptive_avg_pool3d( + input_1, output_size=[8, 8, 8], layout="NCDHW", out_layout="NCDHW" + ) + gv: R.Tensor((1, 3, 8, 8, 8), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(AdaptiveAvgPool3d0(), input_info, {}, expected1) + verify_model(AdaptiveAvgPool3d1(), input_info, {}, expected1) + + def test_flatten(): input_info = [([1, 3, 10, 10], "float32")] diff --git a/tests/python/relax/test_op_nn_pooling.py b/tests/python/relax/test_op_nn_pooling.py index 2533a2fcadcb..f55ce1ec60c2 100644 --- a/tests/python/relax/test_op_nn_pooling.py +++ b/tests/python/relax/test_op_nn_pooling.py @@ -27,7 +27,9 @@ def test_op_correctness(): x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) assert relax.op.nn.max_pool2d(x).op == Op.get("relax.nn.max_pool2d") assert relax.op.nn.avg_pool2d(x).op == Op.get("relax.nn.avg_pool2d") + assert relax.op.nn.adaptive_avg_pool1d(x).op == Op.get("relax.nn.adaptive_avg_pool1d") assert relax.op.nn.adaptive_avg_pool2d(x).op == Op.get("relax.nn.adaptive_avg_pool2d") + assert relax.op.nn.adaptive_avg_pool3d(x).op == Op.get("relax.nn.adaptive_avg_pool3d") def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): @@ -495,6 +497,154 @@ def test_avg_pool2d_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.nn.avg_pool2d(x1)) +def test_adaptive_avg_pool1d_infer_struct_info(): + bb = relax.BlockBuilder() + vdev0 = VDevice("llvm") + + x0 = relax.Var("x", R.Tensor((2, 3, 32), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor(ndim=3)) + x4 = relax.Var("x", R.Tensor()) + + x5 = relax.Var("x", R.Tensor((2, 3, 32), "float32", vdev0)) + + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool1d(x0), + relax.TensorStructInfo((2, 3, 32), "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool1d(x5), + relax.TensorStructInfo((2, 3, 32), "float32", vdev0), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool1d(x0, output_size=16), + relax.TensorStructInfo((2, 3, 16), "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool1d(x1), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool1d(x2), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool1d(x3), + relax.TensorStructInfo(dtype="", ndim=3), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool1d(x4), + relax.TensorStructInfo(dtype="", ndim=3), + ) + + +def test_adaptive_avg_pool1d_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + c = tir.Var("c", "int64") + l = tir.Var("l", "int64") + + x0 = relax.Var("x", R.Tensor((n, c, l), "float32")) + + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool1d(x0), + relax.TensorStructInfo((n, c, l), "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool1d(x0, output_size=64), + relax.TensorStructInfo((n, c, 64), "float32"), + ) + + +def test_adaptive_avg_pool1d_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + s1 = relax.Var("s", relax.ShapeStructInfo()) + + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool1d(x0), + relax.TensorStructInfo(s0, "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool1d(x0, output_size=20), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool1d(x1), + relax.TensorStructInfo(s1, dtype="float32"), + ) + + +def test_adaptive_avg_pool1d_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 64), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 64), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3, 64), "int64")) + + _check_inference( + bb, relax.op.nn.adaptive_avg_pool1d(x0), relax.TensorStructInfo((2, 3, 64), "float16") + ) + _check_inference( + bb, relax.op.nn.adaptive_avg_pool1d(x1), relax.TensorStructInfo((2, 3, 64), "int8") + ) + _check_inference( + bb, relax.op.nn.adaptive_avg_pool1d(x2), relax.TensorStructInfo((2, 3, 64), "int64") + ) + + +def test_adaptive_avg_pool1d_wrong_output_size_ndim(): + x = relax.Var("x", R.Tensor((2, 3, 64), "float32")) + with pytest.raises(TVMError): + relax.op.nn.adaptive_avg_pool1d(x, output_size=(32, 32)) + + +def test_adaptive_avg_pool1d_infer_struct_info_wrong_layout_string(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 64), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.adaptive_avg_pool1d(x, layout="OIW")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.adaptive_avg_pool1d(x, out_layout="OWI")) + + +def test_adaptive_avg_pool1d_wrong_input_ndim(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=2)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.adaptive_avg_pool1d(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.adaptive_avg_pool1d(x1)) + + +def test_adaptive_avg_pool1d_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 64))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 64), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.adaptive_avg_pool1d(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.adaptive_avg_pool1d(x1)) + + def test_adaptive_avg_pool2d_infer_struct_info(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") @@ -668,5 +818,197 @@ def test_adaptive_avg_pool2d_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.nn.adaptive_avg_pool2d(x1)) +def test_adaptive_avg_pool3d_infer_struct_info(): + bb = relax.BlockBuilder() + vdev0 = VDevice("llvm") + + x0 = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "float32")) + x1 = relax.Var("x", R.Tensor((2, 32, 32, 32, 3), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=5)) + x3 = relax.Var("x", R.Tensor("float32")) + x4 = relax.Var("x", R.Tensor(ndim=5)) + x5 = relax.Var("x", R.Tensor()) + x6 = relax.Var("x", R.Tensor((2, 4, 32, 32, 32, 16), "float32")) + x7 = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "float32", vdev0)) + + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool3d(x0), + relax.TensorStructInfo((2, 3, 32, 32, 32), "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool3d(x7), + relax.TensorStructInfo((2, 3, 32, 32, 32), "float32", vdev0), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool3d(x0, output_size=30), + relax.TensorStructInfo((2, 3, 30, 30, 30), "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool3d(x0, output_size=(28, 30, 32)), + relax.TensorStructInfo((2, 3, 28, 30, 32), "float32"), + ) + + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool3d(x1, layout="NCDHW"), + relax.TensorStructInfo((2, 32, 32, 32, 3), "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool3d(x0, out_layout="NCDHW"), + relax.TensorStructInfo((2, 3, 32, 32, 32), "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool3d(x6, layout="NCDHW16c", out_layout="NDHWC16c"), + relax.TensorStructInfo((2, 32, 32, 32, 4, 16), "float32"), + ) + + _check_inference( + bb, relax.op.nn.adaptive_avg_pool3d(x2), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference( + bb, relax.op.nn.adaptive_avg_pool3d(x3), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference( + bb, relax.op.nn.adaptive_avg_pool3d(x4), relax.TensorStructInfo(dtype="", ndim=5) + ) + _check_inference( + bb, relax.op.nn.adaptive_avg_pool3d(x5), relax.TensorStructInfo(dtype="", ndim=5) + ) + + +def test_adaptive_avg_pool3d_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + + n = tir.Var("n", "int64") + c = tir.Var("c", "int64") + c16 = tir.Var("c16", "int64") + d = tir.Var("d", "int64") + ih = tir.Var("ih", "int64") + iw = tir.Var("iw", "int64") + + x0 = relax.Var("x", R.Tensor((n, c, d, ih, iw), "float32")) + x1 = relax.Var("x", R.Tensor((n, c, d, ih, iw, c16), "float32")) + + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool3d(x0), + relax.TensorStructInfo((n, c, d, ih, iw), "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool3d(x0, output_size=256), + relax.TensorStructInfo((n, c, 256, 256, 256), "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool3d(x0, output_size=(256, 128, 64)), + relax.TensorStructInfo((n, c, 256, 128, 64), "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool3d(x1, layout="NCDHW16c", out_layout="NDHWC"), + relax.TensorStructInfo((n, d, ih, iw, c * 16), "float32"), + ) + + +def test_adaptive_avg_pool3d_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=5)) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=6)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference(bb, relax.op.nn.adaptive_avg_pool3d(x0), relax.TensorStructInfo(s0, "float32")) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool3d(x0, output_size=32), + relax.TensorStructInfo(dtype="float32", ndim=5), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool3d(x1, layout="NCDHW16c"), + relax.TensorStructInfo(s1, "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool3d(x0, out_layout="NCDHW16c"), + relax.TensorStructInfo(dtype="float32", ndim=6), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool3d(x2, out_layout="NCDHW16c"), + relax.TensorStructInfo(dtype="float32", ndim=6), + ) + + +def test_adaptive_avg_pool3d_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "int64")) + + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool3d(x0), + relax.TensorStructInfo((2, 3, 32, 32, 32), "float16"), + ) + _check_inference( + bb, relax.op.nn.adaptive_avg_pool3d(x1), relax.TensorStructInfo((2, 3, 32, 32, 32), "int8") + ) + _check_inference( + bb, relax.op.nn.adaptive_avg_pool3d(x2), relax.TensorStructInfo((2, 3, 32, 32, 32), "int64") + ) + + +def test_adaptive_avg_pool3d_wrong_output_size_ndim(): + x = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "float32")) + + with pytest.raises(TVMError): + relax.op.nn.adaptive_avg_pool3d(x, (32, 32, 32, 32)) + + +def test_adaptive_avg_pool3d_infer_struct_info_wrong_layout_string(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 28, 28, 28), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.adaptive_avg_pool3d(x, layout="OIDHW")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.adaptive_avg_pool3d(x, out_layout="OHIDW")) + + +def test_adaptive_avg_pool3d_wrong_input_ndim(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28, 28, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.adaptive_avg_pool3d(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.adaptive_avg_pool3d(x1)) + + +def test_adaptive_avg_pool3d_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28, 28))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 28, 28, 28), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.adaptive_avg_pool3d(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.adaptive_avg_pool3d(x1)) + + if __name__ == "__main__": tvm.testing.main()