Skip to content

Commit

Permalink
Revert "fix process mesh incorrect set in converter (#60504)"
Browse files Browse the repository at this point in the history
This reverts commit ddd29d2.
  • Loading branch information
LiYuRio authored Jan 3, 2024
1 parent 99af9f7 commit 56931e3
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 27 deletions.
6 changes: 3 additions & 3 deletions paddle/fluid/pybind/eager_custom_python_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ static PyObject *eager_api_linear(PyObject *self,
PyObject *kwargs) {
PyThreadState *tstate = nullptr;
try {
auto &x = GetTensorFromArgs("linear", "X", args, 0, false);
auto &weight = GetTensorFromArgs("linear", "weight", args, 1, false);
auto &bias = GetTensorFromArgs("linear", "Bias", args, 2, true);
auto x = GetTensorFromArgs("linear", "X", args, 0, false);
auto weight = GetTensorFromArgs("linear", "weight", args, 1, false);
auto bias = GetTensorFromArgs("linear", "Bias", args, 2, true);

tstate = PyEval_SaveThread();

Expand Down
63 changes: 39 additions & 24 deletions paddle/fluid/pybind/eager_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ void ConvertToDistTensor(Tensor* x, const phi::distributed::ProcessMesh* mesh) {
if (x->is_dist_tensor()) {
PADDLE_ENFORCE_EQ(
std::dynamic_pointer_cast<phi::distributed::DistTensor>(x->impl())
->process_mesh(),
->dist_attr()
.process_mesh(),
*mesh,
platform::errors::InvalidArgument(
"Input %s has different mesh. However all inputs should "
Expand All @@ -135,18 +136,16 @@ void ConvertToDistTensor(Tensor* x, const phi::distributed::ProcessMesh* mesh) {
"Failed to convert input %s impl to phi::distributed::DistTensor "
"as it's not phi::DenseTensor.",
x->name()));
phi::distributed::Placements placements;
for (int64_t i = 0; i < mesh->ndim(); ++i) {
placements.emplace_back(std::make_shared<phi::distributed::Replicate>());
}

phi::distributed::TensorDistAttr dist_attr(
common::vectorize(x->impl()->dims()));
dist_attr.set_process_mesh(*mesh);
auto dense_t = std::static_pointer_cast<phi::DenseTensor>(x->impl());
// auto parallel in dygraph doesn't support strided kernel.
if (!dense_t->meta().is_contiguous()) {
*dense_t = paddle::experimental::Trans2Contiguous(*dense_t);
}
x->set_impl(std::make_shared<phi::distributed::DistTensor>(
dense_t, *mesh, placements));
x->set_impl(
std::make_shared<phi::distributed::DistTensor>(dense_t, dist_attr));
}
}

Expand Down Expand Up @@ -394,7 +393,8 @@ std::vector<paddle::Tensor> CastPyArg2VectorOfTensor(
local_mesh =
&(std::dynamic_pointer_cast<phi::distributed::DistTensor>(
tensor.impl())
->process_mesh());
->dist_attr()
.process_mesh());
mesh_start_index = i;
}
}
Expand Down Expand Up @@ -433,7 +433,8 @@ std::vector<paddle::Tensor> CastPyArg2VectorOfTensor(
local_mesh =
&(std::dynamic_pointer_cast<phi::distributed::DistTensor>(
tensor.impl())
->process_mesh());
->dist_attr()
.process_mesh());
mesh_start_index = i;
}
}
Expand Down Expand Up @@ -1430,7 +1431,8 @@ std::vector<paddle::Tensor> GetTensorListFromArgs(
local_mesh =
&(std::dynamic_pointer_cast<phi::distributed::DistTensor>(
tensor.impl())
->process_mesh());
->dist_attr()
.process_mesh());
mesh_start_index = i;
}
}
Expand Down Expand Up @@ -1463,7 +1465,8 @@ std::vector<paddle::Tensor> GetTensorListFromArgs(
local_mesh =
&(std::dynamic_pointer_cast<phi::distributed::DistTensor>(
tensor.impl())
->process_mesh());
->dist_attr()
.process_mesh());
mesh_start_index = i;
}
}
Expand Down Expand Up @@ -1536,7 +1539,8 @@ paddle::optional<std::vector<paddle::Tensor>> GetOptionalTensorListFromArgs(
local_mesh =
&(std::dynamic_pointer_cast<phi::distributed::DistTensor>(
tensor.impl())
->process_mesh());
->dist_attr()
.process_mesh());
mesh_start_index = i;
}
}
Expand Down Expand Up @@ -1569,7 +1573,8 @@ paddle::optional<std::vector<paddle::Tensor>> GetOptionalTensorListFromArgs(
local_mesh =
&(std::dynamic_pointer_cast<phi::distributed::DistTensor>(
tensor.impl())
->process_mesh());
->dist_attr()
.process_mesh());
mesh_start_index = i;
}
}
Expand Down Expand Up @@ -1674,7 +1679,8 @@ std::vector<paddle::Tensor*> GetTensorPtrListFromArgs(
local_mesh =
&(std::dynamic_pointer_cast<phi::distributed::DistTensor>(
tensor->impl())
->process_mesh());
->dist_attr()
.process_mesh());
mesh_start_index = i;
}
}
Expand Down Expand Up @@ -1706,7 +1712,8 @@ std::vector<paddle::Tensor*> GetTensorPtrListFromArgs(
local_mesh =
&(std::dynamic_pointer_cast<phi::distributed::DistTensor>(
tensor->impl())
->process_mesh());
->dist_attr()
.process_mesh());
mesh_start_index = i;
}
}
Expand Down Expand Up @@ -1753,7 +1760,8 @@ std::vector<paddle::Tensor*> GetTensorPtrListFromPyObject(PyObject* obj) {
local_mesh =
&(std::dynamic_pointer_cast<phi::distributed::DistTensor>(
tensor->impl())
->process_mesh());
->dist_attr()
.process_mesh());
mesh_start_index = i;
}
}
Expand Down Expand Up @@ -1781,7 +1789,8 @@ std::vector<paddle::Tensor*> GetTensorPtrListFromPyObject(PyObject* obj) {
local_mesh =
&(std::dynamic_pointer_cast<phi::distributed::DistTensor>(
tensor->impl())
->process_mesh());
->dist_attr()
.process_mesh());
mesh_start_index = i;
}
}
Expand Down Expand Up @@ -1823,7 +1832,8 @@ std::vector<paddle::Tensor> GetTensorListFromPyObject(PyObject* obj,
local_mesh =
&(std::dynamic_pointer_cast<phi::distributed::DistTensor>(
tensor.impl())
->process_mesh());
->dist_attr()
.process_mesh());
mesh_start_index = i;
}
}
Expand Down Expand Up @@ -1861,7 +1871,8 @@ std::vector<paddle::Tensor> GetTensorListFromPyObject(PyObject* obj,
local_mesh =
&(std::dynamic_pointer_cast<phi::distributed::DistTensor>(
tensor.impl())
->process_mesh());
->dist_attr()
.process_mesh());
mesh_start_index = i;
}
}
Expand Down Expand Up @@ -2676,7 +2687,8 @@ PyMODINIT_FUNC PyInit__static_op_arg_pre_cast_hook() {
void DistTensorTypeParser::operator()(const Tensor& x) {
if (x.defined() && x.is_dist_tensor()) {
*mesh = &(std::dynamic_pointer_cast<phi::distributed::DistTensor>(x.impl())
->process_mesh());
->dist_attr()
.process_mesh());
result = true;
}
}
Expand All @@ -2686,7 +2698,8 @@ void DistTensorTypeParser::operator()(const paddle::optional<Tensor>& x) {
if (x.get_ptr()->defined() && x.get_ptr()->is_dist_tensor()) {
*mesh = &(std::dynamic_pointer_cast<phi::distributed::DistTensor>(
x.get_ptr()->impl())
->process_mesh());
->dist_attr()
.process_mesh());
result = true;
}
}
Expand All @@ -2698,7 +2711,8 @@ void DistTensorTypeParser::operator()(const std::vector<Tensor>& x) {
if (t.defined() && t.is_dist_tensor()) {
*mesh =
&(std::dynamic_pointer_cast<phi::distributed::DistTensor>(t.impl())
->process_mesh());
->dist_attr()
.process_mesh());
result = true;
break;
}
Expand All @@ -2714,7 +2728,8 @@ void DistTensorTypeParser::operator()(
if (t.defined() && t.is_dist_tensor()) {
*mesh = &(
std::dynamic_pointer_cast<phi::distributed::DistTensor>(t.impl())
->process_mesh());
->dist_attr()
.process_mesh());
result = true;
break;
}
Expand Down

0 comments on commit 56931e3

Please sign in to comment.