Skip to content

Commit

Permalink
opt allreduce (#44843)
Browse files Browse the repository at this point in the history
  • Loading branch information
sneaxiy authored Aug 4, 2022
1 parent d3e9068 commit 1f9e274
Showing 1 changed file with 56 additions and 50 deletions.
106 changes: 56 additions & 50 deletions paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1697,37 +1697,39 @@ class DistributedFusedLambOpKernel<phi::GPUContext, T>
// (1) ReduceScater first
if (local_shard) {
if (use_hierarchical_allreduce) {
NCCLAllReduceWithScale(fp32_grad,
fp32_sum_grad,
fp32_numel,
nranks / num_devices,
external_comm,
stream,
dev_ctx);
NCCLReduceScatterWithScale(
fp32_sum_grad,
fp32_grad,
fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_numel_each_device,
num_devices,
local_comm,
stream,
dev_ctx);
NCCLAllReduceWithScale(
fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_numel_each_device,
nranks / num_devices,
external_comm,
stream,
dev_ctx);

NCCLAllReduceWithScale(fp16_grad,
fp16_sum_grad,
fp16_numel,
nranks / num_devices,
external_comm,
stream,
dev_ctx);
NCCLReduceScatterWithScale(
fp16_sum_grad,
fp16_grad,
fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_numel_each_device,
num_devices,
local_comm,
stream,
dev_ctx);
NCCLAllReduceWithScale(
fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_numel_each_device,
nranks / num_devices,
external_comm,
stream,
dev_ctx);
} else {
NCCLAllReduceWithScale(fp32_grad,
fp32_sum_grad,
Expand Down Expand Up @@ -1839,38 +1841,40 @@ class DistributedFusedLambOpKernel<phi::GPUContext, T>
// (3) Do ReduceScatter with scale
if (local_shard) {
if (use_hierarchical_allreduce) {
NCCLAllReduceWithScale(fp32_grad,
fp32_sum_grad,
fp32_numel,
nranks / num_devices,
external_comm,
stream,
dev_ctx,
fp32_scale);
NCCLReduceScatterWithScale(
fp32_sum_grad,
fp32_grad,
fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_numel_each_device,
num_devices,
local_comm,
stream,
dev_ctx,
fp32_scale);
NCCLAllReduceWithScale(
fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_numel_each_device,
nranks / num_devices,
external_comm,
stream,
dev_ctx);

NCCLAllReduceWithScale(fp16_grad,
fp16_sum_grad,
fp16_numel,
nranks / num_devices,
external_comm,
stream,
dev_ctx,
fp16_scale);
NCCLReduceScatterWithScale(
fp16_sum_grad,
fp16_grad,
fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_numel_each_device,
num_devices,
local_comm,
stream,
dev_ctx,
fp16_scale);
NCCLAllReduceWithScale(
fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_numel_each_device,
nranks / num_devices,
external_comm,
stream,
dev_ctx);
} else {
NCCLAllReduceWithScale(fp32_grad,
Expand Down Expand Up @@ -1917,37 +1921,39 @@ class DistributedFusedLambOpKernel<phi::GPUContext, T>
} else {
if (local_shard) {
if (use_hierarchical_allreduce) {
NCCLAllReduceWithScale(fp32_grad,
fp32_sum_grad,
fp32_numel,
nranks / num_devices,
external_comm,
stream,
dev_ctx);
NCCLReduceScatterWithScale(
fp32_sum_grad,
fp32_grad,
fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_numel_each_device,
num_devices,
local_comm,
stream,
dev_ctx);
NCCLAllReduceWithScale(
fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_numel_each_device,
nranks / num_devices,
external_comm,
stream,
dev_ctx);

NCCLAllReduceWithScale(fp16_grad,
fp16_sum_grad,
fp16_numel,
nranks / num_devices,
external_comm,
stream,
dev_ctx);
NCCLReduceScatterWithScale(
fp16_sum_grad,
fp16_grad,
fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_numel_each_device,
num_devices,
local_comm,
stream,
dev_ctx);
NCCLAllReduceWithScale(
fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_numel_each_device,
nranks / num_devices,
external_comm,
stream,
dev_ctx);
} else {
NCCLAllReduceWithScale(fp32_grad,
fp32_sum_grad,
Expand Down

0 comments on commit 1f9e274

Please sign in to comment.