|
| 1 | +import sys |
| 2 | +import os |
| 3 | +import re |
| 4 | +import unittest |
| 5 | +from pathlib import Path |
| 6 | + |
| 7 | +import torch |
| 8 | +import torch_xla |
| 9 | +from torch_xla.experimental.scan import scan |
| 10 | +from torch_xla.distributed.spmd import mark_sharding, set_global_mesh, Mesh |
| 11 | +import torch_xla.runtime as xr |
| 12 | + |
| 13 | + |
| 14 | +class ScanSpmdTest(unittest.TestCase): |
| 15 | + |
| 16 | + def setUp(self): |
| 17 | + # Set up a simple SPMD mesh for these tests. |
| 18 | + num_devices = xr.global_runtime_device_count() |
| 19 | + mesh_shape = (num_devices,) |
| 20 | + self.spmd_mesh = Mesh(list(range(num_devices)), mesh_shape, ('model',)) |
| 21 | + set_global_mesh(self.spmd_mesh) |
| 22 | + xr.use_spmd() |
| 23 | + self.device = torch_xla.device() |
| 24 | + |
| 25 | + @unittest.skipUnless(xr.global_runtime_device_count() >= 4, |
| 26 | + "Multiple devices required") |
| 27 | + def test_scan_cumsum(self): |
| 28 | + """This test uses `scan` to implement `torch.cumsum`.""" |
| 29 | + |
| 30 | + save_file = os.getenv('XLA_SAVE_TENSORS_FILE') |
| 31 | + save_format = os.getenv('XLA_SAVE_TENSORS_FMT') |
| 32 | + if not save_file: |
| 33 | + assert False, "This test should be run with XLA_SAVE_TENSORS_FILE" |
| 34 | + save_file += '.0' |
| 35 | + assert save_format == 'hlo', "This test should be run with XLA_SAVE_TENSORS_FMT=hlo" |
| 36 | + |
| 37 | + # Remove the save file (if exists) to start from a clean slate |
| 38 | + try: |
| 39 | + os.remove(save_file) |
| 40 | + except: |
| 41 | + pass |
| 42 | + |
| 43 | + def fn(carry, x): |
| 44 | + new_carry = carry + x |
| 45 | + y = new_carry |
| 46 | + return new_carry, y |
| 47 | + |
| 48 | + init = torch.zeros(1024, requires_grad=True, device=self.device) |
| 49 | + mark_sharding(init, self.spmd_mesh, ('model',)) |
| 50 | + xs = torch.randn([8, 1024], requires_grad=True, device=self.device) |
| 51 | + mark_sharding(xs, self.spmd_mesh, (None, 'model')) |
| 52 | + final_carry, ys = scan(fn, init, xs) |
| 53 | + torch_xla.sync() |
| 54 | + |
| 55 | + # Check the HLO |
| 56 | + hlo_content = Path(save_file).read_text() |
| 57 | + lines = hlo_content.splitlines() |
| 58 | + |
| 59 | + # There should be only one graph. |
| 60 | + assert len(re.findall('END_GRAPH', hlo_content)) == 1 |
| 61 | + |
| 62 | + # The graph should have output sharding. |
| 63 | + begin_magic = '#OUTPUT_SHARDING_BEGIN' |
| 64 | + end_magic = '#OUTPUT_SHARDING_END' |
| 65 | + self.assertIn(end_magic, str(lines[-2])) |
| 66 | + |
| 67 | + # Extract the output sharding descriptions. |
| 68 | + start = hlo_content.find(begin_magic) |
| 69 | + assert start != -1 |
| 70 | + start += len(begin_magic) |
| 71 | + end = hlo_content.find(end_magic, start) |
| 72 | + assert end != -1 |
| 73 | + end -= len(end_magic) |
| 74 | + output_sharding = hlo_content[start:end].strip().splitlines() |
| 75 | + |
| 76 | + # There should be 4 tensors in output sharding: init, xs, final_carry, ys. |
| 77 | + self.assertEqual(len(output_sharding), 4) |
| 78 | + for sharding in output_sharding: |
| 79 | + self.assertIn('devices=[', sharding) |
| 80 | + |
| 81 | + # Remove the save file again to avoid cluttering other tests. |
| 82 | + os.remove(save_file) |
| 83 | + |
| 84 | + |
| 85 | +if __name__ == '__main__': |
| 86 | + test = unittest.main() |
| 87 | + sys.exit(0 if test.result.wasSuccessful() else 1) |
0 commit comments