diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index b07070ddc99f..16ae43bcd586 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -5051,31 +5051,41 @@ def main( verify_model(Linspace(), example_args, {}, Expected) -def test_bfloat16(): - # TODO(mshr-h): Add tests for all the dtypes supported in fx frontend +@pytest.mark.parametrize( + "torch_dtype, relax_dtype", + [ + (torch.float32, "float32"), + (torch.float16, "float16"), + (torch.bfloat16, "bfloat16"), + (torch.int64, "int64"), + (torch.int32, "int32"), + (torch.bool, "bool"), + ], +) +def test_dtypes(torch_dtype, relax_dtype): example_args = ( - torch.randn(10, 10, dtype=torch.bfloat16), - torch.randn(10, 10, dtype=torch.bfloat16), + torch.randint(0, 10, (10, 10)).to(torch_dtype), + torch.randint(0, 10, (10, 10)).to(torch_dtype), ) - class BFloat16Model(Module): + class Model(Module): def forward(self, lhs: torch.Tensor, rhs: torch.Tensor): return torch.ops.aten.add(lhs, rhs) @tvm.script.ir_module - class expected: + class Expected: @R.function def main( - lhs: R.Tensor((10, 10), dtype="bfloat16"), - rhs: R.Tensor((10, 10), dtype="bfloat16"), - ) -> R.Tuple(R.Tensor((10, 10), dtype="bfloat16")): + lhs: R.Tensor((10, 10), dtype=relax_dtype), + rhs: R.Tensor((10, 10), dtype=relax_dtype), + ) -> R.Tuple(R.Tensor((10, 10), dtype=relax_dtype)): with R.dataflow(): - lv: R.Tensor((10, 10), dtype="bfloat16") = relax.op.add(lhs, rhs) - gv: R.Tuple(R.Tensor((10, 10), dtype="bfloat16")) = (lv,) + lv: R.Tensor((10, 10), dtype=relax_dtype) = relax.op.add(lhs, rhs) + gv: R.Tuple(R.Tensor((10, 10), dtype=relax_dtype)) = (lv,) R.output(gv) return gv - verify_model(BFloat16Model(), example_args, {}, expected) + verify_model(Model(), example_args, {}, Expected) if __name__ == "__main__": diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 2bb2a8444199..1974d50974cd 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -5416,9 +5416,19 @@ def main( verify_model(Norm(p, dim=dim, keepdim=keepdim), input_info, {}, expected) -def test_bfloat16(): - # TODO(mshr-h): Add tests for all the dtypes supported in EP frontend - class BFloat16Model(Module): +@pytest.mark.parametrize( + "torch_dtype, relax_dtype", + [ + (torch.float32, "float32"), + (torch.float16, "float16"), + (torch.bfloat16, "bfloat16"), + (torch.int64, "int64"), + (torch.int32, "int32"), + (torch.bool, "bool"), + ], +) +def test_dtypes(torch_dtype, relax_dtype): + class Model(Module): def forward(self, lhs: torch.Tensor, rhs: torch.Tensor): return torch.ops.aten.add(lhs, rhs) @@ -5426,16 +5436,16 @@ def forward(self, lhs: torch.Tensor, rhs: torch.Tensor): class Expected: @R.function def main( - lhs: R.Tensor((10, 10), dtype="bfloat16"), - rhs: R.Tensor((10, 10), dtype="bfloat16"), - ) -> R.Tensor((10, 10), dtype="bfloat16"): + lhs: R.Tensor((10, 10), dtype=relax_dtype), + rhs: R.Tensor((10, 10), dtype=relax_dtype), + ) -> R.Tensor((10, 10), dtype=relax_dtype): with R.dataflow(): - lv: R.Tensor((10, 10), dtype="bfloat16") = relax.op.add(lhs, rhs) - gv: R.Tensor((10, 10), dtype="bfloat16") = lv + lv: R.Tensor((10, 10), dtype=relax_dtype) = relax.op.add(lhs, rhs) + gv: R.Tensor((10, 10), dtype=relax_dtype) = lv R.output(gv) return gv - verify_model(BFloat16Model(), [([10, 10], "bfloat16"), ([10, 10], "bfloat16")], {}, Expected) + verify_model(Model(), [([10, 10], torch_dtype), ([10, 10], torch_dtype)], {}, Expected) def test_eye():