From e2e1c57b1db6fc94be8ff0bdbc5a62103b6a5498 Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Mon, 12 Jul 2021 11:08:47 +0800 Subject: [PATCH] softmax mask fuse upper triangle (#33981) * softmax mask fuse upper triangle * cover not implemented cpu code --- .../softmax_mask_fuse_upper_triangle_op.cc | 107 ++++ .../softmax_mask_fuse_upper_triangle_op.cu | 546 ++++++++++++++++++ .../softmax_mask_fuse_upper_triangle_op.h | 30 + ...est_softmax_mask_fuse_upper_triangle_op.py | 117 ++++ python/paddle/incubate/__init__.py | 3 +- python/paddle/incubate/operators/__init__.py | 15 + .../softmax_mask_fuse_upper_triangle.py | 42 ++ python/setup.py.in | 1 + 8 files changed, 860 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/operators/softmax_mask_fuse_upper_triangle_op.cc create mode 100644 paddle/fluid/operators/softmax_mask_fuse_upper_triangle_op.cu create mode 100644 paddle/fluid/operators/softmax_mask_fuse_upper_triangle_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_softmax_mask_fuse_upper_triangle_op.py create mode 100644 python/paddle/incubate/operators/__init__.py create mode 100644 python/paddle/incubate/operators/softmax_mask_fuse_upper_triangle.py diff --git a/paddle/fluid/operators/softmax_mask_fuse_upper_triangle_op.cc b/paddle/fluid/operators/softmax_mask_fuse_upper_triangle_op.cc new file mode 100644 index 0000000000000..fa5f996f5c150 --- /dev/null +++ b/paddle/fluid/operators/softmax_mask_fuse_upper_triangle_op.cc @@ -0,0 +1,107 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include "paddle/fluid/operators/softmax_mask_fuse_upper_triangle_op.h" +#include "paddle/fluid/framework/generator.h" +#include "paddle/fluid/framework/op_registry.h" +namespace paddle { +namespace operators { + +using framework::Tensor; + +class SoftmaxMaskFuseUpperTriangleOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", + "SoftmaxMaskFuseUpperTriangle"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", + "SoftmaxMaskFuseUpperTriangle"); + + auto x_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE_EQ( + x_dims.size(), 4, + platform::errors::InvalidArgument("Input x must be in 4D dimension but " + "received the dimension of X is %d", + x_dims.size())); + + ctx->SetOutputDim("Out", x_dims); + ctx->ShareLoD("X", "Out"); + } +}; + +class SoftmaxMaskFuseUpperTriangleOpMaker + : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "The input of softmax_mask_fuse_upper_triangle op, " + "which is the result of matmul(QK)/sqrt(dk)."); + AddOutput("Out", "The result of softmax_mask_fuse_upper_triangle op."); + + AddComment(R"DOC( +Softmax Mask Fuse Operator. +product = matmul(QK)/sqrt(dk) +output = softmax_mask_fuse_upper_triangle(product) +to get the final output. +)DOC"); + } +}; + +class SoftmaxMaskFuseUpperTriangleOpGrad + : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + framework::GradVarName("Out"), + "SoftmaxMaskFuseUpperTriangleGrad"); + + auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); + ctx->SetOutputDim(framework::GradVarName("X"), out_dims); + ctx->ShareLoD(framework::GradVarName("Out"), framework::GradVarName("X")); + } +}; + +template +class SoftmaxMaskFuseUpperTriangleGradOpMaker + : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("softmax_mask_fuse_upper_triangle_grad"); + op->SetInput("Softmax", this->Output("Out")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR( + softmax_mask_fuse_upper_triangle, ops::SoftmaxMaskFuseUpperTriangleOp, + ops::SoftmaxMaskFuseUpperTriangleOpMaker, + ops::SoftmaxMaskFuseUpperTriangleGradOpMaker, + ops::SoftmaxMaskFuseUpperTriangleGradOpMaker); +REGISTER_OPERATOR(softmax_mask_fuse_upper_triangle_grad, + ops::SoftmaxMaskFuseUpperTriangleOpGrad); +REGISTER_OP_CPU_KERNEL(softmax_mask_fuse_upper_triangle, + ops::SoftmaxMaskFuseUpperTriangleCPUKernel< + paddle::platform::CPUDeviceContext, float>, + ops::SoftmaxMaskFuseUpperTriangleCPUKernel< + paddle::platform::CPUDeviceContext, double>); diff --git a/paddle/fluid/operators/softmax_mask_fuse_upper_triangle_op.cu b/paddle/fluid/operators/softmax_mask_fuse_upper_triangle_op.cu new file mode 100644 index 0000000000000..9a1b4332e8b9f --- /dev/null +++ b/paddle/fluid/operators/softmax_mask_fuse_upper_triangle_op.cu @@ -0,0 +1,546 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +// this file is inspired by: +// https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h + +#ifdef PADDLE_WITH_CUDA +#include +#include +#endif +#ifdef PADDLE_WITH_HIP +#include +#include +#endif +#include +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/generator.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/operators/softmax_mask_fuse_upper_triangle_op.h" +#include "paddle/fluid/platform/float16.h" + +namespace paddle { +namespace operators { +using framework::Tensor; + +#ifdef PADDLE_WITH_HIP +#define WARP_SIZE 64 +#else +#define WARP_SIZE 32 +#endif + +#define MASK 0xffffffff + +namespace plat = paddle::platform; + +__device__ __inline__ void load_data_upper_tri(plat::float16* dst, + const plat::float16* src) { + *(reinterpret_cast(dst)) = *(reinterpret_cast(src)); +} + +__device__ __inline__ void load_data_upper_tri(float* dst, const float* src) { + *(reinterpret_cast(dst)) = *(reinterpret_cast(src)); +} + +__device__ __inline__ void load_zero_vector_upper_tri(plat::float16* dst) { + *(reinterpret_cast(dst)) = make_float2(0.0f, 0.0f); +} + +__device__ __inline__ void load_zero_vector_upper_tri(float* dst) { + *(reinterpret_cast(dst)) = make_float4(0.0f, 0.0f, 0.0f, 0.0f); +} + +int get_pow2_index_value(int value) { + int pow2_index = 0; + while ((1 << pow2_index) < value) { + ++pow2_index; + } + return pow2_index; +} + +template +struct AddOP_upper_tri { + __device__ __forceinline__ T operator()(T a, T b) const { return a + b; } +}; + +template +struct MaxOP_upper_tri { + __device__ __forceinline__ T operator()(T a, T b) const { + return a < b ? b : a; + } +}; + +template +__device__ __forceinline__ T warp_shfl_xor_upper_tri(T value, int laneMask, + int width, + unsigned int mask = MASK) { +#if CUDA_VERSION >= 9000 + return __shfl_xor_sync(mask, value, laneMask, width); +#else + return __shfl_xor(value, laneMask, width); +#endif +} + +template class ReduceOp> +__device__ __forceinline__ void warp_reduce_upper_tri(T* sum) { + ReduceOp r; +#pragma unroll + for (int offset = width / 2; offset > 0; offset /= 2) { +#pragma unroll + for (int i = 0; i < batch; ++i) { + T b = warp_shfl_xor_upper_tri(sum[i], offset, width); + sum[i] = r(sum[i], b); + } + } +} + +template +__global__ void SoftmaxMaskFuseUpperTriangleGPUKernel(const T* src, T* dst, + int batch_count, + int key_seq_len) { + constexpr int next_pow2 = 1 << pow2_index; + constexpr int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE; + constexpr int kLocalIterations = std::max(next_pow2 / warp_size, 4); + constexpr int kLocalBatchSize = (next_pow2 <= 128) ? 2 : 1; + constexpr int kOneLoadingCounts = 4; + int key_seq_len_pow_2 = key_seq_len * key_seq_len; + + int first_idx = + (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * kLocalBatchSize + + blockIdx.x; + int local_block_idx = blockIdx.x + 1; + int warp_iter_upper_bound = + (local_block_idx + kOneLoadingCounts * warp_size - 1) / warp_size; + + int local_batches = batch_count - first_idx; + if (local_batches > kLocalBatchSize) local_batches = kLocalBatchSize; + + int local_idx = threadIdx.x; + + src += first_idx * key_seq_len + kOneLoadingCounts * local_idx; + dst += first_idx * key_seq_len + kOneLoadingCounts * local_idx; + + float data[kLocalBatchSize][kLocalIterations]; + T temp_in[kOneLoadingCounts]; + +#pragma unroll + for (int i = 0; i < kLocalBatchSize; ++i) { + int batch_total_number = (i >= local_batches) ? 0 : local_block_idx; + +#pragma unroll + for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) { + int element_index = kOneLoadingCounts * local_idx + ii * warp_size; + + if (element_index < batch_total_number) { + load_data_upper_tri(temp_in, + src + i * key_seq_len_pow_2 + ii * warp_size); + +#pragma unroll + for (int counter = 0; counter < kOneLoadingCounts; ++counter) { + if ((element_index + counter) < batch_total_number) { + data[i][ii + counter] = static_cast(temp_in[counter]); + } else { + data[i][ii + counter] = -std::numeric_limits::infinity(); + } + } + } else { +#pragma unroll + for (int counter = 0; counter < kOneLoadingCounts; ++counter) { + data[i][ii + counter] = -std::numeric_limits::infinity(); + } + } + } + } + + float max_value[kLocalBatchSize]; +#pragma unroll + for (int i = 0; i < kLocalBatchSize; ++i) { + max_value[i] = data[i][0]; +#pragma unroll + for (int ii = 1; ii < kLocalIterations; ++ii) { + max_value[i] = (max_value[i] > data[i][ii]) ? max_value[i] : data[i][ii]; + } + } + warp_reduce_upper_tri( + max_value); + + float sum[kLocalBatchSize]{0.0f}; +#pragma unroll + for (int i = 0; i < kLocalBatchSize; ++i) { +#pragma unroll + for (int ii = 0; ii < kLocalIterations; ++ii) { + if (ii < warp_iter_upper_bound) { + data[i][ii] = std::exp((data[i][ii] - max_value[i])); + sum[i] += data[i][ii]; + } + } + } + warp_reduce_upper_tri( + sum); + + T out[kOneLoadingCounts]; +#pragma unroll + for (int i = 0; i < kLocalBatchSize; ++i) { + if (i >= local_batches) break; +#pragma unroll + for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) { + int element_index = kOneLoadingCounts * local_idx + ii * warp_size; + + if (element_index < local_block_idx) { +#pragma unroll + for (int counter = 0; counter < kOneLoadingCounts; ++counter) { + if (element_index + counter < local_block_idx) { + out[counter] = data[i][ii + counter] / sum[i]; + } else { + out[counter] = 0; + } + } + load_data_upper_tri(dst + i * key_seq_len_pow_2 + ii * warp_size, out); + } else if (element_index < key_seq_len) { + load_zero_vector_upper_tri(dst + i * key_seq_len_pow_2 + + ii * warp_size); + } else { + break; + } + } + } +} + +template +__global__ void SoftmaxMaskFuseUpperTriangleGradGPUKernel(const T* grad_input, + T* grad_output, + const T* softmax_rst, + int batch_count, + int key_seq_len) { + constexpr int next_pow2 = 1 << pow2_index; + constexpr int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE; + constexpr int kLocalIterations = std::max(next_pow2 / warp_size, 4); + constexpr int kLocalBatchSize = (next_pow2 <= 128) ? 2 : 1; + constexpr int kOneLoadingCounts = 4; + int key_seq_len_pow_2 = key_seq_len * key_seq_len; + + int first_idx = + (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * kLocalBatchSize + + blockIdx.x; + int local_block_idx = blockIdx.x + 1; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = batch_count - first_idx; + if (local_batches > kLocalBatchSize) local_batches = kLocalBatchSize; + + // there might be multiple batches per warp. compute the index within the + // batch + int local_idx = threadIdx.x; + + // the first element to process by the current thread + int offset = first_idx * key_seq_len + kOneLoadingCounts * local_idx; + grad_input += offset; + grad_output += offset; + softmax_rst += offset; + + // load data from global memory + float grad_input_reg[kLocalBatchSize][kLocalIterations]{0.0f}; + float softmax_rst_reg[kLocalBatchSize][kLocalIterations]{0.0f}; + T temp_grad_input[kOneLoadingCounts]; + T temp_softmax_rst[kOneLoadingCounts]; + +#pragma unroll + for (int i = 0; i < kLocalBatchSize; ++i) { + int batch_total_number = (i >= local_batches) ? 0 : local_block_idx; + +#pragma unroll + for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) { + int element_index = kOneLoadingCounts * local_idx + ii * warp_size; + if (element_index < batch_total_number) { + load_data_upper_tri( + temp_grad_input, + grad_input + i * key_seq_len_pow_2 + ii * warp_size); + load_data_upper_tri( + temp_softmax_rst, + softmax_rst + i * key_seq_len_pow_2 + ii * warp_size); + +#pragma unroll + for (int counter = 0; counter < kOneLoadingCounts; ++counter) { + if (element_index + counter < batch_total_number) { + softmax_rst_reg[i][ii + counter] = + static_cast(temp_softmax_rst[counter]); + } + } +#pragma unroll + for (int counter = 0; counter < kOneLoadingCounts; ++counter) { + if (element_index + counter < batch_total_number) { + grad_input_reg[i][ii + counter] = + static_cast(temp_grad_input[counter]) * + softmax_rst_reg[i][ii + counter]; + } + } + } + } + } + + float sum[kLocalBatchSize]; +#pragma unroll + for (int i = 0; i < kLocalBatchSize; ++i) { + sum[i] = grad_input_reg[i][0]; +#pragma unroll + for (int ii = 1; ii < kLocalIterations; ++ii) { + sum[i] += grad_input_reg[i][ii]; + } + } + warp_reduce_upper_tri( + sum); + +#pragma unroll + for (int i = 0; i < kLocalBatchSize; ++i) { + if (i >= local_batches) break; +#pragma unroll + for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) { + int element_index = kOneLoadingCounts * local_idx + ii * warp_size; + if (element_index < key_seq_len) { + // compute gradients + T samples_out[kOneLoadingCounts]; +#pragma unroll + for (int counter = 0; counter < kOneLoadingCounts; ++counter) { + samples_out[counter] = grad_input_reg[i][ii + counter] - + softmax_rst_reg[i][ii + counter] * sum[i]; + } + load_data_upper_tri( + grad_output + i * key_seq_len_pow_2 + ii * warp_size, samples_out); + } + } + } +} + +template +class SoftmaxMaskFuseUpperTriangleKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); + auto* y = context.Output("Out"); + + auto* x_data = x->data(); + auto* y_data = y->mutable_data(context.GetPlace()); + + auto x_dim = x->dims(); + auto batches = x_dim[0]; + auto attn_heads = x_dim[1]; + auto attn_mul_batch = batches * attn_heads; + auto query_seq_len = x_dim[2]; + auto key_seq_len = x_dim[3]; + + PADDLE_ENFORCE_EQ(key_seq_len, query_seq_len, + platform::errors::InvalidArgument( + "Key seq len must be equal with query seq len " + "received key len: %d, query len: %d", + key_seq_len, query_seq_len)); + + PADDLE_ENFORCE_EQ(key_seq_len >= 32 && key_seq_len < 8192, true, + platform::errors::InvalidArgument( + "Input x's last dim must be between [32, 8192) " + "received the last dimension of x is %d", + key_seq_len)); + + auto& place = *context.template device_context().eigen_device(); + auto stream = context.cuda_device_context().stream(); + + int pow2_index = get_pow2_index_value(key_seq_len); + const int next_pow2 = 1 << pow2_index; + int batch_count = attn_mul_batch * query_seq_len; + int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE; + int batches_per_warp = (next_pow2 <= 128) ? 2 : 1; + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + PADDLE_ENFORCE_EQ( + query_seq_len % batches_per_block, 0, + platform::errors::InvalidArgument( + "The query seq len (third dim of input X) must can divide the " + "number of batches per block. The query seq len is %d, while " + "the number of batches per block is %d.", + query_seq_len, batches_per_block)); + dim3 blocks(query_seq_len, + (attn_mul_batch + batches_per_block) / batches_per_block, 1); + dim3 threads(warp_size, warps_per_block, 1); + + switch (pow2_index) { + case 5: // 32 + SoftmaxMaskFuseUpperTriangleGPUKernel< + T, 5><<>>(x_data, y_data, batch_count, + key_seq_len); + break; + case 6: // 64 + SoftmaxMaskFuseUpperTriangleGPUKernel< + T, 6><<>>(x_data, y_data, batch_count, + key_seq_len); + break; + case 7: // 128 + SoftmaxMaskFuseUpperTriangleGPUKernel< + T, 7><<>>(x_data, y_data, batch_count, + key_seq_len); + break; + case 8: // 256 + SoftmaxMaskFuseUpperTriangleGPUKernel< + T, 8><<>>(x_data, y_data, batch_count, + key_seq_len); + break; + case 9: // 512 + SoftmaxMaskFuseUpperTriangleGPUKernel< + T, 9><<>>(x_data, y_data, batch_count, + key_seq_len); + break; + case 10: // 1024 + SoftmaxMaskFuseUpperTriangleGPUKernel< + T, 10><<>>(x_data, y_data, batch_count, + key_seq_len); + break; + case 11: // 2048 + SoftmaxMaskFuseUpperTriangleGPUKernel< + T, 11><<>>(x_data, y_data, batch_count, + key_seq_len); + break; + case 12: // 4096 + SoftmaxMaskFuseUpperTriangleGPUKernel< + T, 12><<>>(x_data, y_data, batch_count, + key_seq_len); + break; + case 13: // 8192 + SoftmaxMaskFuseUpperTriangleGPUKernel< + T, 13><<>>(x_data, y_data, batch_count, + key_seq_len); + break; + default: + break; + } + } +}; + +template +class SoftmaxMaskFuseUpperTriangleGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* grad_x = context.Output(framework::GradVarName("X")); + auto* grad_y = context.Input(framework::GradVarName("Out")); + auto* softmax_rst = context.Input("Softmax"); + + auto* grad_x_data = grad_x->mutable_data(context.GetPlace()); + auto* grad_y_data = grad_y->data(); + auto* softmax_rst_data = softmax_rst->data(); + + auto y_dim = grad_y->dims(); + auto batches = y_dim[0]; + auto attn_heads = y_dim[1]; + auto attn_mul_batch = batches * attn_heads; + auto query_seq_len = y_dim[2]; + auto key_seq_len = y_dim[3]; + + auto& place = *context.template device_context().eigen_device(); + auto stream = context.cuda_device_context().stream(); + + int pow2_index = get_pow2_index_value(key_seq_len); + const int next_pow2 = 1 << pow2_index; + int batch_count = attn_mul_batch * query_seq_len; + int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE; + int batches_per_warp = (next_pow2 <= 128) ? 2 : 1; + // use 128 threads per block to maximum gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + dim3 blocks(query_seq_len, + (attn_mul_batch + batches_per_block) / batches_per_block, 1); + dim3 threads(warp_size, warps_per_block, 1); + + switch (pow2_index) { + case 5: // 32 + SoftmaxMaskFuseUpperTriangleGradGPUKernel< + T, 5><<>>(grad_y_data, grad_x_data, + softmax_rst_data, batch_count, + key_seq_len); + break; + case 6: // 64 + SoftmaxMaskFuseUpperTriangleGradGPUKernel< + T, 6><<>>(grad_y_data, grad_x_data, + softmax_rst_data, batch_count, + key_seq_len); + break; + case 7: // 128 + SoftmaxMaskFuseUpperTriangleGradGPUKernel< + T, 7><<>>(grad_y_data, grad_x_data, + softmax_rst_data, batch_count, + key_seq_len); + break; + case 8: // 256 + SoftmaxMaskFuseUpperTriangleGradGPUKernel< + T, 8><<>>(grad_y_data, grad_x_data, + softmax_rst_data, batch_count, + key_seq_len); + break; + case 9: // 512 + SoftmaxMaskFuseUpperTriangleGradGPUKernel< + T, 9><<>>(grad_y_data, grad_x_data, + softmax_rst_data, batch_count, + key_seq_len); + break; + case 10: // 1024 + SoftmaxMaskFuseUpperTriangleGradGPUKernel< + T, 10><<>>(grad_y_data, grad_x_data, + softmax_rst_data, + batch_count, key_seq_len); + break; + case 11: // 2048 + SoftmaxMaskFuseUpperTriangleGradGPUKernel< + T, 11><<>>(grad_y_data, grad_x_data, + softmax_rst_data, + batch_count, key_seq_len); + break; + case 12: // 4096 + SoftmaxMaskFuseUpperTriangleGradGPUKernel< + T, 12><<>>(grad_y_data, grad_x_data, + softmax_rst_data, + batch_count, key_seq_len); + break; + case 13: // 8192 + SoftmaxMaskFuseUpperTriangleGradGPUKernel< + T, 13><<>>(grad_y_data, grad_x_data, + softmax_rst_data, + batch_count, key_seq_len); + break; + default: + break; + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_CUDA_KERNEL( + softmax_mask_fuse_upper_triangle, + ops::SoftmaxMaskFuseUpperTriangleKernel, + ops::SoftmaxMaskFuseUpperTriangleKernel); +REGISTER_OP_CUDA_KERNEL( + softmax_mask_fuse_upper_triangle_grad, + ops::SoftmaxMaskFuseUpperTriangleGradKernel, + ops::SoftmaxMaskFuseUpperTriangleGradKernel); diff --git a/paddle/fluid/operators/softmax_mask_fuse_upper_triangle_op.h b/paddle/fluid/operators/softmax_mask_fuse_upper_triangle_op.h new file mode 100644 index 0000000000000..61dc571066d2b --- /dev/null +++ b/paddle/fluid/operators/softmax_mask_fuse_upper_triangle_op.h @@ -0,0 +1,30 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { +template +class SoftmaxMaskFuseUpperTriangleCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, + platform::errors::Unimplemented( + "Softmax mask fuse op only supports GPU now.")); + } +}; +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_softmax_mask_fuse_upper_triangle_op.py b/python/paddle/fluid/tests/unittests/test_softmax_mask_fuse_upper_triangle_op.py new file mode 100644 index 0000000000000..a5f59c6d1f261 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_softmax_mask_fuse_upper_triangle_op.py @@ -0,0 +1,117 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle.fluid.core as core +from op_test import OpTest +import paddle +import paddle.fluid as fluid +import paddle.incubate as incubate + +paddle.enable_static() + + +def _get_softmax_upper(x, fp16=True): + x_lower = np.tril(x) + masked_x = np.where(x_lower == 0, -10000.0, x_lower).astype("float32") + max_value = np.max(masked_x, axis=-1, keepdims=True) + before_exp = masked_x - max_value + exp = np.exp(before_exp) + exp_sum = np.sum(exp, axis=-1, keepdims=True) + rst = exp / exp_sum + if fp16: + rst = rst.astype("float16") + return rst + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestSoftmaxMaskFuseOp(OpTest): + def setUp(self): + self.op_type = "softmax_mask_fuse_upper_triangle" + x = np.random.random((1, 1, 32, 32)).astype("float16") + self.inputs = {'X': x} + rst = _get_softmax_upper(x) + self.outputs = {'Out': rst} + + def test_check_output(self): + self.check_output_with_place(core.CUDAPlace(0)) + + def test_check_grad(self): + self.check_grad_with_place(core.CUDAPlace(0), ["X"], "Out") + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestSoftmaxMaskFuseOp1(OpTest): + def setUp(self): + self.op_type = "softmax_mask_fuse_upper_triangle" + x = np.random.random((1, 1, 32, 32)) + self.inputs = {'X': x} + rst = _get_softmax_upper(x) + self.outputs = {'Out': rst} + + def test_check_output(self): + try: + self.check_output_with_place(core.CPUPlace()) + except NotImplementedError: + pass + + def test_check_grad(self): + try: + self.check_grad_with_place(core.CPUPlace(), ["X"], "Out") + except NotImplementedError: + pass + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestDropoutBiasFuseOp2(unittest.TestCase): + # test the python side API for softmax_mask_fuse op + def setUp(self): + np.random.seed(123) + self.dtypes = ['float16', 'float32'] + + def test_static(self): + for dtype in self.dtypes: + with fluid.program_guard(fluid.Program(), fluid.Program()): + input_x = fluid.data( + name="x", shape=[1, 1, 32, 32], dtype=dtype) + rst = incubate.softmax_mask_fuse_upper_triangle(input_x) + + x_in_np = np.random.random((1, 1, 32, 32)).astype(dtype) + rst_np = _get_softmax_upper(x_in_np, dtype == 'float16') + + exe = fluid.Executor(fluid.CUDAPlace(0)) + fetches = exe.run(fluid.default_main_program(), + feed={"x": x_in_np}, + fetch_list=[rst]) + self.assertTrue(np.allclose(fetches[0], rst_np)) + + def test_dygraph(self): + for dtype in self.dtypes: + with fluid.dygraph.guard(fluid.CUDAPlace(0)): + x_in_np = np.random.random((1, 1, 32, 32)).astype(dtype) + rst_np = _get_softmax_upper(x_in_np, dtype == 'float16') + input_x = fluid.dygraph.to_variable(x_in_np) + + rst = incubate.softmax_mask_fuse_upper_triangle(input_x) + self.assertTrue(np.allclose(rst, rst_np)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/incubate/__init__.py b/python/paddle/incubate/__init__.py index 22769053b1ac9..9b9797ede717e 100644 --- a/python/paddle/incubate/__init__.py +++ b/python/paddle/incubate/__init__.py @@ -16,7 +16,8 @@ from .optimizer import ModelAverage # noqa: F401 from .checkpoint import auto_checkpoint # noqa: F401 from ..fluid.layer_helper import LayerHelper # noqa: F401 +from .operators import softmax_mask_fuse_upper_triangle # noqa: F401 __all__ = [ # noqa - 'LookAhead', 'ModelAverage' + 'LookAhead', 'ModelAverage', 'softmax_mask_fuse_upper_triangle' ] diff --git a/python/paddle/incubate/operators/__init__.py b/python/paddle/incubate/operators/__init__.py new file mode 100644 index 0000000000000..026bf32d81250 --- /dev/null +++ b/python/paddle/incubate/operators/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .softmax_mask_fuse_upper_triangle import softmax_mask_fuse_upper_triangle # noqa: F401 diff --git a/python/paddle/incubate/operators/softmax_mask_fuse_upper_triangle.py b/python/paddle/incubate/operators/softmax_mask_fuse_upper_triangle.py new file mode 100644 index 0000000000000..b81ad4ecdc82a --- /dev/null +++ b/python/paddle/incubate/operators/softmax_mask_fuse_upper_triangle.py @@ -0,0 +1,42 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.framework import in_dygraph_mode +from paddle.fluid import core + + +def softmax_mask_fuse_upper_triangle(x): + """ + Fuse softmax mask together without even give a mask. + Under GPT model, the mask is always be a upper triangle + so we can simply mask the upper triangle part of x to get the mask result + :param x: the input x (rst of QK) + :return: the result of softmax mask fuse (upper triangle) + """ + if in_dygraph_mode(): + out = core.ops.softmax_mask_fuse_upper_triangle(x) + return out + + helper = LayerHelper('softmax_mask_fuse_upper_triangle', **locals()) + + out = helper.create_variable_for_type_inference(dtype=x.dtype) + + helper.append_op( + type='softmax_mask_fuse_upper_triangle', + inputs={'X': [x]}, + outputs={'Out': [out]}) + return out diff --git a/python/setup.py.in b/python/setup.py.in index 787317acb6d44..ba7ea88dd43b9 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -146,6 +146,7 @@ packages=['paddle', 'paddle.incubate', 'paddle.incubate.optimizer', 'paddle.incubate.checkpoint', + 'paddle.incubate.operators', 'paddle.distributed.fleet', 'paddle.distributed.fleet.base', 'paddle.distributed.fleet.meta_optimizers',