From f364c04107e853feb12d9c87cf34e0fda6017ab6 Mon Sep 17 00:00:00 2001 From: LiYuRio Date: Thu, 16 Nov 2023 16:25:21 +0800 Subject: [PATCH] remove value for rank not in mesh --- .../distributed/auto_parallel/dist_tensor.cc | 27 +++++++------------ 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc b/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc index 0661ef17d2140c..3e127aaf709ca1 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc +++ b/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc @@ -37,13 +37,12 @@ DistTensor::DistTensor() : value_(std::make_shared()) {} DistTensor::DistTensor(const std::shared_ptr& global_value, const TensorDistAttr& dist_attr) - : dims_(global_value->dims()), - dist_attr_(dist_attr), - value_(std::make_shared()) { + : dims_(global_value->dims()), dist_attr_(dist_attr) { // If the current rank doesn't in process_mesh, we should create an // uninitialized tensor only with tensor_meta. if (IsCurRankInMesh(dist_attr.process_mesh())) { if (!dist_attr.is_replicated()) { + value_ = std::make_shared(); // 1. create replicated global tensor TensorDistAttr replicated_dist_attr(vectorize(global_value->dims())); replicated_dist_attr.set_process_mesh(dist_attr.process_mesh()); @@ -57,21 +56,16 @@ DistTensor::DistTensor(const std::shared_ptr& global_value, value_ = global_value; } } else { - // TODO(liyurui): The following logic is illegal, and should be removed - // later. It exist temporary because the basic execution procedure is not - // ready, even sometimes we try to construct a DistTensor with empty - // DistAttr. Here we warning when the DistAttr is empty for debug use. - if (dist_attr.empty()) { - LOG(WARNING) << "Try to construct a dist tensor with empty dist attr."; - } - value_ = global_value; + value_ = std::make_shared( + std::make_shared(nullptr, 0, global_value->place()), + phi::DenseTensorMeta(global_value->meta())); } } DistTensor::DistTensor(const std::shared_ptr& global_value, const ProcessMesh& process_mesh, const Placements& placements) - : dims_(global_value->dims()), value_(std::make_shared()) { + : dims_(global_value->dims()) { dist_tensor_meta_ = DistTensorMeta( process_mesh, placements, @@ -88,6 +82,7 @@ DistTensor::DistTensor(const std::shared_ptr& global_value, // uninitialized tensor only with dist_tensor_meta_. if (IsCurRankInMesh(process_mesh)) { if (!dist_tensor_meta_.is_replicated()) { + value_ = std::make_shared(); // 1. create replicated global tensor TensorDistAttr replicated_dist_attr(vectorize(global_value->dims())); replicated_dist_attr.set_process_mesh(process_mesh); @@ -101,11 +96,9 @@ DistTensor::DistTensor(const std::shared_ptr& global_value, value_ = global_value; } } else { - // The following logic is illegal, and should be removed - // later. It exist temporary because the basic execution procedure is not - // ready, even sometimes we try to construct a DistTensor with empty - // DistAttr. Here we warning when the DistAttr is empty for debug use. - value_ = global_value; + value_ = std::make_shared( + std::make_shared(nullptr, 0, global_value->place()), + phi::DenseTensorMeta(global_value->meta())); } }