diff --git a/tritonsrc/performance_forward.py b/tritonsrc/performance_forward.py index 097a0825..2373a949 100644 --- a/tritonsrc/performance_forward.py +++ b/tritonsrc/performance_forward.py @@ -7,7 +7,7 @@ import os import triton -from attn_torch_function import attention +from attn_torch_function import attention, AttentionExtraArgs try: from flash_attn.flash_attn_interface import \ @@ -72,9 +72,11 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) b = None sm_scale = 1.3 - autotune = True - return_encoded_softmax = False - fn = lambda: attention(q, k, v, b, causal, sm_scale, split_kernel, return_encoded_softmax, autotune) + dropout_p = 0.0 + ext = AttentionExtraArgs(return_encoded_softmax=False, + autotune=True, + return_autotune=False) + fn = lambda: attention(q, k, v, b, causal, sm_scale, dropout_p, ext) if mode == 'bwd': o = fn() do = torch.randn_like(o)