Skip to content

Commit

Permalink
Fix tritonsrc/performance_forward.py
Browse files Browse the repository at this point in the history
  • Loading branch information
xinyazhang committed Aug 9, 2024
1 parent cb9427d commit b5f8997
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions tritonsrc/performance_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import os
import triton
from attn_torch_function import attention
from attn_torch_function import attention, AttentionExtraArgs

try:
from flash_attn.flash_attn_interface import \
Expand Down Expand Up @@ -72,9 +72,11 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
b = None
sm_scale = 1.3
autotune = True
return_encoded_softmax = False
fn = lambda: attention(q, k, v, b, causal, sm_scale, split_kernel, return_encoded_softmax, autotune)
dropout_p = 0.0
ext = AttentionExtraArgs(return_encoded_softmax=False,
autotune=True,
return_autotune=False)
fn = lambda: attention(q, k, v, b, causal, sm_scale, dropout_p, ext)
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)
Expand Down

0 comments on commit b5f8997

Please sign in to comment.