diff --git a/python/paddle/distributed/auto_parallel/api.py b/python/paddle/distributed/auto_parallel/api.py index 84e49b2aa390c..ef7539094f5db 100644 --- a/python/paddle/distributed/auto_parallel/api.py +++ b/python/paddle/distributed/auto_parallel/api.py @@ -225,9 +225,13 @@ def _init_func(var, block): return dist_param else: - return paddle.Tensor( + dist_tensor = paddle.Tensor( tensor, process_mesh=mesh, placements=placements, place=place ) + # InitDistTensorWithTensor won't pass the stop gradient attribute, + # have to pass it manually. + dist_tensor.stop_gradient = tensor.stop_gradient + return dist_tensor else: # TODO(zhiqiu): we need to refine the static shard_tensor sharding_specs = get_shard_spec(mesh, placements, tensor.ndim) diff --git a/test/auto_parallel/test_shard_tensor_api.py b/test/auto_parallel/test_shard_tensor_api.py index a726f8f595c2c..5efa0c25031f1 100644 --- a/test/auto_parallel/test_shard_tensor_api.py +++ b/test/auto_parallel/test_shard_tensor_api.py @@ -92,6 +92,12 @@ def test_dynamic_mode_property_change(self): self.assertEqual(d_tensor.process_mesh, self.mesh) + def test_stop_gradient(self): + x = paddle.ones([4, 1024, 512]) + x.stop_gradient = False + x = dist.shard_tensor(x, self.mesh, [Shard(0), Replicate()]) + assert not x.stop_gradient + class TestShardTensorStatic(unittest.TestCase): def setUp(self):