1414
1515
1616class DTensorXLAFromLocalConversionTest (test_xla_sharding_base .XlaShardingTest ):
17- """
17+ """
1818 Test suite for the automatic conversion of regular tensors to XLAShardedTensor
1919 in DTensor.from_local() when using XLA device mesh.
2020 """
2121
22- @classmethod
23- def setUpClass (cls ):
24- super ().setUpClass ()
25-
26- def test_to_local (self ):
27- from torch .distributed .tensor import distribute_tensor
28- world_size = xr .global_runtime_device_count ()
29- mesh = DeviceMesh ("xla" , list (range (world_size )))
30-
31- big_tensor = torch .randn (100000 , 88 )
32- sharded_tensor = XLAShardedTensor (big_tensor , mesh , [Shard (0 )])
33-
34- local_tensor = sharded_tensor .to_local ()
35-
36- # Verify the shapes are the same
37- self .assertEqual (local_tensor .shape , big_tensor .shape )
38-
39- # Check the value of the tensor
40- torch .testing .assert_close (local_tensor , big_tensor , check_device = False )
41-
42- def test_to_local_requires_grad (self ):
43- """Test that gradients flow correctly through to_local()."""
44- # Create a tensor with requires_grad=True
45- world_size = xr .global_runtime_device_count ()
46- mesh = DeviceMesh ("xla" , list (range (world_size )))
47-
48- tensor = torch .randn (100_000 , 88 , requires_grad = True )
49-
50- # Create XLAShardedTensor
51- sharded_tensor = XLAShardedTensor (tensor , mesh , [Shard (0 )])
52-
53- # Verify requires_grad is set
54- self .assertTrue (sharded_tensor .requires_grad )
55-
56- res = sharded_tensor .sum ()
57- res .backward ()
58-
59- # Verify grad are calculated
60- self .assertTrue (sharded_tensor .grad is not None )
61-
62- # Call to local function
63- local_tensor = sharded_tensor .to_local ()
64-
65- # Verify requires_grad is preserved
66- self .assertTrue (local_tensor .requires_grad )
67-
68- # All gradients should be 1.0 since we did a sum()
69- self .assertTrue (torch .allclose (local_tensor .grad , torch .ones_like (tensor )))
70-
71- print ("Gradient flow test successful" )
22+ @classmethod
23+ def setUpClass (cls ):
24+ super ().setUpClass ()
25+
26+ def test_to_local (self ):
27+ from torch .distributed .tensor import distribute_tensor
28+ world_size = xr .global_runtime_device_count ()
29+ mesh = DeviceMesh ("xla" , list (range (world_size )))
30+
31+ big_tensor = torch .randn (100000 , 88 )
32+ sharded_tensor = XLAShardedTensor (big_tensor , mesh , [Shard (0 )])
33+
34+ local_tensor = sharded_tensor .to_local ()
35+
36+ # Verify the shapes are the same
37+ self .assertEqual (local_tensor .shape , big_tensor .shape )
38+
39+ # Check the value of the tensor
40+ torch .testing .assert_close (local_tensor , big_tensor , check_device = False )
41+
42+ def test_to_local_requires_grad (self ):
43+ """Test that gradients flow correctly through to_local()."""
44+ # Create a tensor with requires_grad=True
45+ world_size = xr .global_runtime_device_count ()
46+ mesh = DeviceMesh ("xla" , list (range (world_size )))
47+
48+ tensor = torch .randn (100_000 , 88 , requires_grad = True )
49+
50+ # Create XLAShardedTensor
51+ sharded_tensor = XLAShardedTensor (tensor , mesh , [Shard (0 )])
52+
53+ # Verify requires_grad is set
54+ self .assertTrue (sharded_tensor .requires_grad )
55+
56+ res = sharded_tensor .sum ()
57+ res .backward ()
58+
59+ # Verify grad are calculated
60+ self .assertTrue (sharded_tensor .grad is not None )
61+
62+ # Call to local function
63+ local_tensor = sharded_tensor .to_local ()
64+
65+ # Verify requires_grad is preserved
66+ self .assertTrue (local_tensor .requires_grad )
67+
68+ # All gradients should be 1.0 since we did a sum()
69+ self .assertTrue (torch .allclose (local_tensor .grad , torch .ones_like (tensor )))
70+
71+ print ("Gradient flow test successful" )
72+
7273
7374if __name__ == "__main__" :
74- result = unittest .main (exit = False )
75- sys .exit (0 if result .result .wasSuccessful () else 1 )
75+ result = unittest .main (exit = False )
76+ sys .exit (0 if result .result .wasSuccessful () else 1 )
0 commit comments