Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XPU] support cross attention for decoder model #63203

Merged
merged 4 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,8 @@ if(WITH_XPU)
${XPU_PASS_DEPS})
pass_library(decoder_attention_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(cross_attention_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(multi_encoder_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(multi_encoder_xpu_adaptive_seqlen_fuse_pass inference DIR xpu
Expand Down
666 changes: 666 additions & 0 deletions paddle/fluid/framework/ir/xpu/cross_attention_xpu_fuse_pass.cc

Large diffs are not rendered by default.

126 changes: 126 additions & 0 deletions paddle/fluid/framework/ir/xpu/cross_attention_xpu_fuse_pass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h"

namespace phi {
class DenseTensor;
} // namespace phi

namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle

namespace paddle {
namespace framework {
namespace ir {

/*
This pass is used to fuse the cross attention op into one op in decoder.
models .

Origin subgraph:

mask input_q input_kv
| | | |
| | |-----------|
| matmul matmul matmul
| |q |k |v
| | | |
| | | |
| add add add
| | | |
| | | |
| reshape reshape reshape
| | | |
| | | |
| transpose transpose transpose
| | | |
| | | |
| (scale) | |
| | | |
\ |(x) |(y) |
\ \ / |
\ qk_matmul |
\ | |
\ | |
add /
| /
| /
softmax /
\ /
\ /
qkv_matmul
|
|
transpose
|
|
reshape
|
|
output

-------------------------------------------------------
Fused subgraph:
input_q input_kv
| |
| |
| |
cross_attention_xpu
|
|
|
output

*/

class CrossAttentionXPUFusePass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;

private:
void ApplyCrossAttentionXPUFuse(ir::Graph* graph, bool with_q_scale) const;

// 1. Generate q/k/v_w_max tensor
// 2. Quant q/k/v_w to int16
void PrepareQKVWeight(Graph* graph,
Scope* scope,
BlockDesc* block,
Node* w,
Node** real_w,
Node** w_max) const;

// Cast fc_bias to fp32
void PrepareQKVBias(Graph* graph,
Scope* scope,
BlockDesc* block,
Node* q_bias,
Node* k_bias,
Node* v_bias,
Node** real_q_bias,
Node** real_k_bias,
Node** real_v_bias) const;

const std::string name_scope_{"cross_attention_xpu_fuse_pass"};
};

} // namespace ir
} // namespace framework
} // namespace paddle
Original file line number Diff line number Diff line change
Expand Up @@ -163,15 +163,15 @@ DecoderAttentionFusePattern::DecoderAttentionFusePattern(

// link nodes
reshape2_1->LinksFrom({input_q}).LinksTo({reshape2_1_out});
reshape2_2->LinksFrom({input_k}).LinksTo({reshape2_2_out});
reshape2_3->LinksFrom({input_v}).LinksTo({reshape2_3_out});
transpose2_1->LinksFrom({reshape2_1_out}).LinksTo({transpose2_1_out});
reshape2_2->LinksFrom({input_k}).LinksTo({reshape2_2_out});
transpose2_2->LinksFrom({reshape2_2_out}).LinksTo({transpose2_2_out});
transpose2_3->LinksFrom({reshape2_3_out}).LinksTo({transpose2_3_out});
qk_matmul->LinksFrom({transpose2_1_out, transpose2_2_out})
.LinksTo({qk_matmul_out});
scale->LinksFrom({qk_matmul_out}).LinksTo({scale_out});
qk_softmax->LinksFrom({scale_out}).LinksTo({qk_softmax_out});
reshape2_3->LinksFrom({input_v}).LinksTo({reshape2_3_out});
transpose2_3->LinksFrom({reshape2_3_out}).LinksTo({transpose2_3_out});
qkv_matmul->LinksFrom({qk_softmax_out, transpose2_3_out})
.LinksTo({qkv_matmul_out});
transpose2_4->LinksFrom({qkv_matmul_out}).LinksTo({transpose2_4_out});
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"multi_encoder_xpu_slice_fuse_pass",
"fused_multi_transformer_cachekv_layout_trans_pass",
"fused_multi_transformer_int8_cachekv_layout_trans_pass",
"cross_attention_xpu_fuse_pass",
"decoder_attention_xpu_fuse_pass",
"one_beam_size_fuse_pass",
"fold_interp_outsize_fuse_pass",
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/api/yaml/fused_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,15 @@
data_type : x
optional : bias, branch, branch_max ,x_max, scale_max, out_max_in

- op : cross_attention_xpu
args : (Tensor input_q, Tensor input_kv, Tensor[] fc_weight, Tensor[] fc_weight_max, Tensor[] fc_bias, Tensor mask, int head_num, int head_dim, float alpha, DataType out_dtype)
output : Tensor(qkv), Tensor(qkv_max)
infer_meta :
func : CrossAttentionXPUInferMeta
kernel :
func : cross_attention_xpu
data_type : input_q

- op : dequantize_xpu
args : (Tensor x, DataType out_dtype, float scale = 1.0f)
output : Tensor(y)
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/backends/xpu/xpu2_op_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1229,6 +1229,8 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})},
{"roformer_relative_embedding_xpu",
XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})},
{"cross_attention_xpu",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"variable_length_memory_efficient_attention",
XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})},
{"flash_attn_unpadded",
Expand Down
89 changes: 89 additions & 0 deletions paddle/phi/infermeta/fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3828,6 +3828,95 @@ void SinePosXPUInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype());
}

void CrossAttentionXPUInferMeta(
const MetaTensor& input_q,
const MetaTensor& input_kv,
const std::vector<const MetaTensor*>& fc_weight,
const std::vector<const MetaTensor*>& fc_weight_max,
const std::vector<const MetaTensor*>& fc_bias,
const MetaTensor& mask,
int head_num,
int head_dim,
float alpha,
DataType out_dtype,
MetaTensor* qkv,
MetaTensor* qkv_max) {
auto input_q_dims = input_q.dims();
auto input_kv_dims = input_kv.dims();
auto mask_dims = mask.dims();
// input shape : {B, L, H*D}
PADDLE_ENFORCE_EQ(input_q_dims.size(),
3,
phi::errors::InvalidArgument(
"The dim of input_q should be 3! But received ",
input_q_dims.size()));
PADDLE_ENFORCE_EQ(input_kv_dims.size(),
3,
phi::errors::InvalidArgument(
"The dim of input_kv should be 3! But received ",
input_kv_dims.size()));
// sequece length of q and k/v not requied to be eqaul
// but batch size and dim should be the same
PADDLE_ENFORCE_EQ(
input_q_dims[0],
input_kv_dims[0],
phi::errors::InvalidArgument("The batch size of input_q and input_kv "
"should be the same! Received ",
input_q_dims[0],
" vs ",
input_kv_dims[0]));
PADDLE_ENFORCE_EQ(
input_q_dims[2],
input_kv_dims[2],
phi::errors::InvalidArgument("The hidden_dim of input_q and input_kv "
"should be the same! Received ",
input_q_dims[2],
" vs ",
input_kv_dims[2]));
int hidden_dim = head_num * head_dim;
PADDLE_ENFORCE_EQ(
input_q_dims[2],
hidden_dim,
phi::errors::InvalidArgument(
"The last dimension of input_q should be [H*D]! Received ",
input_q_dims[2],
" != expected ",
hidden_dim));
PADDLE_ENFORCE_EQ(fc_weight.size(),
3,
phi::errors::InvalidArgument(
"The size of fc_weight should be 3! But received ",
fc_weight.size()));
PADDLE_ENFORCE_EQ(fc_weight_max.size(),
3,
phi::errors::InvalidArgument(
"The size of fc_weight_max should be 3! But received ",
fc_weight_max.size()));
PADDLE_ENFORCE_EQ(
fc_bias.size(),
3,
phi::errors::InvalidArgument(
"The size of fc_bias should be 3! But received ", fc_bias.size()));
PADDLE_ENFORCE_LE(
mask_dims.size(),
4,
phi::errors::InvalidArgument(
"The dim of mask should be not greater than 4!", mask_dims.size()));

// output shape: {B, qL, H*D}
qkv->set_dims(
phi::make_ddim({input_q_dims[0], input_q_dims[1], head_num * head_dim}));
qkv->set_dtype(out_dtype);
qkv->set_layout(input_q.layout());
// TODO(Terry) optmize the max value num
// unable to pass few PR-CIs, so just use a constant value
// int xpu2_max_value_num = phi::backends::xpu::get_xpu_max_ptr_size(-1);
const int xpu2_max_value_num = 6;
qkv_max->set_dims(phi::make_ddim({xpu2_max_value_num}));
qkv_max->set_dtype(out_dtype);
qkv_max->set_layout(input_q.layout());
}

void MultiGruInferMeta(
const MetaTensor& x,
const std::vector<const MetaTensor*>& weight_x,
Expand Down
13 changes: 13 additions & 0 deletions paddle/phi/infermeta/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,19 @@ void RoformerRelativePosXPUInferMeta(const MetaTensor& x,
const MetaTensor& cos_emb,
int max_pos_len,
MetaTensor* out);
void CrossAttentionXPUInferMeta(
const MetaTensor& input_q,
const MetaTensor& input_kv,
const std::vector<const MetaTensor*>& fc_weight,
const std::vector<const MetaTensor*>& fc_weight_max,
const std::vector<const MetaTensor*>& fc_bias,
const MetaTensor& mask,
int head_num,
int head_dim,
float alpha,
DataType out_dtype,
MetaTensor* qkv,
MetaTensor* qkv_max);

void MultiGruInferMeta(
const MetaTensor& x,
Expand Down
Loading