diff --git a/test/test_einsum_autocast.py b/test/test_einsum_autocast.py index 72520dc62aa..0d474bca485 100644 --- a/test/test_einsum_autocast.py +++ b/test/test_einsum_autocast.py @@ -1,4 +1,3 @@ -import os import re import torch import torch_xla @@ -7,16 +6,17 @@ device = xm.xla_device() -class TestAutocastXla(unittest.TestCase): +class TestEinsumAutocastXla(unittest.TestCase): def test_einsum(self): data = torch.randn(16, 10).to(torch.bfloat16).to(device) target = torch.randn(5, 10).to(torch.bfloat16).to(device) + with torch.autocast("xla"): product = torch.einsum("...n,mn->...m", data, target) + # test the HLO to see if autocast works for einsum op, which would show up as a dot op in the HLO hlo = torch_xla._XLAC._get_xla_tensors_hlo([product]) - + # Verify that dot op has bf16 output and not f32, i.e. the computation is performed in bfloat16 precision by autocast self.assertTrue(re.search(r".*dot.*bf16", hlo) is not None) - self.assertTrue(re.search(r".*dot.*f32", hlo) is None)