11import sys
22import unittest
3- import torch_xla
3+ from functools import reduce
4+
45import torch
6+ from torch .utils ._pytree import tree_map , tree_flatten , tree_iter , tree_leaves , PyTree
7+
8+ import torch_xla
59from torch_xla .experimental .scan import scan
6- from torch .utils ._pytree import tree_map , tree_flatten , tree_iter
710
8- from test_utils import XlaTestCase
11+ from test_utils import XlaTestCase # type:ignore
912
1013
1114def _loopy_scan (fn , init , xs ):
@@ -24,6 +27,8 @@ def _loopy_scan(fn, init, xs):
2427class ScanTest (XlaTestCase ):
2528
2629 def setUp (self ):
30+ super ().setUp ()
31+
2732 self .device = torch_xla .device ()
2833
2934 def compare_pytree (self , expected_pytree , actual_pytree ):
@@ -32,31 +37,54 @@ def compare_pytree(self, expected_pytree, actual_pytree):
3237 assert expected_spec == actual_spec
3338 super ().compareResults (flat_expected_pytree , flat_actual_pytree )
3439
35- def run_test (self , step_fn , init , xs ):
40+ def run_test (self , fn , init : PyTree , xs : PyTree ):
41+ """Compares the result of scanning with `fn` with our optimized HLO implementation
42+ against a for loop implementation. Checks both output values and gradients.
43+ """
3644 # Actual output
37- final_carry , ys = scan (step_fn , init , xs )
45+ init_scan = tree_map (lambda v : v .detach ().requires_grad_ (), init )
46+ xs_scan = tree_map (lambda v : v .detach ().requires_grad_ (), xs )
47+ final_carry , ys = scan (fn , init_scan , xs_scan )
48+ # Add up all leaves in `ys` and `backward()` once.
49+ reduce (lambda a , b : a + b , map (lambda v : v .sum (), tree_leaves (ys )),
50+ torch .tensor (0.0 )).backward ()
3851 torch_xla .sync ()
3952
4053 # Expected output
41- expected_final_carry , expected_ys = _loopy_scan (step_fn , init , xs )
54+ init_loop = tree_map (lambda v : v .detach ().requires_grad_ (), init )
55+ xs_loop = tree_map (lambda v : v .detach ().requires_grad_ (), xs )
56+ expected_final_carry , expected_ys = _loopy_scan (fn , init_loop , xs_loop )
57+ # Add up all leaves in `ys` and `backward()` once.
58+ reduce (lambda a , b : a + b , map (lambda v : v .sum (), tree_leaves (expected_ys )),
59+ torch .tensor (0.0 )).backward ()
4260 torch_xla .sync ()
4361
44- # Compare
62+ # Compare values
4563 self .compare_pytree (expected_final_carry , final_carry )
4664 self .compare_pytree (expected_ys , ys )
4765
66+ # Compare gradients
67+ self .compare_pytree (
68+ tree_map (lambda v : v .grad , init_scan ),
69+ tree_map (lambda v : v .grad , init_loop ))
70+ self .compare_pytree (
71+ tree_map (lambda v : v .grad , xs_scan ), tree_map (lambda v : v .grad ,
72+ xs_loop ))
73+
4874 return final_carry , ys
4975
50- def test_scan_forward_simple (self ):
76+ def test_scan_simple (self ):
5177 """This test uses `scan` to implement `torch.cumsum`."""
5278
5379 def step_fn (carry , x ):
5480 new_carry = carry + x
5581 y = new_carry
5682 return new_carry , y
5783
58- init = torch .tensor ([0.0 , 0.0 ], device = self .device )
59- xs = torch .tensor ([[1.0 , 2.0 ], [3.0 , 4.0 ], [5.0 , 6.0 ]], device = self .device )
84+ init = torch .tensor ([0.0 , 0.0 ], requires_grad = True , device = self .device )
85+ xs = torch .tensor ([[1.0 , 2.0 ], [3.0 , 4.0 ], [5.0 , 6.0 ]],
86+ requires_grad = True ,
87+ device = self .device )
6088 final_carry , ys = self .run_test (step_fn , init , xs )
6189
6290 # Also ensure that our loop-based scan is correct, with manual checks
@@ -80,26 +108,30 @@ def test_scan_incompatible_length(self):
80108 with self .assertRaises (ValueError ):
81109 scan (lambda a , b : (a , b ), init , (xs_1 , xs_2 ))
82110
83- def test_scan_forward_tuples (self ):
111+ def test_scan_tuples (self ):
84112 """Test scanning over the leading axis of a tuple of tensors simultaneously,
85113 which is a simple PyTree."""
86114
87- def step_fn (carry , x ):
115+ def fn (carry , x ):
88116 carry1 , carry2 = carry
89117 x1 , x2 = x
90118 new_carry1 = carry1 + x1 .sum ()
91119 new_carry2 = carry2 + x2 .sum ()
92- y1 = x1 * 2
93- y2 = x2 * 2
120+ y1 = x1 * 2 + torch . sum ( new_carry1 )
121+ y2 = x2 * 2 + torch . sum ( new_carry2 )
94122 return (new_carry1 , new_carry2 ), (y1 , y2 )
95123
96- init = (torch .tensor ([0.0 ], device = self .device ),
97- torch .tensor ([1.0 , 2.0 ], device = self .device ))
124+ init = (torch .tensor ([0.0 ], requires_grad = True , device = self .device ),
125+ torch .tensor ([1.0 , 2.0 ], requires_grad = True , device = self .device ))
98126
99- xs = (torch .tensor ([[1.0 , 2.0 ], [3.0 , 4.0 ]], device = self .device ),
100- torch .tensor ([[5.0 , 6.0 , 7.0 ], [8.0 , 9.0 , 10.0 ]], device = self .device ))
127+ xs = (torch .tensor ([[1.0 , 2.0 ], [3.0 , 4.0 ]],
128+ requires_grad = True ,
129+ device = self .device ),
130+ torch .tensor ([[5.0 , 6.0 , 7.0 ], [8.0 , 9.0 , 10.0 ]],
131+ requires_grad = True ,
132+ device = self .device ))
101133
102- self .run_test (step_fn , init , xs )
134+ self .run_test (fn , init , xs )
103135
104136
105137if __name__ == '__main__' :
0 commit comments