Skip to content

Commit

Permalink
fix FlashAttnOpInferSymbolicShape and FlashAttnInferMeta (#63816)
Browse files Browse the repository at this point in the history
* fix FlashAttnOpInferSymbolicShape and FlashAttnInferMeta

* use int for simple value

* add check and constraint

* fix

* add shape constraint for attention_mask
  • Loading branch information
Hongqing-work authored Apr 26, 2024
1 parent 3289d25 commit 52db8e4
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -278,15 +278,67 @@ bool FlashAttnOpInferSymbolicShape(
const symbol::ShapeOrDataDimExprs &q =
shape_analysis->GetShapeOrDataForValue(operand_source);

const symbol::ShapeOrDataDimExprs &k =
shape_analysis->GetShapeOrDataForValue(op->operand_source(1));

const symbol::ShapeOrDataDimExprs &v =
shape_analysis->GetShapeOrDataForValue(op->operand_source(2));

PADDLE_ENFORCE_EQ(q.shape().size(),
4,
phi::errors::InvalidArgument(
"flash_attn receive input with dim "
"[batch_size, seq_len, num_heads, head_dim]"));

shape_analysis->AddEqualCstr(q.shape()[0], k.shape()[0]);
shape_analysis->AddEqualCstr(q.shape()[0], v.shape()[0]);
shape_analysis->AddEqualCstr(k.shape()[1], v.shape()[1]);

if (op->operand_source(4)) {
const symbol::ShapeOrDataDimExprs &attn_mask =
shape_analysis->GetShapeOrDataForValue(op->operand_source(4));
shape_analysis->AddEqualCstr(attn_mask.shape()[0], q.shape()[0]);
shape_analysis->AddEqualCstr(attn_mask.shape()[2], q.shape()[1]);
shape_analysis->AddEqualCstr(attn_mask.shape()[3], k.shape()[1]);
}

std::vector<symbol::DimExpr> out_shape = q.shape();

out_shape.back() = v.shape().back();

shape_analysis->SetShapeOrDataForValue(
op->result(0), symbol::TensorShapeOrDataDimExprs(out_shape));

// GPU has round for seqlen, but XPU has not. Here we align with the GPU
// version.
auto round_multiple = [](symbol::DimExpr x) {
auto m = symbol::DimExpr{128};
auto m_minus_one = symbol::DimExpr{127};
return (x + m_minus_one) / m * m;
};
auto batch_size_expr = q.shape()[0];
auto num_heads_expr = q.shape()[2];
auto seqlen_q_rounded_expr = round_multiple(q.shape()[1]);
auto seqlen_k_rounded_expr = round_multiple(k.shape()[1]);
if (op->result(1)) {
std::vector<symbol::DimExpr> softmax_shape{batch_size_expr,
num_heads_expr,
seqlen_q_rounded_expr,
seqlen_k_rounded_expr};
shape_analysis->SetShapeOrDataForValue(
op->result(1), symbol::TensorShapeOrDataDimExprs(softmax_shape));
}
if (op->result(2)) {
std::vector<symbol::DimExpr> softmax_lse_shape{
batch_size_expr, num_heads_expr, seqlen_q_rounded_expr};
shape_analysis->SetShapeOrDataForValue(
op->result(2), symbol::TensorShapeOrDataDimExprs(softmax_lse_shape));
}
if (op->result(3)) {
std::vector<symbol::DimExpr> seed_offset_shape{symbol::DimExpr{2}};
shape_analysis->SetShapeOrDataForValue(
op->result(3), symbol::TensorShapeOrDataDimExprs(out_shape));
}
return true;
}

Expand Down
22 changes: 20 additions & 2 deletions paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -380,14 +380,32 @@ void FlashAttnInferMeta(const MetaTensor& q,
MetaTensor* softmax_lse,
MetaTensor* seed_offset) {
auto out_dims = q.dims();
PADDLE_ENFORCE_EQ(out_dims.size(),
4,
phi::errors::InvalidArgument(
"flash_attn receive input with dim "
"[batch_size, seq_len, num_heads, head_dim]"));
out_dims[3] = v.dims()[3];
out->set_dims(out_dims);
out->set_dtype(q.dtype());
out->set_layout(q.layout());
softmax->set_dtype(q.dtype());
softmax_lse->set_dtype(q.dtype());
auto round_multiple = [](int x) { return (x + 127) / 128 * 128; };
int batch_size = q.dims()[0];
int num_heads = q.dims()[2];
int seqlen_q_rounded = round_multiple(q.dims()[1]);
int seqlen_k_rounded = round_multiple(k.dims()[1]);
if (softmax) {
softmax->set_dtype(q.dtype());
softmax->set_dims(
{batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded});
}
if (softmax_lse) {
softmax_lse->set_dtype(q.dtype());
softmax_lse->set_dims({batch_size, num_heads, seqlen_q_rounded});
}
if (seed_offset) {
seed_offset->set_dtype(phi::DataType::INT64);
seed_offset->set_dims({2});
}
}

Expand Down

0 comments on commit 52db8e4

Please sign in to comment.