Skip to content

Commit

Permalink
Adding META['enable_tma']
Browse files Browse the repository at this point in the history
  • Loading branch information
plotfi committed Sep 4, 2024
1 parent ffbfb01 commit 306492f
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions ops/triton/triton_ragged_hstu_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,8 @@ def grid_tma(META):
META['BLOCK_N'],
BLOCK_D_Q,
k.element_size(),
k_buf.numpy()
k_buf.numpy(),
META['enable_tma']
)
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(
v.data_ptr(),
Expand All @@ -783,7 +784,8 @@ def grid_tma(META):
META['BLOCK_N'],
BLOCK_D_V,
v.element_size(),
v_buf.numpy()
v_buf.numpy(),
META['enable_tma']
)
desc_k.copy_(k_buf, non_blocking=True)
desc_v.copy_(v_buf, non_blocking=True)
Expand Down Expand Up @@ -905,7 +907,8 @@ def grid_tma(META):
META['BLOCK_N'],
BLOCK_D_Q,
k.element_size(),
k_buf.numpy()
k_buf.numpy(),
META['enable_tma']
)
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(
v.data_ptr(),
Expand All @@ -914,7 +917,8 @@ def grid_tma(META):
META['BLOCK_N'],
BLOCK_D_V,
v.element_size(),
v_buf.numpy()
v_buf.numpy(),
META['enable_tma']
)
desc_k.copy_(k_buf, non_blocking=True)
desc_v.copy_(v_buf, non_blocking=True)
Expand Down

0 comments on commit 306492f

Please sign in to comment.