From ccdf4a1b619fa868d9efe7c45302b03fca981b23 Mon Sep 17 00:00:00 2001 From: Nachiket Mokashi Date: Wed, 27 Nov 2024 18:51:14 +0000 Subject: [PATCH] Added and addressed comments --- test/test_einsum_autocast.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_einsum_autocast.py b/test/test_einsum_autocast.py index 72520dc62aa..0d474bca485 100644 --- a/test/test_einsum_autocast.py +++ b/test/test_einsum_autocast.py @@ -1,4 +1,3 @@ -import os import re import torch import torch_xla @@ -7,16 +6,17 @@ device = xm.xla_device() -class TestAutocastXla(unittest.TestCase): +class TestEinsumAutocastXla(unittest.TestCase): def test_einsum(self): data = torch.randn(16, 10).to(torch.bfloat16).to(device) target = torch.randn(5, 10).to(torch.bfloat16).to(device) + with torch.autocast("xla"): product = torch.einsum("...n,mn->...m", data, target) + # test the HLO to see if autocast works for einsum op, which would show up as a dot op in the HLO hlo = torch_xla._XLAC._get_xla_tensors_hlo([product]) - + # Verify that dot op has bf16 output and not f32, i.e. the computation is performed in bfloat16 precision by autocast self.assertTrue(re.search(r".*dot.*bf16", hlo) is not None) - self.assertTrue(re.search(r".*dot.*f32", hlo) is None)