From 1e7a33b5d9488c9fa368f6237edaafc8e8d91514 Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Wed, 3 Aug 2022 14:07:29 +0800 Subject: [PATCH] opt allreduce --- .../optimizers/distributed_fused_lamb_op.cu | 106 +++++++++--------- 1 file changed, 56 insertions(+), 50 deletions(-) diff --git a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu index f1b852301d4d9..53a5fa4706cc0 100644 --- a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu +++ b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu @@ -1697,37 +1697,39 @@ class DistributedFusedLambOpKernel // (1) ReduceScater first if (local_shard) { if (use_hierarchical_allreduce) { - NCCLAllReduceWithScale(fp32_grad, - fp32_sum_grad, - fp32_numel, - nranks / num_devices, - external_comm, - stream, - dev_ctx); NCCLReduceScatterWithScale( - fp32_sum_grad, + fp32_grad, fp32_sum_grad + local_rank * fp32_numel_each_device, fp32_numel_each_device, num_devices, local_comm, stream, dev_ctx); + NCCLAllReduceWithScale( + fp32_sum_grad + local_rank * fp32_numel_each_device, + fp32_sum_grad + local_rank * fp32_numel_each_device, + fp32_numel_each_device, + nranks / num_devices, + external_comm, + stream, + dev_ctx); - NCCLAllReduceWithScale(fp16_grad, - fp16_sum_grad, - fp16_numel, - nranks / num_devices, - external_comm, - stream, - dev_ctx); NCCLReduceScatterWithScale( - fp16_sum_grad, + fp16_grad, fp16_sum_grad + local_rank * fp16_numel_each_device, fp16_numel_each_device, num_devices, local_comm, stream, dev_ctx); + NCCLAllReduceWithScale( + fp16_sum_grad + local_rank * fp16_numel_each_device, + fp16_sum_grad + local_rank * fp16_numel_each_device, + fp16_numel_each_device, + nranks / num_devices, + external_comm, + stream, + dev_ctx); } else { NCCLAllReduceWithScale(fp32_grad, fp32_sum_grad, @@ -1839,38 +1841,40 @@ class DistributedFusedLambOpKernel // (3) Do ReduceScatter with scale if (local_shard) { if (use_hierarchical_allreduce) { - NCCLAllReduceWithScale(fp32_grad, - fp32_sum_grad, - fp32_numel, - nranks / num_devices, - external_comm, - stream, - dev_ctx, - fp32_scale); NCCLReduceScatterWithScale( - fp32_sum_grad, + fp32_grad, fp32_sum_grad + local_rank * fp32_numel_each_device, fp32_numel_each_device, num_devices, local_comm, stream, + dev_ctx, + fp32_scale); + NCCLAllReduceWithScale( + fp32_sum_grad + local_rank * fp32_numel_each_device, + fp32_sum_grad + local_rank * fp32_numel_each_device, + fp32_numel_each_device, + nranks / num_devices, + external_comm, + stream, dev_ctx); - NCCLAllReduceWithScale(fp16_grad, - fp16_sum_grad, - fp16_numel, - nranks / num_devices, - external_comm, - stream, - dev_ctx, - fp16_scale); NCCLReduceScatterWithScale( - fp16_sum_grad, + fp16_grad, fp16_sum_grad + local_rank * fp16_numel_each_device, fp16_numel_each_device, num_devices, local_comm, stream, + dev_ctx, + fp16_scale); + NCCLAllReduceWithScale( + fp16_sum_grad + local_rank * fp16_numel_each_device, + fp16_sum_grad + local_rank * fp16_numel_each_device, + fp16_numel_each_device, + nranks / num_devices, + external_comm, + stream, dev_ctx); } else { NCCLAllReduceWithScale(fp32_grad, @@ -1917,37 +1921,39 @@ class DistributedFusedLambOpKernel } else { if (local_shard) { if (use_hierarchical_allreduce) { - NCCLAllReduceWithScale(fp32_grad, - fp32_sum_grad, - fp32_numel, - nranks / num_devices, - external_comm, - stream, - dev_ctx); NCCLReduceScatterWithScale( - fp32_sum_grad, + fp32_grad, fp32_sum_grad + local_rank * fp32_numel_each_device, fp32_numel_each_device, num_devices, local_comm, stream, dev_ctx); + NCCLAllReduceWithScale( + fp32_sum_grad + local_rank * fp32_numel_each_device, + fp32_sum_grad + local_rank * fp32_numel_each_device, + fp32_numel_each_device, + nranks / num_devices, + external_comm, + stream, + dev_ctx); - NCCLAllReduceWithScale(fp16_grad, - fp16_sum_grad, - fp16_numel, - nranks / num_devices, - external_comm, - stream, - dev_ctx); NCCLReduceScatterWithScale( - fp16_sum_grad, + fp16_grad, fp16_sum_grad + local_rank * fp16_numel_each_device, fp16_numel_each_device, num_devices, local_comm, stream, dev_ctx); + NCCLAllReduceWithScale( + fp16_sum_grad + local_rank * fp16_numel_each_device, + fp16_sum_grad + local_rank * fp16_numel_each_device, + fp16_numel_each_device, + nranks / num_devices, + external_comm, + stream, + dev_ctx); } else { NCCLAllReduceWithScale(fp32_grad, fp32_sum_grad,