diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py index fa3c3202c120..70197689e44a 100644 --- a/python/perf-kernels/flash-attention.py +++ b/python/perf-kernels/flash-attention.py @@ -902,8 +902,8 @@ def forward(ctx, q, k, v, o, metadata): ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=metadata.max_seqlens_q, MAX_SEQLENS_K=metadata.max_seqlens_k, IS_CAUSAL=metadata.causal, VARLEN=metadata.varlen, BLOCK_DMODEL=padded_d_model, BIAS_TYPE=0 if metadata.bias is None else 1, - USE_ALIBI=False if metadata.alibi_slopes is None else True, - ENABLE_DROPOUT=metadata.dropout_p > 0.0, + USE_ALIBI=False if metadata.alibi_slopes is None else True, + ENABLE_DROPOUT=metadata.dropout_p > 0.0, RETURN_ENCODED_SOFTMAX=metadata.return_encoded_softmax) ctx.save_for_backward(q, k, v, o, M) @@ -1467,7 +1467,7 @@ def parse_args(): parser.add_argument("-hk", type=int, default=0) parser.add_argument("-sq", type=int, default=0) parser.add_argument("-sk", type=int, default=0) - parser.add_argument("-equal_seqlens", action='store_true', default=False, + parser.add_argument("-equal_seqlens", action='store_true', default=False, help='If specified, each context within the thd layout' \ ' has same seqlen as sq and sk') parser.add_argument("-d", type=int, default=0)