Skip to content

Commit

Permalink
Add fused_attention_op: add impl wrappers. (#35903)
Browse files Browse the repository at this point in the history
  • Loading branch information
limin2021 committed Oct 24, 2021
1 parent 6840cf5 commit a9618bf
Show file tree
Hide file tree
Showing 6 changed files with 487 additions and 8 deletions.
3 changes: 2 additions & 1 deletion paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 2, false> {

template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 3, false> {
__device__ inline void operator()(Functor func, InT **args, OutT *result) {
__device__ inline void operator()(Functor func, InT (*args)[VecSize],
OutT *result) {
kps::ElementwiseTernary<InT, OutT, VecSize, 1, 1, Functor>(
result, args[0], args[1], args[2], func);
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/fused/attention_layer_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class AttnLayerNorm {
}
}

void ComputeBackward(const T* x_data const T* y_data,
void ComputeBackward(const T* x_data, const T* d_y_data,
const LayerNormParamType<T>* scale_data,
const LayerNormParamType<T>* mean_data,
const LayerNormParamType<T>* var_data, T* d_x_data,
Expand Down
6 changes: 1 addition & 5 deletions paddle/fluid/operators/fused/attn_bias_add.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ namespace cub = hipcub;
#define LAUNCH_BOUNDS(BlockDim)
#endif

#include "paddle/fluid/operators/elementwise/elementwise_functor.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
Expand All @@ -51,11 +52,6 @@ using CudnnDataType = platform::CudnnDataType<T>;
template <typename T>
using ReduceParamType = typename CudnnDataType<T>::BatchNormParamType;

template <typename T>
struct AddFunctor {
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a + b; }
};

template <typename InT, typename OutT, int ShapeSize, int VecSize,
int DATA_PER_THREAD, typename Functor>
__global__ void BroadcastKernelBinary(
Expand Down
159 changes: 159 additions & 0 deletions paddle/fluid/operators/fused/attn_gemm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/fluid/operators/fused/attn_bias_add.cu.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/float16.h"

namespace paddle {
namespace operators {

// support gemm-nt and gemm-nn, which is used in fused_attention_op.
template <typename T>
class AttnMatMul {
public:
// (m, n, k) = bsz_seq, output_size, input_size
AttnMatMul(const platform::CUDADeviceContext& dev_ctx, bool transA,
bool transB, int bsz_seq, int output_size, int input_size,
bool compute_bias)
: dev_ctx_(dev_ctx),
transA_(transA),
transB_(transB),
bsz_seq_(bsz_seq),
output_size_(output_size),
input_size_(input_size),
compute_bias_(compute_bias) {}

~AttnMatMul() {}

void ComputeForward(const T* weight_data, const T* input_data,
const T* bias_data, T* output_data, T* bias_out_data) {
// Note: for blas.GEMM API in Paddle, it treats all inputs as row-major.
// here: (transa, transb): nt, input * weight.
CBLAS_TRANSPOSE transA = CblasNoTrans;
CBLAS_TRANSPOSE transB = CblasNoTrans;
if (transA_) {
transA = CblasTrans;
}
if (transB_) {
transB = CblasTrans;
}
T alpha = static_cast<T>(1.0);
T beta = static_cast<T>(0.0);

// here: (m, n, k) = bsz_seq, output_size, input_size, (input, weight, out)
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(dev_ctx_);
blas.GEMM(transA, transB, bsz_seq_, output_size_, input_size_, alpha,
input_data, weight_data, beta, output_data);
if (compute_bias_) {
// compute output + bias
LaunchBiasAddFwKernel(dev_ctx_, bsz_seq_, output_size_, output_data,
bias_data, bias_out_data);
}
}

void ComputeBackward(const T* input, const T* weight, const T* d_output,
T* d_input, T* d_weight, T* d_bias) {
T alpha = static_cast<T>(1.0);
T beta = static_cast<T>(0.0);
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(dev_ctx_);

CBLAS_TRANSPOSE dB_transA = CblasNoTrans;
CBLAS_TRANSPOSE dB_transB = CblasNoTrans;
CBLAS_TRANSPOSE dA_transA = CblasNoTrans;
CBLAS_TRANSPOSE dA_transB = CblasNoTrans;
int dB_m = 1;
int dB_n = 1;
int dB_k = 1;
int dA_m = 1;
int dA_n = 1;
int dA_k = 1;

T* dB_input_1_ptr = nullptr;
T* dB_input_2_ptr = nullptr;
T* dB_output_ptr = d_weight;

T* dA_input_1_ptr = nullptr;
T* dA_input_2_ptr = nullptr;
T* dA_output_ptr = d_input;

if (!transA_) {
// fw: gemm-nt
if (transB_) {
// bw: gemm-tn, dB = (dC)^t * A
dB_transA = CblasTrans;
dB_transB = CblasNoTrans;
dB_m = output_size_;
dB_n = input_size_;
dB_k = bsz_seq_;

// bw: gemm-nn, dA = dC * B
dA_transA = CblasNoTrans;
dA_transB = CblasNoTrans;
dA_m = bsz_seq_;
dA_n = input_size_;
dA_k = output_size_;

blas.GEMM(dB_transA, dB_transB, dB_m, dB_n, dB_k, alpha, d_output,
input, beta, dB_output_ptr);
blas.GEMM(dA_transA, dA_transB, dA_m, dA_n, dA_k, alpha, d_output,
weight, beta, dA_output_ptr);
} else { // fw: gemm-nn
// bw: gemm-tn, dB = A^t * dC
dB_transA = CblasTrans;
dB_transB = CblasNoTrans;
dB_m = input_size_;
dB_n = output_size_;
dB_k = bsz_seq_;

// bw: gemm-nt, dA = dC * B^t
dA_transA = CblasNoTrans;
dA_transB = CblasTrans;
dA_m = bsz_seq_;
dA_n = input_size_;
dA_k = output_size_;

blas.GEMM(dB_transA, dB_transB, dB_m, dB_n, dB_k, alpha, input,
d_output, beta, dB_output_ptr);
blas.GEMM(dA_transA, dA_transB, dA_m, dA_n, dA_k, alpha, d_output,
weight, beta, dA_output_ptr);
}
} else if (transB_) {
PADDLE_THROW(platform::errors::InvalidArgument(
"AttnMatMul wrapper do not support (transA=T, transB=T)"
"parameters."));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"AttnMatMul wrapper do not support (transA=T, transB=N)"
"parameters."));
}
if (compute_bias_) {
LaunchBiasAddBwKernel(dev_ctx_, bsz_seq_, output_size_, d_output, d_bias);
}
}

private:
const platform::CUDADeviceContext& dev_ctx_;

bool transA_;
bool transB_;

int bsz_seq_;
int output_size_;
int input_size_;

int compute_bias_;
};

} // namespace operators
} // namespace paddle
Loading

1 comment on commit a9618bf

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.