Skip to content

Commit

Permalink
Zhalei/gated rel position bias (#14553)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: Lei Zhang <zhang.huanning@hotmail.com>
  • Loading branch information
2 people authored and Ubuntu committed Feb 3, 2023
1 parent 0d44a7b commit da875fa
Show file tree
Hide file tree
Showing 8 changed files with 566 additions and 4 deletions.
121 changes: 120 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,15 @@

#include "core/providers/cuda/cuda_common.h"
#include "relative_attn_bias.h"
#include "core/common/safeint.h"
#include "relative_attn_bias_impl.h"
#include "core/providers/cuda/shared_inc/fpgeneric.h"
#include "contrib_ops/cuda/bert/add_bias_transpose.h"

using namespace onnxruntime::cuda;
using namespace ::onnxruntime::common;
using namespace ONNX_NAMESPACE;


namespace onnxruntime {
namespace contrib {
Expand All @@ -20,7 +28,16 @@ namespace cuda {
.InputMemoryType(OrtMemTypeCPUInput, 1) \
.InputMemoryType(OrtMemTypeCPUInput, 2) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
RelPosAttnBias<T>);
RelPosAttnBias<T>); \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
GatedRelativePositionBias, \
kMSDomain, \
1, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
GatedRelativePositionBias<T>);

REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)
Expand Down Expand Up @@ -69,6 +86,108 @@ Status RelPosAttnBias<T>::ComputeInternal(OpKernelContext* context) const {
device_prop.maxThreadsPerBlock);
}

template <typename T>
GatedRelativePositionBias<T>::GatedRelativePositionBias(const OpKernelInfo& info) : CudaKernel(info) {
int64_t num_heads = 0;
ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
num_heads_ = SafeInt<int>(num_heads);
}

template <typename T>
Status GatedRelativePositionBias<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor& query_tensor = *context->Input<Tensor>(0);
const Tensor& query_bias_tensor = *context->Input<Tensor>(1);
const Tensor& rel_pos_tensor = *context->Input<Tensor>(2);
const Tensor& weight_tensor = *context->Input<Tensor>(3);
const Tensor& bias_tensor = *context->Input<Tensor>(4);
const Tensor& eco_a_tensor = *context->Input<Tensor>(5);

const auto& query_dims = query_tensor.Shape().GetDims();
ORT_ENFORCE(query_dims.size() == 3);
ORT_ENFORCE(query_dims[2] > 0);
ORT_ENFORCE(query_dims[2] % num_heads_ == 0);
const auto batch_size = SafeInt<int>(query_dims[0]);
const auto seq_len = SafeInt<int>(query_dims[1]);
const auto head_size = SafeInt<int>(query_dims[2] / num_heads_);

ORT_ENFORCE(query_bias_tensor.Shape().NumDimensions() == 1);
ORT_ENFORCE(query_bias_tensor.Shape()[0] == query_dims[2]);

const auto& rel_pos_dims = rel_pos_tensor.Shape().GetDims();
ORT_ENFORCE(rel_pos_dims.size() == 4);
ORT_ENFORCE(rel_pos_dims[0] == 1);
ORT_ENFORCE(rel_pos_dims[1] == num_heads_);
ORT_ENFORCE(rel_pos_dims[2] == seq_len);
ORT_ENFORCE(rel_pos_dims[3] == seq_len);

const auto& weight_dims = weight_tensor.Shape().GetDims();
ORT_ENFORCE(weight_dims.size() == 2);
ORT_ENFORCE(weight_dims[0] == head_size);
ORT_ENFORCE((weight_dims[1] > 0) && (weight_dims[1] % 2 == 0));

ORT_ENFORCE(bias_tensor.Shape().NumDimensions() == 1);
ORT_ENFORCE(bias_tensor.Shape()[0] == weight_dims[1]);

const auto D = SafeInt<int>(weight_dims[1]);

const auto& eco_a_dims = eco_a_tensor.Shape().GetDims();
ORT_ENFORCE(eco_a_dims.size() == 4);
ORT_ENFORCE(eco_a_dims[0] == 1);
ORT_ENFORCE(eco_a_dims[1] == num_heads_);
ORT_ENFORCE(eco_a_dims[2] == 1);
ORT_ENFORCE(eco_a_dims[3] == 1);

Tensor* output = context->Output(0, {batch_size, num_heads_, seq_len, seq_len});

auto& device_prop = GetDeviceProp();
cublasHandle_t cublas = GetCublasHandle(context);

typedef typename ToCudaType<T>::MappedType CudaT;
const auto BNS = batch_size * num_heads_ * seq_len;
const size_t elements_in_query = (size_t)BNS * (size_t)head_size;
const size_t elements_after_gemm = (size_t)BNS *(size_t)D;
size_t workspace_size = sizeof(T) * (elements_in_query + (seq_len < D) ? elements_after_gemm : (size_t)0);
auto workspace = GetScratchBuffer<void>(workspace_size, context->GetComputeStream());

// format 1: BxSx(NH * total_matrix) => matrix_to_transpose * (BxNxSxH)
constexpr int format = 1;
constexpr int total_maxtrix = 1;
constexpr int num_matrix_to_transpose = 1;
LaunchAddBiasTranspose(Stream(context), num_matrix_to_transpose, format, device_prop.maxThreadsPerBlock,
batch_size, seq_len, num_heads_, head_size,
reinterpret_cast<const CudaT*>(query_tensor.template Data<T>()),
reinterpret_cast<const CudaT*>(query_bias_tensor.template Data<T>()),
reinterpret_cast<CudaT*>(workspace.get()),
false, head_size, reinterpret_cast<CudaT*>(nullptr), total_maxtrix);

// reuse output if possible
CudaT* gemm_output = (seq_len < D) ? (reinterpret_cast<CudaT*>(workspace.get()) + elements_in_query)
: reinterpret_cast<CudaT*>(output->template MutableData<T>());
int ld_gemm_output = max(seq_len, D);

const CudaT one = ToCudaType<T>::FromFloat(1.0f);
const CudaT zero = ToCudaType<T>::FromFloat(0.0f);

// ([b*n*s, h] * [h, D]), CUDA assumes col-major
CUBLAS_RETURN_IF_ERROR(cublasGemmHelper(
cublas, CUBLAS_OP_N, CUBLAS_OP_N,
D, BNS, head_size, &one,
reinterpret_cast<const CudaT*>(weight_tensor.template Data<T>()), (int)D,
reinterpret_cast<const CudaT*>(workspace.get()), (int)head_size,
&zero, gemm_output, ld_gemm_output, device_prop));

auto status = LaunchGatedRelativePositionBiasKernel<CudaT>(
device_prop, Stream(context),
reinterpret_cast<CudaT*>(output->template MutableData<T>()),
reinterpret_cast<const CudaT*>(rel_pos_tensor.template Data<T>()),
reinterpret_cast<const CudaT*>(gemm_output),
reinterpret_cast<const CudaT*>(bias_tensor.template Data<T>()),
reinterpret_cast<const CudaT*>(eco_a_tensor.template Data<T>()),
batch_size, num_heads_, seq_len, D, ld_gemm_output);

return status;
}

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
12 changes: 12 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,18 @@ class RelPosAttnBias final : public CudaKernel {
bool is_bidirectional_;
};

template <typename T>
class GatedRelativePositionBias final : public CudaKernel {
public:
GatedRelativePositionBias(const OpKernelInfo& op_kernel_info);

Status ComputeInternal(OpKernelContext* ctx) const override;

private:
int num_heads_;
};


} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
116 changes: 116 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,122 @@ template Status LaunchRelPosAttnBiasKernel<half>(cudaStream_t stream,
const bool is_bidirectional,
const int max_threads_per_block);

template <typename T>
__global__ void GatedRelativePositionBiasKernelSmallD(
T* output, // (batch_size, num_heads, seq_len, seq_len)
const T* rel_pos, // (1, num_heads, seq_len, seq_len)
const T* qw, // (batch_size, num_heads, seq_len, D)
const T* bias, // (D)
const T* eco_a, // (1, num_heads, 1, 1)
const int D,
const int ldqw) {
__shared__ float gate[1];

const int seq_len = gridDim.x;
const int num_heads = gridDim.y;
const int s = blockIdx.x;
const int n = blockIdx.y;
const int b = blockIdx.z;

rel_pos += ((int64_t)n * seq_len + s) * seq_len;
output += ((int64_t)b * num_heads * seq_len + (int64_t)n * seq_len + s) * seq_len;
qw += ((int64_t)b * num_heads * seq_len + (int64_t)n * seq_len + s) * ldqw;

float val = 0.0f;
if (threadIdx.x < D) {
val = (float)qw[threadIdx.x] + (bias ? (float)bias[threadIdx.x] : 0.0f);
}

float u = (threadIdx.x < D / 2) ? val : 0.0f;
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2) {
u += __shfl_down_sync(0xffffffff, u, offset);
}

float r = (threadIdx.x >= D / 2) ? val : 0.0f;
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2) {
r += __shfl_down_sync(0xffffffff, r, offset);
}

if (threadIdx.x == 0) {
u = 1.0f / (1.0f + expf(-u));
r = 1.0f / (1.0f + expf(-r));
gate[0] = u * (r * (float)eco_a[n] - 1.0f) + 2.0f;
}
__syncthreads();

for (int idx = threadIdx.x; idx < seq_len; idx += blockDim.x) {
output[idx] = (T)(gate[0] * (float)rel_pos[idx]);
}
}

template <typename T>
Status LaunchGatedRelativePositionBiasKernel(
const cudaDeviceProp& device_prop,
cudaStream_t stream,
T* output,
const T* rel_pos,
const T* qw, // query * weight
const T* bias,
const T* eco_a,
const int batch_size,
const int num_heads,
const int seq_len,
const int D,
const int ldqw) {
ORT_ENFORCE(D <= 32 && D > 0 && (D % 2 == 0));
ORT_ENFORCE(ldqw == seq_len || ldqw == D);

int tpb = std::max(32, std::max(D, seq_len));
tpb = std::min(tpb, device_prop.maxThreadsPerBlock);

// round up tpb to power of 2
--tpb;
tpb |= (tpb >> 1);
tpb |= (tpb >> 2);
tpb |= (tpb >> 4);
tpb |= (tpb >> 8);
tpb |= (tpb >> 16);
tpb++;

dim3 block(tpb);
dim3 grid(seq_len, num_heads, batch_size);

GatedRelativePositionBiasKernelSmallD<<<grid, block, sizeof(float), stream>>>(
output, rel_pos, qw, bias, eco_a, D, ldqw);

return CUDA_CALL(cudaGetLastError());
}

template Status LaunchGatedRelativePositionBiasKernel(
const cudaDeviceProp& device_prop,
cudaStream_t stream,
float* output,
const float* rel_pos,
const float* qw,
const float* bias,
const float* eco_a,
const int batch_size,
const int num_heads,
const int seq_len,
const int D,
const int ldqw);

template Status LaunchGatedRelativePositionBiasKernel(
const cudaDeviceProp& device_prop,
cudaStream_t stream,
half* output,
const half* rel_pos,
const half* qw,
const half* bias,
const half* eco_a,
const int batch_size,
const int num_heads,
const int seq_len,
const int D,
const int ldqw);

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
15 changes: 15 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,21 @@ Status LaunchRelPosAttnBiasKernel(
const int max_threads_per_block
);

template <typename T>
Status LaunchGatedRelativePositionBiasKernel(
const cudaDeviceProp& device_prop,
cudaStream_t stream,
T* output,
const T* rel_pos,
const T* qw, // from query * weight
const T* bias,
const T* eco_a,
const int batch_size,
const int num_heads,
const int seq_len,
const int D,
const int ldqw);

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
4 changes: 4 additions & 0 deletions onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, FusedMatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RelativePositionBias);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RelativePositionBias);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GatedRelativePositionBias);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GatedRelativePositionBias);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RemovePadding);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RemovePadding);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RestorePadding);
Expand Down Expand Up @@ -155,6 +157,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, FusedMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RelativePositionBias)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RelativePositionBias)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GatedRelativePositionBias)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GatedRelativePositionBias)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RemovePadding)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RemovePadding)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RestorePadding)>,
Expand Down
34 changes: 34 additions & 0 deletions onnxruntime/core/graph/contrib_ops/bert_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -663,5 +663,39 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
RestorePaddingTypeAndShapeInference(ctx);
}));

constexpr const char* GatedRelativePositionBias_ver1_doc = R"DOC(
query_layer = (query_layer + query_bias).reshape(batch_size, seq_len, num_heads, head_size).transpose(1, 2)
gate_u, gate_r = torch.sigmoid(
self.gate_ur_linear(query_layer).view(batch_size, num_head, seq_len, 2, D/2).sum(-1, keepdim=False)
).chunk(2, dim=-1)
gate_u_1 = gate_u * (gate_r * self.eco_a - 1.0) + 2.0
rel_pos_bias = gate_u_1 * rel_pos
)DOC";

ONNX_MS_OPERATOR_SET_SCHEMA(
GatedRelativePositionBias, 1,
OpSchema()
.SetDoc(GatedRelativePositionBias_ver1_doc)
.Attr("num_heads", "Number of attention heads", AttributeProto::INT)
.Input(0, "query_layer", "tensor with shape (batch_size, seq_len, num_heads x head_size)", "T")
.Input(1, "query_bias", "1-d tensor with shape (num_heads x head_size)", "T")
.Input(2, "rel_pos", "tensor with shape (1, num_head, seq_len, seq_len)", "T")
.Input(3, "weight", "gemm weight for the gated_ur_linear, shape (head_size, D), D is divisible by 2", "T")
.Input(4, "bias", "bias for the gated_ur_linear, shape (D)", "T")
.Input(5, "eco_a", "tensor of shape (1, num_heads, 1, 1)", "T")
.Output(0, "output", "output tensor with shape (batch_size, num_heads, seq_len, seq_len)", "T")
.TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
int64_t num_heads = getAttribute(ctx, "num_heads", -1L);
auto& query_layer_shape = getInputShape(ctx, 0);
TensorShapeProto output_shape;
*output_shape.add_dim() = query_layer_shape.dim(0);
output_shape.add_dim()->set_dim_value(num_heads);
*output_shape.add_dim() = query_layer_shape.dim(1);
*output_shape.add_dim() = query_layer_shape.dim(1);
updateOutputShape(ctx, 0, output_shape);
}));

} // namespace contrib
} // namespace onnxruntime
2 changes: 2 additions & 0 deletions onnxruntime/core/graph/contrib_ops/ms_opset.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MurmurHash3);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, NGramRepeatBlock);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Pad);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RelativePositionBias);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GatedRelativePositionBias);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RemovePadding);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RestorePadding);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Rfft);
Expand Down Expand Up @@ -167,6 +168,7 @@ class OpSet_Microsoft_ver1 {
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QAttention)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QEmbedLayerNormalization)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RelativePositionBias)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GatedRelativePositionBias)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RemovePadding)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RestorePadding)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Rfft)>());
Expand Down
Loading

0 comments on commit da875fa

Please sign in to comment.