Skip to content

Commit

Permalink
Added and addressed comments
Browse files Browse the repository at this point in the history
  • Loading branch information
aws-nm9 committed Nov 27, 2024
1 parent 983b66d commit ccdf4a1
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions test/test_einsum_autocast.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import re
import torch
import torch_xla
Expand All @@ -7,16 +6,17 @@

device = xm.xla_device()

class TestAutocastXla(unittest.TestCase):
class TestEinsumAutocastXla(unittest.TestCase):
def test_einsum(self):
data = torch.randn(16, 10).to(torch.bfloat16).to(device)
target = torch.randn(5, 10).to(torch.bfloat16).to(device)

with torch.autocast("xla"):
product = torch.einsum("...n,mn->...m", data, target)
# test the HLO to see if autocast works for einsum op, which would show up as a dot op in the HLO
hlo = torch_xla._XLAC._get_xla_tensors_hlo([product])

# Verify that dot op has bf16 output and not f32, i.e. the computation is performed in bfloat16 precision by autocast
self.assertTrue(re.search(r".*dot.*bf16", hlo) is not None)

self.assertTrue(re.search(r".*dot.*f32", hlo) is None)


Expand Down

0 comments on commit ccdf4a1

Please sign in to comment.