@@ -44,7 +44,8 @@ def test_to_local_requires_grad(self):
4444 tensor = torch .randn (100_000 , 88 , requires_grad = True )
4545
4646 # Create XLAShardedTensor
47- sharded_tensor = XLAShardedTensor (tensor , mesh , [Shard (0 )], requires_grad = tensor .requires_grad )
47+ sharded_tensor = XLAShardedTensor (
48+ tensor , mesh , [Shard (0 )], requires_grad = tensor .requires_grad )
4849
4950 # Verify requires_grad is set
5051 self .assertTrue (sharded_tensor .requires_grad )
@@ -70,7 +71,8 @@ def test_to_local_grad_independence(self):
7071 mesh = DeviceMesh ("xla" , list (range (world_size )))
7172
7273 tensor = torch .randn (100_000 , 88 , requires_grad = True )
73- sharded_tensor = XLAShardedTensor (tensor , mesh , [Shard (0 )], requires_grad = tensor .requires_grad )
74+ sharded_tensor = XLAShardedTensor (
75+ tensor , mesh , [Shard (0 )], requires_grad = tensor .requires_grad )
7476
7577 # Create gradients
7678 res = sharded_tensor .sum ()
@@ -95,7 +97,8 @@ def test_to_local_grad_none_handling(self):
9597 mesh = DeviceMesh ("xla" , list (range (world_size )))
9698
9799 tensor = torch .randn (100_000 , 88 , requires_grad = True )
98- sharded_tensor = XLAShardedTensor (tensor , mesh , [Shard (0 )], requires_grad = tensor .requires_grad )
100+ sharded_tensor = XLAShardedTensor (
101+ tensor , mesh , [Shard (0 )], requires_grad = tensor .requires_grad )
99102
100103 # Don't do backward pass, so grad remains None
101104 self .assertIsNone (sharded_tensor .grad )
0 commit comments