Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Infer Symbolic Shape No.208】[BUAA] Add masked_multihead_attention_ op #67861

Merged
merged 12 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1747,12 +1747,63 @@ bool NearestInterpOpInferSymbolicShape(
return BicubicInterpOpInferSymbolicShape(op, infer_context);
}

// bool MaskedMultiheadAttention_OpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext
// *infer_context) {
// // pass
// return true;
// }
bool MaskedMultiheadAttentionOpInferSymbolicShape(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没有不带下划线版本的op了,yaml中的op名是masked_multihead_attention_

pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const symbol::ShapeOrDataDimExprs &x_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(0));
const symbol::ShapeOrDataDimExprs &cache_kv_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(1));

const std::vector<symbol::DimExpr> &x_shape = x_shape_or_data.shape();
const std::vector<symbol::DimExpr> &cache_kv_shape =
cache_kv_shape_or_data.shape();

std::string compute_dtype =
op->attribute<pir::StrAttribute>("compute_dtype").AsString();

PADDLE_ENFORCE_EQ(
cache_kv_shape.size(),
5,
phi::errors::InvalidArgument("The cache_kv must be 5 dims."));
infer_context->AddEqualCstr(cache_kv_shape[0], symbol::DimExpr(2));
// TODO(Luohongzhige, Buaa): add constrain for the num_head and k_num_head

symbol::DimExpr bsz = x_shape[0];
symbol::DimExpr dim_head = cache_kv_shape[4];
symbol::DimExpr k_num_head = cache_kv_shape[2];
symbol::DimExpr v_num_head = k_num_head;
symbol::DimExpr num_head =
(x_shape[x_shape.size() - 1] / dim_head - k_num_head - v_num_head);
std::vector<symbol::DimExpr> out_shape = {bsz, num_head * dim_head};

infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(out_shape)});

infer_context->SetShapeOrDataForValue(
op->result(1),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(cache_kv_shape)});

if (op->operand_source(7) != nullptr) {
const symbol::ShapeOrDataDimExprs &beam_cache_offset_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(7));
const std::vector<symbol::DimExpr> &beam_cache_offset_shape =
beam_cache_offset_shape_or_data.shape();
infer_context->SetShapeOrDataForValue(
op->result(2),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(beam_cache_offset_shape)});
}

return true;
}

bool MaskedMultiheadAttention_OpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
return MaskedMultiheadAttentionOpInferSymbolicShape(op, infer_context);
}

bool MemoryEfficientAttentionOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Linspace)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(LinearInterp)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logspace)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Lstm)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(MaskedMultiheadAttention_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MaskedMultiheadAttention_)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(MergedAdam)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(MergedAdam_)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(MergedMomentum)
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3081,7 +3081,7 @@
data_type : x
optional : bias, src_mask, cum_offsets, sequence_lengths, rotary_tensor, beam_cache_offset, qkv_out_scale, out_shift, out_smooth
inplace : (cache_kv -> cache_kv_out), (beam_cache_offset -> beam_cache_offset_out)
# interfaces : paddle::dialect::InferSymbolicShapeInterface
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : masked_select
args : (Tensor x, Tensor mask)
Expand Down