Skip to content

Commit

Permalink
add infer_symbol_shape for memory_efficient_attention (#63999)
Browse files Browse the repository at this point in the history
  • Loading branch information
zyfncg authored Apr 30, 2024
1 parent 6bf12c7 commit 5fd3e2e
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,74 @@ bool NearestInterpOpInferSymbolicShape(
return BicubicInterpOpInferSymbolicShape(op, shape_analysis);
}

bool MemoryEfficientAttentionOpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
const auto &q_shape =
shape_analysis->GetShapeOrDataForValue(op->operand_source(0)).shape();
const auto &k_shape =
shape_analysis->GetShapeOrDataForValue(op->operand_source(1)).shape();
const auto &v_shape =
shape_analysis->GetShapeOrDataForValue(op->operand_source(2)).shape();
PADDLE_ENFORCE_EQ(
q_shape.size(),
4,
phi::errors::InvalidArgument("Query should be a 4-D tensor"
"But received Query dimension(%d)",
q_shape.size()));
PADDLE_ENFORCE_EQ(
k_shape.size(),
4,
phi::errors::InvalidArgument("Key should be a 4-D tensor"
"But received Key dimension(%d)",
k_shape.size()));
PADDLE_ENFORCE_EQ(
v_shape.size(),
4,
phi::errors::InvalidArgument("Value should be a 4-D tensor"
"But received Value dimension(%d)",
v_shape.size()));

const auto &query_batch_size = q_shape[0];
const auto &query_seq_length = q_shape[1];
const auto &query_num_head = q_shape[2];
const auto &query_head_size = q_shape[3];

const auto &key_batch_size = k_shape[0];
const auto &key_seq_length = k_shape[1];
const auto &key_num_head = k_shape[2];
const auto &key_head_size = k_shape[3];

const auto &value_batch_size = v_shape[0];
const auto &value_seq_length = v_shape[1];
const auto &value_num_head = v_shape[2];
const auto &value_head_size = v_shape[3];

shape_analysis->AddEqualCstr(query_batch_size, key_batch_size);
shape_analysis->AddEqualCstr(key_batch_size, value_batch_size);

shape_analysis->AddEqualCstr(query_num_head, key_num_head);
shape_analysis->AddEqualCstr(key_num_head, value_num_head);

shape_analysis->AddEqualCstr(query_head_size, key_head_size);

shape_analysis->AddEqualCstr(key_seq_length, value_seq_length);

const std::vector<symbol::DimExpr> out_dims{
query_batch_size, query_seq_length, query_num_head, value_head_size};
const std::vector<symbol::DimExpr> logsumexp_dims{query_num_head,
query_batch_size};
const std::vector<symbol::DimExpr> seed_and_offset_dims{2};

shape_analysis->SetShapeOrDataForValue(
op->result(0), symbol::TensorShapeOrDataDimExprs(out_dims));
shape_analysis->SetShapeOrDataForValue(
op->result(1), symbol::TensorShapeOrDataDimExprs(logsumexp_dims));
shape_analysis->SetShapeOrDataForValue(
op->result(2), symbol::TensorShapeOrDataDimExprs(seed_and_offset_dims));

return true;
}

bool MeshgridOpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
const symbol::TensorListShapeOrDataDimExprs &shape_data_list =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(GroupNorm)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Linspace)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(LinearInterp)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logspace)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MemoryEfficientAttention)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Meshgrid)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(NearestInterp)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Stack)
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2007,6 +2007,7 @@
data_type : query
optional : bias, cu_seqlens_q, cu_seqlens_k, causal_diagonal, seqlen_k
backward : memory_efficient_attention_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : merge_selected_rows
args : (Tensor x)
Expand Down

0 comments on commit 5fd3e2e

Please sign in to comment.