Skip to content

Commit

Permalink
[auto parallel] shard tensor stop gradient support (#60699)
Browse files Browse the repository at this point in the history
  • Loading branch information
FeixLiu authored Jan 10, 2024
1 parent 3611e26 commit f177fa6
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
6 changes: 5 additions & 1 deletion python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,13 @@ def _init_func(var, block):

return dist_param
else:
return paddle.Tensor(
dist_tensor = paddle.Tensor(
tensor, process_mesh=mesh, placements=placements, place=place
)
# InitDistTensorWithTensor won't pass the stop gradient attribute,
# have to pass it manually.
dist_tensor.stop_gradient = tensor.stop_gradient
return dist_tensor
else:
# TODO(zhiqiu): we need to refine the static shard_tensor
sharding_specs = get_shard_spec(mesh, placements, tensor.ndim)
Expand Down
6 changes: 6 additions & 0 deletions test/auto_parallel/test_shard_tensor_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ def test_dynamic_mode_property_change(self):

self.assertEqual(d_tensor.process_mesh, self.mesh)

def test_stop_gradient(self):
x = paddle.ones([4, 1024, 512])
x.stop_gradient = False
x = dist.shard_tensor(x, self.mesh, [Shard(0), Replicate()])
assert not x.stop_gradient


class TestShardTensorStatic(unittest.TestCase):
def setUp(self):
Expand Down

0 comments on commit f177fa6

Please sign in to comment.