Skip to content

Commit

Permalink
Add autocast support for einsum
Browse files Browse the repository at this point in the history
  • Loading branch information
aws-nm9 committed Nov 27, 2024
1 parent 39e67b5 commit 983b66d
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 0 deletions.
24 changes: 24 additions & 0 deletions test/test_einsum_autocast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import os
import re
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import unittest

device = xm.xla_device()

class TestAutocastXla(unittest.TestCase):
def test_einsum(self):
data = torch.randn(16, 10).to(torch.bfloat16).to(device)
target = torch.randn(5, 10).to(torch.bfloat16).to(device)
with torch.autocast("xla"):
product = torch.einsum("...n,mn->...m", data, target)
hlo = torch_xla._XLAC._get_xla_tensors_hlo([product])

self.assertTrue(re.search(r".*dot.*bf16", hlo) is not None)

self.assertTrue(re.search(r".*dot.*f32", hlo) is None)


if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions torch_xla/csrc/autocast_mode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ TORCH_LIBRARY_IMPL(aten, AutocastXLA, m) {
KERNEL_XLA(prelu, lower_precision_fp)
KERNEL_XLA(relu, lower_precision_fp)
KERNEL_XLA(max_pool2d, lower_precision_fp)
KERNEL_XLA(einsum, lower_precision_fp)
// Disable `scaled_dot_product_attention` for now since it causes
// undefined symbol with official torch whl.
// KERNEL_XLA(scaled_dot_product_attention, lower_precision_fp)
Expand Down

0 comments on commit 983b66d

Please sign in to comment.