Skip to content

Commit 60ef219

Browse files
committed
[scan] Add a test under SPMD
Verifies that the GSPMD sharding annotation propagation pass can propagate through a While op and through the Body computation just fine.
1 parent 3006c8e commit 60ef219

File tree

4 files changed

+92
-3
lines changed

4 files changed

+92
-3
lines changed

test/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ function run_xla_op_tests2 {
209209
run_test "$CDIR/pjrt/test_dtypes.py"
210210
run_test "$CDIR/test_while_loop.py"
211211
run_test "$CDIR/scan/test_scan.py"
212+
run_save_tensor_hlo "$CDIR/scan/test_scan_spmd.py"
212213
run_test "$CDIR/scan/test_scan_layers.py"
213214
run_test "$CDIR/test_autocast.py"
214215
run_test "$CDIR/eager/test_eager.py"

test/scan/test_scan.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,6 @@ def compare_pytree(self, expected_pytree, actual_pytree):
5757
flat_actual_pytree = [x for x in flat_actual_pytree if x is not None]
5858
super().compareResults(flat_expected_pytree, flat_actual_pytree)
5959

60-
61-
class ScanTest(TestBase):
62-
6360
def run_test(self,
6461
fn,
6562
init: PyTree,
@@ -104,6 +101,9 @@ def run_test(self,
104101

105102
return final_carry, ys
106103

104+
105+
class ScanTest(TestBase):
106+
107107
def test_scan_simple(self):
108108
"""This test uses `scan` to implement `torch.cumsum`."""
109109

test/scan/test_scan_spmd.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import sys
2+
import os
3+
import re
4+
import unittest
5+
from pathlib import Path
6+
7+
import torch
8+
import torch_xla
9+
from torch_xla.experimental.scan import scan
10+
from torch_xla.distributed.spmd import mark_sharding, set_global_mesh, Mesh
11+
import torch_xla.runtime as xr
12+
13+
14+
class ScanSpmdTest(unittest.TestCase):
15+
16+
def setUp(self):
17+
# Set up a simple SPMD mesh for these tests.
18+
num_devices = xr.global_runtime_device_count()
19+
mesh_shape = (num_devices,)
20+
self.spmd_mesh = Mesh(list(range(num_devices)), mesh_shape, ('model',))
21+
set_global_mesh(self.spmd_mesh)
22+
xr.use_spmd()
23+
self.device = torch_xla.device()
24+
25+
@unittest.skipUnless(xr.global_runtime_device_count() >= 4,
26+
"Multiple devices required")
27+
def test_scan_cumsum(self):
28+
"""This test uses `scan` to implement `torch.cumsum`."""
29+
30+
save_file = os.getenv('XLA_SAVE_TENSORS_FILE')
31+
save_format = os.getenv('XLA_SAVE_TENSORS_FMT')
32+
if not save_file:
33+
assert False, "This test should be run with XLA_SAVE_TENSORS_FILE"
34+
save_file += '.0'
35+
assert save_format == 'hlo', "This test should be run with XLA_SAVE_TENSORS_FMT=hlo"
36+
37+
# Remove the save file (if exists) to start from a clean slate
38+
try:
39+
os.remove(save_file)
40+
except:
41+
pass
42+
43+
def fn(carry, x):
44+
new_carry = carry + x
45+
y = new_carry
46+
return new_carry, y
47+
48+
init = torch.zeros(1024, requires_grad=True, device=self.device)
49+
mark_sharding(init, self.spmd_mesh, ('model',))
50+
xs = torch.randn([8, 1024], requires_grad=True, device=self.device)
51+
mark_sharding(xs, self.spmd_mesh, (None, 'model'))
52+
final_carry, ys = scan(fn, init, xs)
53+
torch_xla.sync()
54+
55+
# Check the HLO
56+
hlo_content = Path(save_file).read_text()
57+
lines = hlo_content.splitlines()
58+
59+
# There should be only one graph.
60+
assert len(re.findall('END_GRAPH', hlo_content)) == 1
61+
62+
# The graph should have output sharding.
63+
begin_magic = '#OUTPUT_SHARDING_BEGIN'
64+
end_magic = '#OUTPUT_SHARDING_END'
65+
self.assertIn(end_magic, str(lines[-2]))
66+
67+
# Extract the output sharding descriptions.
68+
start = hlo_content.find(begin_magic)
69+
assert start != -1
70+
start += len(begin_magic)
71+
end = hlo_content.find(end_magic, start)
72+
assert end != -1
73+
end -= len(end_magic)
74+
output_sharding = hlo_content[start:end].strip().splitlines()
75+
76+
# There should be 4 tensors in output sharding: init, xs, final_carry, ys.
77+
self.assertEqual(len(output_sharding), 4)
78+
for sharding in output_sharding:
79+
self.assertIn('devices=[', sharding)
80+
81+
# Remove the save file again to avoid cluttering other tests.
82+
os.remove(save_file)
83+
84+
85+
if __name__ == '__main__':
86+
test = unittest.main()
87+
sys.exit(0 if test.result.wasSuccessful() else 1)

test/tpu/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ python3 test/pjrt/test_dtypes.py
2626
python3 test/pjrt/test_dynamic_plugin_tpu.py
2727
python3 test/test_while_loop.py
2828
python3 test/scan/test_scan.py
29+
XLA_SAVE_TENSORS_FILE="/tmp/xla_test_save_ir.txt" XLA_SAVE_TENSORS_FMT="hlo" python3 test/scan/test_scan_spmd.py
2930
python3 test/scan/test_scan_layers.py
3031
python3 test/test_pallas.py -v
3132
python3 test/test_pallas_spmd.py

0 commit comments

Comments
 (0)