From cf33518b433da4809543d3b0f6281d96c24f01ee Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Fri, 25 Apr 2025 17:34:35 +0900 Subject: [PATCH] support bfloat16 dtype in pytorch frontend --- .../torch/base_fx_graph_translator.py | 2 ++ .../test_frontend_from_exported_program.py | 27 +++++++++++++++++++ tests/python/relax/test_frontend_from_fx.py | 22 +++++++++++++++ 3 files changed, 51 insertions(+) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 3e81ff1f0bfe..1014b39921b9 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -58,6 +58,8 @@ def _convert_data_type(input_type: Union[str, torch.dtype], env: Optional[Dict] return "float32" elif input_type in ["float16", "torch.float16", torch.float16]: return "float16" + elif input_type in ["bfloat16", "torch.bfloat16", torch.bfloat16]: + return "bfloat16" elif input_type in ["int64", "torch.int64", torch.int64]: return "int64" elif input_type in ["int32", "torch.int32", torch.int32]: diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index e3b6f4ad9c17..c0439dd4cd5a 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -4838,5 +4838,32 @@ def main( verify_model(Linspace(), example_args, {}, Expected) +def test_bfloat16(): + # TODO(mshr-h): Add tests for all the dtypes supported in fx frontend + example_args = ( + torch.randn(10, 10, dtype=torch.bfloat16), + torch.randn(10, 10, dtype=torch.bfloat16), + ) + + class BFloat16Model(Module): + def forward(self, lhs: torch.Tensor, rhs: torch.Tensor): + return torch.ops.aten.add(lhs, rhs) + + @tvm.script.ir_module + class expected: + @R.function + def main( + lhs: R.Tensor((10, 10), dtype="bfloat16"), + rhs: R.Tensor((10, 10), dtype="bfloat16"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="bfloat16")): + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="bfloat16") = relax.op.add(lhs, rhs) + gv: R.Tuple(R.Tensor((10, 10), dtype="bfloat16")) = (lv,) + R.output(gv) + return gv + + verify_model(BFloat16Model(), example_args, {}, expected) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 4003202d4f55..af4bef956461 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -5220,5 +5220,27 @@ def main( verify_model(Norm(p, dim=dim, keepdim=keepdim), input_info, {}, expected) +def test_bfloat16(): + # TODO(mshr-h): Add tests for all the dtypes supported in EP frontend + class BFloat16Model(Module): + def forward(self, lhs: torch.Tensor, rhs: torch.Tensor): + return torch.ops.aten.add(lhs, rhs) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + lhs: R.Tensor((10, 10), dtype="bfloat16"), + rhs: R.Tensor((10, 10), dtype="bfloat16"), + ) -> R.Tensor((10, 10), dtype="bfloat16"): + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="bfloat16") = relax.op.add(lhs, rhs) + gv: R.Tensor((10, 10), dtype="bfloat16") = lv + R.output(gv) + return gv + + verify_model(BFloat16Model(), [([10, 10], "bfloat16"), ([10, 10], "bfloat16")], {}, Expected) + + if __name__ == "__main__": tvm.testing.main()