From 7ab4a711859b680b68efb4bc2b9ce55eef03f296 Mon Sep 17 00:00:00 2001 From: qingshui Date: Wed, 4 Nov 2020 16:17:39 +0800 Subject: [PATCH] Optimize fuse_seqpool_cvm supports nncross and add nncross sync stats configuration (#52) --- paddle/fluid/operators/batch_fc_op.cu | 2 +- .../fluid/operators/cross_norm_hadamard.cu.h | 4 +- .../fluid/operators/cross_norm_hadamard_op.cc | 2 + .../fluid/operators/cross_norm_hadamard_op.cu | 2 +- .../operators/fused/fused_seqpool_cvm_op.cc | 16 +- .../operators/fused/fused_seqpool_cvm_op.cu | 242 +++++++++--------- python/paddle/fluid/contrib/layers/nn.py | 10 +- 7 files changed, 141 insertions(+), 137 deletions(-) diff --git a/paddle/fluid/operators/batch_fc_op.cu b/paddle/fluid/operators/batch_fc_op.cu index bcdb1912c1d0b..1a5b2bc2b8827 100644 --- a/paddle/fluid/operators/batch_fc_op.cu +++ b/paddle/fluid/operators/batch_fc_op.cu @@ -29,7 +29,7 @@ using framework::Tensor; for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ i += blockDim.x * gridDim.x) -const int CUDA_NUM_THREADS = 1024; +const int CUDA_NUM_THREADS = paddle::platform::PADDLE_CUDA_NUM_THREADS; static inline int GET_BLOCKS(const int N) { return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; } diff --git a/paddle/fluid/operators/cross_norm_hadamard.cu.h b/paddle/fluid/operators/cross_norm_hadamard.cu.h index 7d5074d6e5f18..7efa3a1ad5dd4 100644 --- a/paddle/fluid/operators/cross_norm_hadamard.cu.h +++ b/paddle/fluid/operators/cross_norm_hadamard.cu.h @@ -16,6 +16,8 @@ limitations under the License. */ #include #include "cub/cub.cuh" #include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/gpu_info.h" #define NORM_POS(idx, row, col) (((idx)*block_cols + (col)) * ins_num + (row)) #define SCALE_MEAN_POS(idx, col) ((idx)*block_cols + (col)) @@ -29,7 +31,7 @@ limitations under the License. */ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ i += blockDim.x * gridDim.x) -const int CUDA_NUM_THREADS = 1024; +const int CUDA_NUM_THREADS = paddle::platform::PADDLE_CUDA_NUM_THREADS; static inline int GET_BLOCKS(const int N) { return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; } diff --git a/paddle/fluid/operators/cross_norm_hadamard_op.cc b/paddle/fluid/operators/cross_norm_hadamard_op.cc index fc3158fe1e99c..07097098707a2 100644 --- a/paddle/fluid/operators/cross_norm_hadamard_op.cc +++ b/paddle/fluid/operators/cross_norm_hadamard_op.cc @@ -113,6 +113,8 @@ class CrossNormHadamardOpMaker : public framework::OpProtoAndCheckerMaker { PADDLE_ENFORCE(epsilon >= 0.0f && epsilon <= 0.001f, "'epsilon' should be between 0.0 and 0.001."); }); + AddAttr("sync_stats", "(bool, default false) only used in multi-GPU") + .SetDefault(false); AddOutput("Out", "Output tensor of cross_norm_hadamard_op operator."); AddOutput("CudaMeans", "Output tensor of cross_norm_hadamard_op operator."); AddOutput("CudaScales", diff --git a/paddle/fluid/operators/cross_norm_hadamard_op.cu b/paddle/fluid/operators/cross_norm_hadamard_op.cu index 2ca2ca9e5bcf5..5e71b0b173eb9 100644 --- a/paddle/fluid/operators/cross_norm_hadamard_op.cu +++ b/paddle/fluid/operators/cross_norm_hadamard_op.cu @@ -85,6 +85,7 @@ class CrossNormHadamardOpCUDAKernel : public framework::OpKernel { auto embed_dim = ctx.Attr("embed_dim"); const float epsilon = ctx.Attr("epsilon"); const float dr = ctx.Attr("summary_decay_rate"); + const bool need_sync_stats = ctx.Attr("sync_stats"); auto* input_grad = ctx.Output(framework::GradVarName("Input")); auto* summary_grad = @@ -173,7 +174,6 @@ class CrossNormHadamardOpCUDAKernel : public framework::OpKernel { T* summary_input_data = ctx.Output("SummaryInput")->mutable_data(ctx.GetPlace()); - bool need_sync_stats = true; if (need_sync_stats) { #if defined(PADDLE_WITH_NCCL) auto comm = platform::NCCLCommContext::Instance().Get(0, ctx.GetPlace()); diff --git a/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc b/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc index 1894e6102cbea..52129bba2e8f9 100644 --- a/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc +++ b/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/fused/fused_seqpool_cvm_op.h" - +#include namespace paddle { namespace operators { @@ -30,12 +30,12 @@ class FusedSeqpoolCVMOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ( cvm_dims.size(), 2UL, platform::errors::InvalidArgument("Input(CVM)'s rank should be 2.")); - PADDLE_ENFORCE_EQ( - cvm_dims[1], 2UL, - platform::errors::InvalidArgument("The 2nd dimension of " - "Input(CVM) should be 2.")); + PADDLE_ENFORCE_EQ(cvm_dims[1], 2UL, platform::errors::InvalidArgument( + "The 2nd dimension of " + "Input(CVM) should be 2.")); auto ins_dims = ctx->GetInputsDim("X"); + const int cvm_offset = ctx->Attrs().Get("cvm_offset"); const size_t num_inputs = ins_dims.size(); std::vector outs_dims; outs_dims.resize(num_inputs); @@ -69,7 +69,7 @@ class FusedSeqpoolCVMOp : public framework::OperatorWithKernel { if (ctx->Attrs().Get("use_cvm")) { out_dim = {-1, dims[rank - 1]}; } else { - out_dim = {-1, dims[rank - 1] - 2}; + out_dim = {-1, dims[rank - 1] - cvm_offset}; } outs_dims[i] = framework::make_ddim(out_dim); } @@ -111,6 +111,7 @@ class FusedSeqpoolCVMOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("show_coeff", "(float, default 0.2)").SetDefault(0.2); AddAttr("clk_coeff", "(float, default 1)").SetDefault(1); AddAttr("threshold", "(float, default 0.96)").SetDefault(0.96); + AddAttr("cvm_offset", "(int, default 2)").SetDefault(2); AddComment(R"DOC( Fuse multiple pairs of Sequence Pool and CVM Operator. @@ -127,6 +128,7 @@ class FusedSeqpoolCVMGradOp : public framework::OperatorWithKernel { auto og_dims = ctx->GetInputsDim(framework::GradVarName("Out")); auto x_dims = ctx->GetInputsDim("X"); auto cvm_dims = ctx->GetInputDim("CVM"); + const int cvm_offset = ctx->Attrs().Get("cvm_offset"); PADDLE_ENFORCE_EQ( cvm_dims.size(), 2, @@ -151,7 +153,7 @@ class FusedSeqpoolCVMGradOp : public framework::OperatorWithKernel { } else { PADDLE_ENFORCE_EQ( og_dims[i][og_dims[i].size() - 1], - x_dims[i][og_dims[i].size() - 1] - 2, + x_dims[i][og_dims[i].size() - 1] - cvm_offset, platform::errors::InvalidArgument( "The dimension mismatch between Input(OUT@GRAD) and " "Input(X). Received Input(OUT@GRAD): input rank %u, " diff --git a/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cu b/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cu index eae1120200c43..cd208f38e765e 100644 --- a/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cu +++ b/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cu @@ -53,8 +53,9 @@ __global__ void FusedSeqpoolKernel( if (tid <= 1) { // show & click val += *(input_values[x] + k * embedding_size + tid); } else { - val += ((int)(*(input_values[x] + k * embedding_size + tid) * 128 + - 0.5)) / + val += (static_cast( + *(input_values[x] + k * embedding_size + tid) * 128 + + 0.5)) / 128.0; } } @@ -74,18 +75,17 @@ template __global__ void FusedCVMKernel(T **output_values, T **seqpool_output_values, const int64_t *data_lens, const int batch_size, int64_t total_len, const int embedding_size, - bool use_cvm) { + const bool use_cvm, const int cvm_offset) { CUDA_KERNEL_LOOP(i, total_len * embedding_size) { int key = i / embedding_size; int offset = i % embedding_size; int x = key / batch_size; int y = key - (x ? data_lens[x - 1] : 0); - int cvm_offset = 2; if (use_cvm) { - if (offset == 0) { + if (offset == 0) { // show *(output_values[x] + y * embedding_size) = log(*(seqpool_output_values[x] + y * embedding_size) + 1); - } else if (offset == 1) { + } else if (offset == 1) { // click *(output_values[x] + y * embedding_size + offset) = log(*(seqpool_output_values[x] + y * embedding_size + 1) + 1) - log(*(seqpool_output_values[x] + y * embedding_size) + 1); @@ -108,15 +108,13 @@ __global__ void FusedSeqpoolCVMGradKernel( T **out_grads_values, T **out_seqpool_grads_values, T **in_grads_values, T **cvm_values, size_t **lods_values, const int64_t *data_lens, const int batch_size, int64_t total_len, const int embedding_size, - bool use_cvm) { + const bool use_cvm, const int cvm_offset) { CUDA_KERNEL_LOOP(i, total_len * embedding_size) { int key = i / embedding_size; int offset = i % embedding_size; int x = key / batch_size; int y = key - (x ? data_lens[x - 1] : 0); - int cvm_offset = 2; - if (offset < cvm_offset) { *(out_seqpool_grads_values[x] + y * embedding_size + offset) = *(cvm_values[x] + y * cvm_offset + offset); @@ -146,9 +144,9 @@ void DoFusedSeqpoolCVM(const paddle::platform::Place &place, T **gpu_seqpool_output_values, size_t **lods_values, const int64_t *data_lens, int slot_num, int64_t total_len, const int embedding_size, - const float padding_value, bool use_cvm, - bool need_filter, float show_coeff, float clk_coeff, - float threshold) { + const float padding_value, const bool use_cvm, + const int cvm_offset, bool need_filter, float show_coeff, + float clk_coeff, float threshold) { auto stream = dynamic_cast( platform::DeviceContextPool::Instance().Get( BOOST_GET_CONST(platform::CUDAPlace, place))) @@ -165,7 +163,7 @@ void DoFusedSeqpoolCVM(const paddle::platform::Place &place, PADDLE_CUDA_NUM_THREADS, PADDLE_CUDA_NUM_THREADS, 0, stream>>>( gpu_output_values, gpu_seqpool_output_values, data_lens, batch_size, - total_len, embedding_size, use_cvm); + total_len, embedding_size, use_cvm, cvm_offset); } template @@ -176,8 +174,9 @@ void FusedSeqpoolCVM(const paddle::platform::Place &place, std::vector lods, const std::vector &data_lengths, const int embedding_size, const float padding_value, - const bool use_cvm, float need_filter, float show_coeff, - float clk_coeff, float threshold) { + const bool use_cvm, const int cvm_offset, + float need_filter, float show_coeff, float clk_coeff, + float threshold) { auto data_lengths_lod = data_lengths; int slot_num = static_cast(data_lengths.size()); for (int i = 1; i < slot_num; i++) { @@ -228,60 +227,7 @@ void FusedSeqpoolCVM(const paddle::platform::Place &place, DoFusedSeqpoolCVM(place, gpu_input_values, gpu_output_values, gpu_seqpool_output_values, lods_values, data_lens, slot_num, total_length, embedding_size, padding_value, use_cvm, - need_filter, show_coeff, clk_coeff, threshold); -} - -template -static void FusedSeqpoolCVMFunctor(const framework::ExecutionContext &ctx) { - auto inputs = ctx.MultiInput("X"); - auto outputs = ctx.MultiOutput("Out"); - - const auto slot_size = inputs.size(); - std::vector input_data(slot_size); - std::vector data_lens(slot_size); - std::vector lods_data(slot_size); - std::vector output_data(slot_size); - - std::vector seqpool_outputs(slot_size); - std::vector seqpool_output_data(slot_size); - - auto padding_value = ctx.Attr("pad_value"); - auto use_cvm = ctx.Attr("use_cvm"); - bool need_filter = ctx.Attr("need_filter"); - float show_coeff = ctx.Attr("show_coeff"); - float clk_coeff = ctx.Attr("clk_coeff"); - float threshold = ctx.Attr("threshold"); - - int embedding_size = inputs[0]->numel() / inputs[0]->dims()[0]; - - for (size_t i = 0; i < slot_size; ++i) { - const auto *input = inputs[i]; - auto dims = input->dims(); - - auto lod = input->lod(); - auto lod_level = lod.size(); - int batch_size = lod[lod_level - 1].size() - 1; // -1 to real batch size - - input_data[i] = reinterpret_cast(input->data()); - auto *output = outputs[i]; - if (use_cvm) { - output->Resize({batch_size, embedding_size}); - } else { - output->Resize({batch_size, embedding_size - 2}); - } - output_data[i] = - reinterpret_cast(output->mutable_data(ctx.GetPlace())); - data_lens[i] = lod[lod_level - 1].size() - 1; - lods_data[i] = lod[lod_level - 1].CUDAData(ctx.GetPlace()); - - seqpool_output_data[i] = - reinterpret_cast(seqpool_outputs[i].mutable_data( - {batch_size, embedding_size}, ctx.GetPlace())); - } - - FusedSeqpoolCVM(ctx.GetPlace(), input_data, output_data, seqpool_output_data, - lods_data, data_lens, embedding_size, padding_value, use_cvm, - need_filter, show_coeff, clk_coeff, threshold); + cvm_offset, need_filter, show_coeff, clk_coeff, threshold); } template @@ -290,7 +236,8 @@ void DoFusedSeqpoolCVMGrad(const paddle::platform::Place &place, T **in_grads_values, T **gpu_cvm_values, size_t **lods_values, const int64_t *slot_lens, int slot_num, int64_t total_len, - const int embedding_size, bool use_cvm) { + const int embedding_size, const bool use_cvm, + const int cvm_offset) { auto stream = dynamic_cast( platform::DeviceContextPool::Instance().Get( BOOST_GET_CONST(platform::CUDAPlace, place))) @@ -302,7 +249,7 @@ void DoFusedSeqpoolCVMGrad(const paddle::platform::Place &place, PADDLE_CUDA_NUM_THREADS, 0, stream>>>( out_grads_values, out_seqpool_grads_values, in_grads_values, gpu_cvm_values, lods_values, slot_lens, batch_size, total_len, - embedding_size, use_cvm); + embedding_size, use_cvm, cvm_offset); } template @@ -311,9 +258,10 @@ void FusedSeqpoolCVMGrad(const paddle::platform::Place &place, const std::vector &out_seqpool_grads_data, const std::vector &in_grads_data, const std::vector &cvm_data, - std::vector &lods, + const std::vector &lods, const std::vector &data_lengths, - const int embedding_size, const bool use_cvm) { + const int embedding_size, const bool use_cvm, + const int cvm_offset) { auto data_lengths_lod = data_lengths; int slot_num = static_cast(data_lengths.size()); for (int i = 1; i < slot_num; i++) { @@ -370,62 +318,64 @@ void FusedSeqpoolCVMGrad(const paddle::platform::Place &place, DoFusedSeqpoolCVMGrad(place, gpu_out_grads_values, gpu_out_seqpool_grads_values, gpu_in_grads_values, gpu_cvm_values, lods_values, data_lens, slot_num, - total_length, embedding_size, use_cvm); -} - -template -static void FusedSeqpoolCVMGradFunctor(const framework::ExecutionContext &ctx) { - auto out_grads = ctx.MultiInput(framework::GradVarName("Out")); - auto in_grads = ctx.MultiOutput(framework::GradVarName("X")); - auto *cvm = ctx.Input("CVM"); - - std::string pooltype = ctx.Attr("pooltype"); - auto use_cvm = ctx.Attr("use_cvm"); - - const auto slot_size = in_grads.size(); - std::vector out_grads_data(slot_size); - std::vector in_grads_data(slot_size); - std::vector cvm_data(slot_size); - std::vector lods_data(slot_size); - std::vector data_lengths(slot_size); - - std::vector out_seqpool_grads(slot_size); - std::vector out_seqpool_grads_data(slot_size); - - int embedding_size = in_grads[0]->numel() / in_grads[0]->dims()[0]; - - for (size_t i = 0; i < slot_size; ++i) { - auto *in_grad = in_grads[i]; - auto dims = in_grad->dims(); - - auto lod = in_grad->lod(); - auto lod_level = lod.size(); - int batch_size = lod[lod_level - 1].size() - 1; // -1 to real batch size - - auto *out_grad = out_grads[i]; - out_grads_data[i] = reinterpret_cast(out_grad->data()); - - in_grads_data[i] = - reinterpret_cast(in_grad->mutable_data(ctx.GetPlace())); - lods_data[i] = lod[lod_level - 1].CUDAData(ctx.GetPlace()); - data_lengths[i] = lod[lod_level - 1].size() - 1; - cvm_data[i] = reinterpret_cast(cvm->data()); - - out_seqpool_grads_data[i] = - reinterpret_cast(out_seqpool_grads[i].mutable_data( - {batch_size, embedding_size}, ctx.GetPlace())); - } - - FusedSeqpoolCVMGrad(ctx.GetPlace(), out_grads_data, out_seqpool_grads_data, - in_grads_data, cvm_data, lods_data, data_lengths, - embedding_size, use_cvm); + total_length, embedding_size, use_cvm, cvm_offset); } template class FusedSeqpoolCVMCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { - FusedSeqpoolCVMFunctor(ctx); + auto inputs = ctx.MultiInput("X"); + auto outputs = ctx.MultiOutput("Out"); + + const auto slot_size = inputs.size(); + std::vector input_data(slot_size); + std::vector data_lens(slot_size); + std::vector lods_data(slot_size); + std::vector output_data(slot_size); + + std::vector seqpool_outputs(slot_size); + std::vector seqpool_output_data(slot_size); + + auto padding_value = ctx.Attr("pad_value"); + auto use_cvm = ctx.Attr("use_cvm"); + bool need_filter = ctx.Attr("need_filter"); + float show_coeff = ctx.Attr("show_coeff"); + float clk_coeff = ctx.Attr("clk_coeff"); + float threshold = ctx.Attr("threshold"); + const int cvm_offset = ctx.Attr("cvm_offset"); + + int embedding_size = inputs[0]->numel() / inputs[0]->dims()[0]; + + for (size_t i = 0; i < slot_size; ++i) { + const auto *input = inputs[i]; + auto dims = input->dims(); + + auto lod = input->lod(); + auto lod_level = lod.size(); + int batch_size = lod[lod_level - 1].size() - 1; // -1 to real batch size + + input_data[i] = reinterpret_cast(input->data()); + auto *output = outputs[i]; + if (use_cvm) { + output->Resize({batch_size, embedding_size}); + } else { + output->Resize({batch_size, embedding_size - cvm_offset}); + } + output_data[i] = + reinterpret_cast(output->mutable_data(ctx.GetPlace())); + data_lens[i] = lod[lod_level - 1].size() - 1; + lods_data[i] = lod[lod_level - 1].CUDAData(ctx.GetPlace()); + + seqpool_output_data[i] = + reinterpret_cast(seqpool_outputs[i].mutable_data( + {batch_size, embedding_size}, ctx.GetPlace())); + } + + FusedSeqpoolCVM(ctx.GetPlace(), input_data, output_data, + seqpool_output_data, lods_data, data_lens, embedding_size, + padding_value, use_cvm, cvm_offset, need_filter, show_coeff, + clk_coeff, threshold); } }; @@ -433,7 +383,51 @@ template class FusedSeqpoolCVMGradCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { - FusedSeqpoolCVMGradFunctor(ctx); + auto out_grads = ctx.MultiInput(framework::GradVarName("Out")); + auto in_grads = ctx.MultiOutput(framework::GradVarName("X")); + auto *cvm = ctx.Input("CVM"); + + std::string pooltype = ctx.Attr("pooltype"); + auto use_cvm = ctx.Attr("use_cvm"); + const int cvm_offset = ctx.Attr("cvm_offset"); + + const auto slot_size = in_grads.size(); + std::vector out_grads_data(slot_size); + std::vector in_grads_data(slot_size); + std::vector cvm_data(slot_size); + std::vector lods_data(slot_size); + std::vector data_lengths(slot_size); + + std::vector out_seqpool_grads(slot_size); + std::vector out_seqpool_grads_data(slot_size); + + int embedding_size = in_grads[0]->numel() / in_grads[0]->dims()[0]; + + for (size_t i = 0; i < slot_size; ++i) { + auto *in_grad = in_grads[i]; + auto dims = in_grad->dims(); + + auto lod = in_grad->lod(); + auto lod_level = lod.size(); + int batch_size = lod[lod_level - 1].size() - 1; // -1 to real batch size + + auto *out_grad = out_grads[i]; + out_grads_data[i] = reinterpret_cast(out_grad->data()); + + in_grads_data[i] = + reinterpret_cast(in_grad->mutable_data(ctx.GetPlace())); + lods_data[i] = lod[lod_level - 1].CUDAData(ctx.GetPlace()); + data_lengths[i] = lod[lod_level - 1].size() - 1; + cvm_data[i] = reinterpret_cast(cvm->data()); + + out_seqpool_grads_data[i] = + reinterpret_cast(out_seqpool_grads[i].mutable_data( + {batch_size, embedding_size}, ctx.GetPlace())); + } + + FusedSeqpoolCVMGrad(ctx.GetPlace(), out_grads_data, out_seqpool_grads_data, + in_grads_data, cvm_data, lods_data, data_lengths, + embedding_size, use_cvm, cvm_offset); } }; diff --git a/python/paddle/fluid/contrib/layers/nn.py b/python/paddle/fluid/contrib/layers/nn.py index e8e6f437e6548..f5f19d53bd0c0 100644 --- a/python/paddle/fluid/contrib/layers/nn.py +++ b/python/paddle/fluid/contrib/layers/nn.py @@ -1431,7 +1431,8 @@ def fused_seqpool_cvm(input, need_filter=False, show_coeff=0.2, clk_coeff=1.0, - threshold=0.96): + threshold=0.96, + cvm_offset=2): """ **Notes: The Op only receives List of LoDTensor as input, only support SUM pooling now. :attr:`input`. @@ -1474,6 +1475,7 @@ def fused_seqpool_cvm(input, "pooltype": pool_type.upper(), "pad_value": pad_value, "use_cvm": use_cvm, + "cvm_offset": cvm_offset, "need_filter": need_filter, "show_coeff": show_coeff, "clk_coeff": clk_coeff, @@ -1489,7 +1491,8 @@ def cross_norm_layer_hadamard(input, param_dict={}, summary_decay_rate=0.9999999, epsilon=1e-04, - name=None): + name=None, + sync_stats=False): """ **Cross Norm Layer Hadamard** """ @@ -1538,6 +1541,7 @@ def cross_norm_layer_hadamard(input, "fields_num": fields_num, "embed_dim": embed_dim, "epsilon": epsilon, - "summary_decay_rate": summary_decay_rate + "summary_decay_rate": summary_decay_rate, + "sync_stats": sync_stats }) return out