Skip to content

Commit

Permalink
Optimize fuse_seqpool_cvm supports nncross and add nncross sync stats…
Browse files Browse the repository at this point in the history
… configuration (#52)
  • Loading branch information
qingshui authored Nov 4, 2020
1 parent 9ca71a8 commit 7ab4a71
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 137 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/operators/batch_fc_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ using framework::Tensor;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)

const int CUDA_NUM_THREADS = 1024;
const int CUDA_NUM_THREADS = paddle::platform::PADDLE_CUDA_NUM_THREADS;
static inline int GET_BLOCKS(const int N) {
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
}
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/operators/cross_norm_hadamard.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ limitations under the License. */
#include <memory.h>
#include "cub/cub.cuh"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/gpu_info.h"

#define NORM_POS(idx, row, col) (((idx)*block_cols + (col)) * ins_num + (row))
#define SCALE_MEAN_POS(idx, col) ((idx)*block_cols + (col))
Expand All @@ -29,7 +31,7 @@ limitations under the License. */
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)

const int CUDA_NUM_THREADS = 1024;
const int CUDA_NUM_THREADS = paddle::platform::PADDLE_CUDA_NUM_THREADS;
static inline int GET_BLOCKS(const int N) {
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
}
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/operators/cross_norm_hadamard_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ class CrossNormHadamardOpMaker : public framework::OpProtoAndCheckerMaker {
PADDLE_ENFORCE(epsilon >= 0.0f && epsilon <= 0.001f,
"'epsilon' should be between 0.0 and 0.001.");
});
AddAttr<bool>("sync_stats", "(bool, default false) only used in multi-GPU")
.SetDefault(false);
AddOutput("Out", "Output tensor of cross_norm_hadamard_op operator.");
AddOutput("CudaMeans", "Output tensor of cross_norm_hadamard_op operator.");
AddOutput("CudaScales",
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/cross_norm_hadamard_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class CrossNormHadamardOpCUDAKernel : public framework::OpKernel<T> {
auto embed_dim = ctx.Attr<int64_t>("embed_dim");
const float epsilon = ctx.Attr<float>("epsilon");
const float dr = ctx.Attr<float>("summary_decay_rate");
const bool need_sync_stats = ctx.Attr<bool>("sync_stats");

auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
auto* summary_grad =
Expand Down Expand Up @@ -173,7 +174,6 @@ class CrossNormHadamardOpCUDAKernel : public framework::OpKernel<T> {
T* summary_input_data =
ctx.Output<Tensor>("SummaryInput")->mutable_data<T>(ctx.GetPlace());

bool need_sync_stats = true;
if (need_sync_stats) {
#if defined(PADDLE_WITH_NCCL)
auto comm = platform::NCCLCommContext::Instance().Get(0, ctx.GetPlace());
Expand Down
16 changes: 9 additions & 7 deletions paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/fused/fused_seqpool_cvm_op.h"

#include <string>
namespace paddle {
namespace operators {

Expand All @@ -30,12 +30,12 @@ class FusedSeqpoolCVMOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(
cvm_dims.size(), 2UL,
platform::errors::InvalidArgument("Input(CVM)'s rank should be 2."));
PADDLE_ENFORCE_EQ(
cvm_dims[1], 2UL,
platform::errors::InvalidArgument("The 2nd dimension of "
"Input(CVM) should be 2."));
PADDLE_ENFORCE_EQ(cvm_dims[1], 2UL, platform::errors::InvalidArgument(
"The 2nd dimension of "
"Input(CVM) should be 2."));

auto ins_dims = ctx->GetInputsDim("X");
const int cvm_offset = ctx->Attrs().Get<int>("cvm_offset");
const size_t num_inputs = ins_dims.size();
std::vector<framework::DDim> outs_dims;
outs_dims.resize(num_inputs);
Expand Down Expand Up @@ -69,7 +69,7 @@ class FusedSeqpoolCVMOp : public framework::OperatorWithKernel {
if (ctx->Attrs().Get<bool>("use_cvm")) {
out_dim = {-1, dims[rank - 1]};
} else {
out_dim = {-1, dims[rank - 1] - 2};
out_dim = {-1, dims[rank - 1] - cvm_offset};
}
outs_dims[i] = framework::make_ddim(out_dim);
}
Expand Down Expand Up @@ -111,6 +111,7 @@ class FusedSeqpoolCVMOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<float>("show_coeff", "(float, default 0.2)").SetDefault(0.2);
AddAttr<float>("clk_coeff", "(float, default 1)").SetDefault(1);
AddAttr<float>("threshold", "(float, default 0.96)").SetDefault(0.96);
AddAttr<int>("cvm_offset", "(int, default 2)").SetDefault(2);

AddComment(R"DOC(
Fuse multiple pairs of Sequence Pool and CVM Operator.
Expand All @@ -127,6 +128,7 @@ class FusedSeqpoolCVMGradOp : public framework::OperatorWithKernel {
auto og_dims = ctx->GetInputsDim(framework::GradVarName("Out"));
auto x_dims = ctx->GetInputsDim("X");
auto cvm_dims = ctx->GetInputDim("CVM");
const int cvm_offset = ctx->Attrs().Get<int>("cvm_offset");

PADDLE_ENFORCE_EQ(
cvm_dims.size(), 2,
Expand All @@ -151,7 +153,7 @@ class FusedSeqpoolCVMGradOp : public framework::OperatorWithKernel {
} else {
PADDLE_ENFORCE_EQ(
og_dims[i][og_dims[i].size() - 1],
x_dims[i][og_dims[i].size() - 1] - 2,
x_dims[i][og_dims[i].size() - 1] - cvm_offset,
platform::errors::InvalidArgument(
"The dimension mismatch between Input(OUT@GRAD) and "
"Input(X). Received Input(OUT@GRAD): input rank %u, "
Expand Down
Loading

0 comments on commit 7ab4a71

Please sign in to comment.