Skip to content

Commit

Permalink
fix p2p comm memory release logic (#47497) (#47517)
Browse files Browse the repository at this point in the history
  • Loading branch information
FeixLiu authored Nov 1, 2022
1 parent 4b3589f commit 0201ccc
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions paddle/fluid/distributed/collective/ProcessGroupNCCL.cc
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(

platform::CUDADeviceGuard cuda_guard;

if (FLAGS_use_stream_safe_cuda_allocator) {
{
platform::NCCLGroupGuard nccl_guard;
for (size_t i = 0; i < tensors.size(); ++i) {
cuda_guard.SetDevice(places[i]);
gpuStream_t nccl_stream;
Expand All @@ -460,12 +461,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
} else {
nccl_stream = places_to_ctx_[key][i]->stream();
}
memory::RecordStream(tensors[i].Holder(), nccl_stream);
fn(tensors[i], nccl_comms[i]->GetNcclComm(), nccl_stream, dst_rank);
}
}

{
platform::NCCLGroupGuard nccl_guard;
if (FLAGS_use_stream_safe_cuda_allocator) {
for (size_t i = 0; i < tensors.size(); ++i) {
cuda_guard.SetDevice(places[i]);
gpuStream_t nccl_stream;
Expand All @@ -477,7 +477,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
} else {
nccl_stream = places_to_ctx_[key][i]->stream();
}
fn(tensors[i], nccl_comms[i]->GetNcclComm(), nccl_stream, dst_rank);
memory::RecordStream(tensors[i].Holder(), nccl_stream);
}
}

Expand Down Expand Up @@ -516,20 +516,20 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
// construct uninitialize guard for device
platform::CUDADeviceGuard cuda_guard;

if (FLAGS_use_stream_safe_cuda_allocator) {
{
platform::NCCLGroupGuard nccl_guard;
for (size_t i = 0; i < tensors.size(); ++i) {
cuda_guard.SetDevice(places[i]);
memory::RecordStream(tensors[i].Holder(),
places_to_ctx_[key][i]->stream());
const auto& nccl_stream = places_to_ctx_[key][i]->stream();
fn(tensors[i], nccl_comms[i]->GetNcclComm(), nccl_stream, dst_rank);
}
}

{
platform::NCCLGroupGuard nccl_guard;
if (FLAGS_use_stream_safe_cuda_allocator) {
for (size_t i = 0; i < tensors.size(); ++i) {
cuda_guard.SetDevice(places[i]);
const auto& nccl_stream = places_to_ctx_[key][i]->stream();
fn(tensors[i], nccl_comms[i]->GetNcclComm(), nccl_stream, dst_rank);
memory::RecordStream(tensors[i].Holder(),
places_to_ctx_[key][i]->stream());
}
}

Expand Down

0 comments on commit 0201ccc

Please sign in to comment.