Skip to content

Commit

Permalink
[auto parallel] Lazy init for MP. Add reshard infer shape. (#60563)
Browse files Browse the repository at this point in the history
  • Loading branch information
FeixLiu authored Jan 5, 2024
1 parent 1874d1c commit a9712d1
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 16 deletions.
31 changes: 20 additions & 11 deletions paddle/phi/core/distributed/auto_parallel/dist_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,17 +174,26 @@ DistTensor::DistTensor(const std::shared_ptr<phi::DenseTensor>& global_value,
// uninitialized tensor only with dist_tensor_meta_.
if (IsCurRankInMesh(process_mesh)) {
if (!dist_attr_.is_replicated()) {
value_ = std::make_shared<DenseTensor>();
// 1. create replicated global tensor
TensorDistAttr replicated_dist_attr(
common::vectorize(global_value->dims()));
replicated_dist_attr.set_process_mesh(process_mesh);
DistTensor replicated_tensor(global_value, replicated_dist_attr);

// 2. reshard from replicated to other state
auto* func = ChooseProperReshardFunction(replicated_tensor, dist_attr_);
auto* dev_ctx = DeviceContextPool::Instance().Get(global_value->place());
func->Eval(dev_ctx, replicated_tensor, dist_attr_, this);
if (global_value->initialized()) {
value_ = std::make_shared<DenseTensor>();
// 1. create replicated global tensor
TensorDistAttr replicated_dist_attr(
common::vectorize(global_value->dims()));
replicated_dist_attr.set_process_mesh(process_mesh);
DistTensor replicated_tensor(global_value, replicated_dist_attr);

// 2. reshard from replicated to other state
auto* func = ChooseProperReshardFunction(replicated_tensor, dist_attr_);
auto* dev_ctx =
DeviceContextPool::Instance().Get(global_value->place());
func->Eval(dev_ctx, replicated_tensor, dist_attr_, this);
} else {
// For lazy init, the global value is an uninitialized tensor.
// Just infer the local shape of the dist tensor.
value_ = global_value;
value_->Resize(
InferShapeForReshardFromReplicate(global_value, dist_attr_));
}
} else {
value_ = global_value;
}
Expand Down
23 changes: 23 additions & 0 deletions paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,5 +180,28 @@ phi::DeviceContext* GetDistTensorDeviceContext(
return phi::DeviceContextPool::Instance().Get(place);
}

phi::DDim InferShapeForReshardFromReplicate(
const std::shared_ptr<phi::DenseTensor>& global_value,
const TensorDistAttr& dist_attr) {
phi::DDim out_dim = global_value->dims();
auto coord_id = GetCurRankCoordInMesh(dist_attr.process_mesh());
for (int tensor_axis = 0; tensor_axis < global_value->dims().size();
++tensor_axis) {
if (dist_attr.is_shard(-1, tensor_axis)) {
for (int mesh_axis = 0; mesh_axis < dist_attr.process_mesh().ndim();
++mesh_axis) {
if (dist_attr.is_shard(mesh_axis, tensor_axis)) {
// handle the shard axis
int64_t global_shape = out_dim[tensor_axis];
int64_t mesh_size = dist_attr.process_mesh().dim_size(mesh_axis);
auto balance_shard = BalancedSplit(global_shape, mesh_size);
out_dim[tensor_axis] = balance_shard[coord_id[mesh_axis]];
}
}
}
}
return out_dim;
}

} // namespace distributed
} // namespace phi
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ std::vector<int64_t> BalancedSplit(int64_t total_nums, int64_t num_of_pieces);
CommContext* CreateOrGetCommContext(const DeviceContext& dev_ctx,
const std::vector<int64_t>& process_ids);

phi::DDim InferShapeForReshardFromReplicate(
const std::shared_ptr<phi::DenseTensor>& global_value,
const TensorDistAttr& dist_attr);

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#define RESHARD_FUNCTOR_IMPL(dev_ctx, fn_name, dtype, ...) \
do { \
Expand Down
17 changes: 14 additions & 3 deletions python/paddle/nn/initializer/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,20 @@ def forward(self, var, block=None):
if self._force_cpu:
place = core.CPUPlace()
if in_dygraph_mode():
_C_ops.full_(
var, var.shape, float(self._value), var.dtype, place
)
if isinstance(var, framework.EagerParamBase) and var.is_dist():
out_var = _C_ops.full(
var._local_shape, float(self._value), var.dtype, place
)
out_var = (
paddle.distributed.auto_parallel.api.dtensor_from_local(
out_var, var.process_mesh, var.placements
)
)
out_var._share_underline_tensor_to(var)
else:
_C_ops.full_(
var, var.shape, float(self._value), var.dtype, place
)
return None
else:
return _C_ops.full(
Expand Down
63 changes: 62 additions & 1 deletion test/auto_parallel/semi_auto_parallel_lazy_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os

import paddle
Expand All @@ -34,6 +34,11 @@ def __init__(self):
self._mesh_bias = dist.ProcessMesh([1], dim_names=["x"])
self._placements_weight = [dist.Replicate()]
self._placements_bias = [dist.Replicate()]
elif self._placements_type == "MP":
self._mesh_weight = dist.ProcessMesh([0, 1], dim_names=["x"])
self._mesh_bias = dist.ProcessMesh([0, 1], dim_names=["x"])
self._placements_weight = [dist.Shard(1)]
self._placements_bias = [dist.Shard(0)]

def test_different_xavier(self):
paddle.distributed.auto_parallel.parallel_manual_seed(self._seed)
Expand All @@ -53,6 +58,31 @@ def test_different_xavier(self):
linear.bias = dist.shard_tensor(
linear.bias, self._mesh_bias, self._placements_bias
)
for param in linear.parameters():
param.initialize()
logging.info(param)

def test_constant(self):
paddle.distributed.auto_parallel.parallel_manual_seed(self._seed)
weight_attr = paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Constant(2.0)
)
bias_attr = paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Constant(1.0)
)
with LazyGuard():
linear = paddle.nn.Linear(
10, 10, weight_attr=weight_attr, bias_attr=bias_attr
)
linear.weight = dist.shard_tensor(
linear.weight, self._mesh_weight, self._placements_weight
)
linear.bias = dist.shard_tensor(
linear.bias, self._mesh_bias, self._placements_bias
)
for param in linear.parameters():
param.initialize()
logging.info(param)

def test_placements(self):
paddle.distributed.auto_parallel.parallel_manual_seed(self._seed)
Expand All @@ -67,6 +97,7 @@ def test_placements(self):
for param in linear.parameters():
assert not param._is_initialized()
param.initialize()
logging.info(param)

if self._placements_type == "DP":
assert linear.weight._is_initialized()
Expand All @@ -93,10 +124,40 @@ def test_placements(self):
else:
assert not linear.weight._is_initialized()
assert linear.bias._is_initialized()
elif self._placements_type == "MP":
assert linear.weight._is_initialized()
assert linear.bias._is_initialized()
assert linear.weight._local_shape == [10, 5]
assert linear.bias._local_shape == [5]

def test_unbalance_mp(self):
paddle.distributed.auto_parallel.parallel_manual_seed(self._seed)
with LazyGuard():
linear = paddle.nn.Linear(11, 11)
linear.weight = dist.shard_tensor(
linear.weight, self._mesh_weight, self._placements_weight
)
linear.bias = dist.shard_tensor(
linear.bias, self._mesh_bias, self._placements_bias
)
for param in linear.parameters():
assert not param._is_initialized()
param.initialize()
assert param._is_initialized()

if dist.get_rank() == 0:
assert linear.weight._local_shape == [11, 6]
assert linear.bias._local_shape == [6]
else:
assert linear.weight._local_shape == [11, 5]
assert linear.bias._local_shape == [5]

def run_test_case(self):
self.test_placements()
self.test_different_xavier()
self.test_constant()
if self._placements_type == "MP":
self.test_unbalance_mp()


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion test/auto_parallel/test_semi_auto_parallel_lazy_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def setUp(self):
}
self._changeable_envs = {
"backend": ["cpu", "gpu"],
"_placements_type": ["DP", "PP"],
"_placements_type": ["DP", "PP", "MP"],
}

def test_lazy_init(self):
Expand Down

0 comments on commit a9712d1

Please sign in to comment.