Skip to content

Commit

Permalink
[AutoParallel] tensor.cc support auto parallel (PaddlePaddle#58656)
Browse files Browse the repository at this point in the history
* tensor.cc support auto parallel
  • Loading branch information
wanghuancoder authored and SecretXV committed Nov 28, 2023
1 parent f09e2d3 commit 84f3532
Showing 1 changed file with 26 additions and 7 deletions.
33 changes: 26 additions & 7 deletions paddle/phi/api/lib/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,11 @@ const phi::DDim &Tensor::strides() const {
return static_cast<phi::DenseTensor *>(impl_.get())->strides();
} else if (is_dist_tensor()) {
return static_cast<phi::distributed::DistTensor *>(impl_.get())
->unsafe_mutable_value()
->strides();
->value()
.strides();
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Only support strides operation on DenseTensor now."));
"Only support strides operation on DenseTensor and DistTensor now."));
}
}

Expand Down Expand Up @@ -437,9 +437,16 @@ void Tensor::bump_inplace_version() {
auto &inplace_version_counter =
static_cast<phi::DenseTensor *>(impl_.get())->InplaceVersionCounter();
inplace_version_counter.Bump();
} else if (is_dist_tensor()) {
auto &inplace_version_counter =
static_cast<phi::distributed::DistTensor *>(impl_.get())
->unsafe_mutable_value()
->InplaceVersionCounter();
inplace_version_counter.Bump();
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"bump_inplace_version is only supported on DenseTensor now."));
PADDLE_THROW(
phi::errors::Unimplemented("bump_inplace_version is only supported on "
"DenseTensor and DistTensor now."));
}
}

Expand All @@ -448,9 +455,15 @@ uint32_t Tensor::current_inplace_version() {
auto &inplace_version_counter =
static_cast<phi::DenseTensor *>(impl_.get())->InplaceVersionCounter();
return inplace_version_counter.CurrentVersion();
} else if (is_dist_tensor()) {
auto &inplace_version_counter =
static_cast<phi::distributed::DistTensor *>(impl_.get())
->unsafe_mutable_value()
->InplaceVersionCounter();
return inplace_version_counter.CurrentVersion();
} else {
LOG_FIRST_N(WARNING, 1)
<< "current_inplace_version is only supported on DenseTensor now.";
LOG_FIRST_N(WARNING, 1) << "current_inplace_version is only supported on "
"DenseTensor DistTensor now.";
}
return 0;
}
Expand All @@ -461,6 +474,12 @@ void Tensor::reset_inplace_version(bool set_to_zero) {
auto &inplace_version_counter =
static_cast<phi::DenseTensor *>(impl_.get())->InplaceVersionCounter();
inplace_version_counter.SetInplaceVersionToZero();
} else if (is_dist_tensor()) {
auto &inplace_version_counter =
static_cast<phi::distributed::DistTensor *>(impl_.get())
->unsafe_mutable_value()
->InplaceVersionCounter();
return inplace_version_counter.SetInplaceVersionToZero();
}
}
}
Expand Down

0 comments on commit 84f3532

Please sign in to comment.