Skip to content

Commit

Permalink
fix process mesh incorrect set in converter (PaddlePaddle#60504)
Browse files Browse the repository at this point in the history
  • Loading branch information
LiYuRio authored and Wanglongzhi2001 committed Jan 7, 2024
1 parent 9e8b735 commit b2c4730
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 42 deletions.
6 changes: 3 additions & 3 deletions paddle/fluid/pybind/eager_custom_python_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ static PyObject *eager_api_linear(PyObject *self,
PyObject *kwargs) {
PyThreadState *tstate = nullptr;
try {
auto x = GetTensorFromArgs("linear", "X", args, 0, false);
auto weight = GetTensorFromArgs("linear", "weight", args, 1, false);
auto bias = GetTensorFromArgs("linear", "Bias", args, 2, true);
auto &x = GetTensorFromArgs("linear", "X", args, 0, false);
auto &weight = GetTensorFromArgs("linear", "weight", args, 1, false);
auto &bias = GetTensorFromArgs("linear", "Bias", args, 2, true);

tstate = PyEval_SaveThread();

Expand Down
63 changes: 24 additions & 39 deletions paddle/fluid/pybind/eager_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,7 @@ void ConvertToDistTensor(Tensor* x, const phi::distributed::ProcessMesh* mesh) {
if (x->is_dist_tensor()) {
PADDLE_ENFORCE_EQ(
std::dynamic_pointer_cast<phi::distributed::DistTensor>(x->impl())
->dist_attr()
.process_mesh(),
->process_mesh(),
*mesh,
platform::errors::InvalidArgument(
"Input %s has different mesh. However all inputs should "
Expand All @@ -136,16 +135,18 @@ void ConvertToDistTensor(Tensor* x, const phi::distributed::ProcessMesh* mesh) {
"Failed to convert input %s impl to phi::distributed::DistTensor "
"as it's not phi::DenseTensor.",
x->name()));
phi::distributed::TensorDistAttr dist_attr(
common::vectorize(x->impl()->dims()));
dist_attr.set_process_mesh(*mesh);
phi::distributed::Placements placements;
for (int64_t i = 0; i < mesh->ndim(); ++i) {
placements.emplace_back(std::make_shared<phi::distributed::Replicate>());
}

auto dense_t = std::static_pointer_cast<phi::DenseTensor>(x->impl());
// auto parallel in dygraph doesn't support strided kernel.
if (!dense_t->meta().is_contiguous()) {
*dense_t = paddle::experimental::Trans2Contiguous(*dense_t);
}
x->set_impl(
std::make_shared<phi::distributed::DistTensor>(dense_t, dist_attr));
x->set_impl(std::make_shared<phi::distributed::DistTensor>(
dense_t, *mesh, placements));
}
}

Expand Down Expand Up @@ -393,8 +394,7 @@ std::vector<paddle::Tensor> CastPyArg2VectorOfTensor(
local_mesh =
&(std::dynamic_pointer_cast<phi::distributed::DistTensor>(
tensor.impl())
->dist_attr()
.process_mesh());
->process_mesh());
mesh_start_index = i;
}
}
Expand Down Expand Up @@ -433,8 +433,7 @@ std::vector<paddle::Tensor> CastPyArg2VectorOfTensor(
local_mesh =
&(std::dynamic_pointer_cast<phi::distributed::DistTensor>(
tensor.impl())
->dist_attr()
.process_mesh());
->process_mesh());
mesh_start_index = i;
}
}
Expand Down Expand Up @@ -1431,8 +1430,7 @@ std::vector<paddle::Tensor> GetTensorListFromArgs(
local_mesh =
&(std::dynamic_pointer_cast<phi::distributed::DistTensor>(
tensor.impl())
->dist_attr()
.process_mesh());
->process_mesh());
mesh_start_index = i;
}
}
Expand Down Expand Up @@ -1465,8 +1463,7 @@ std::vector<paddle::Tensor> GetTensorListFromArgs(
local_mesh =
&(std::dynamic_pointer_cast<phi::distributed::DistTensor>(
tensor.impl())
->dist_attr()
.process_mesh());
->process_mesh());
mesh_start_index = i;
}
}
Expand Down Expand Up @@ -1539,8 +1536,7 @@ paddle::optional<std::vector<paddle::Tensor>> GetOptionalTensorListFromArgs(
local_mesh =
&(std::dynamic_pointer_cast<phi::distributed::DistTensor>(
tensor.impl())
->dist_attr()
.process_mesh());
->process_mesh());
mesh_start_index = i;
}
}
Expand Down Expand Up @@ -1573,8 +1569,7 @@ paddle::optional<std::vector<paddle::Tensor>> GetOptionalTensorListFromArgs(
local_mesh =
&(std::dynamic_pointer_cast<phi::distributed::DistTensor>(
tensor.impl())
->dist_attr()
.process_mesh());
->process_mesh());
mesh_start_index = i;
}
}
Expand Down Expand Up @@ -1679,8 +1674,7 @@ std::vector<paddle::Tensor*> GetTensorPtrListFromArgs(
local_mesh =
&(std::dynamic_pointer_cast<phi::distributed::DistTensor>(
tensor->impl())
->dist_attr()
.process_mesh());
->process_mesh());
mesh_start_index = i;
}
}
Expand Down Expand Up @@ -1712,8 +1706,7 @@ std::vector<paddle::Tensor*> GetTensorPtrListFromArgs(
local_mesh =
&(std::dynamic_pointer_cast<phi::distributed::DistTensor>(
tensor->impl())
->dist_attr()
.process_mesh());
->process_mesh());
mesh_start_index = i;
}
}
Expand Down Expand Up @@ -1760,8 +1753,7 @@ std::vector<paddle::Tensor*> GetTensorPtrListFromPyObject(PyObject* obj) {
local_mesh =
&(std::dynamic_pointer_cast<phi::distributed::DistTensor>(
tensor->impl())
->dist_attr()
.process_mesh());
->process_mesh());
mesh_start_index = i;
}
}
Expand Down Expand Up @@ -1789,8 +1781,7 @@ std::vector<paddle::Tensor*> GetTensorPtrListFromPyObject(PyObject* obj) {
local_mesh =
&(std::dynamic_pointer_cast<phi::distributed::DistTensor>(
tensor->impl())
->dist_attr()
.process_mesh());
->process_mesh());
mesh_start_index = i;
}
}
Expand Down Expand Up @@ -1832,8 +1823,7 @@ std::vector<paddle::Tensor> GetTensorListFromPyObject(PyObject* obj,
local_mesh =
&(std::dynamic_pointer_cast<phi::distributed::DistTensor>(
tensor.impl())
->dist_attr()
.process_mesh());
->process_mesh());
mesh_start_index = i;
}
}
Expand Down Expand Up @@ -1871,8 +1861,7 @@ std::vector<paddle::Tensor> GetTensorListFromPyObject(PyObject* obj,
local_mesh =
&(std::dynamic_pointer_cast<phi::distributed::DistTensor>(
tensor.impl())
->dist_attr()
.process_mesh());
->process_mesh());
mesh_start_index = i;
}
}
Expand Down Expand Up @@ -2687,8 +2676,7 @@ PyMODINIT_FUNC PyInit__static_op_arg_pre_cast_hook() {
void DistTensorTypeParser::operator()(const Tensor& x) {
if (x.defined() && x.is_dist_tensor()) {
*mesh = &(std::dynamic_pointer_cast<phi::distributed::DistTensor>(x.impl())
->dist_attr()
.process_mesh());
->process_mesh());
result = true;
}
}
Expand All @@ -2698,8 +2686,7 @@ void DistTensorTypeParser::operator()(const paddle::optional<Tensor>& x) {
if (x.get_ptr()->defined() && x.get_ptr()->is_dist_tensor()) {
*mesh = &(std::dynamic_pointer_cast<phi::distributed::DistTensor>(
x.get_ptr()->impl())
->dist_attr()
.process_mesh());
->process_mesh());
result = true;
}
}
Expand All @@ -2711,8 +2698,7 @@ void DistTensorTypeParser::operator()(const std::vector<Tensor>& x) {
if (t.defined() && t.is_dist_tensor()) {
*mesh =
&(std::dynamic_pointer_cast<phi::distributed::DistTensor>(t.impl())
->dist_attr()
.process_mesh());
->process_mesh());
result = true;
break;
}
Expand All @@ -2728,8 +2714,7 @@ void DistTensorTypeParser::operator()(
if (t.defined() && t.is_dist_tensor()) {
*mesh = &(
std::dynamic_pointer_cast<phi::distributed::DistTensor>(t.impl())
->dist_attr()
.process_mesh());
->process_mesh());
result = true;
break;
}
Expand Down

0 comments on commit b2c4730

Please sign in to comment.