Skip to content

Commit

Permalink
Adding grid_tma for both bias and non-bias paths
Browse files Browse the repository at this point in the history
  • Loading branch information
plotfi committed Sep 4, 2024
1 parent 3e2fc8c commit ffbfb01
Showing 1 changed file with 27 additions and 30 deletions.
57 changes: 27 additions & 30 deletions ops/triton/triton_ragged_hstu_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,35 +758,36 @@ def triton_ragged_attention(
has_max_attn_len = max_attn_len is not None

TMA_SIZE = 128
BLOCK_N = 64
BLOCK_D_V, BLOCK_D_Q = DimV, DimQ
desc_k = torch.empty((TMA_SIZE), device="cuda", dtype=torch.int8)
desc_v = torch.empty((TMA_SIZE), device="cuda", dtype=torch.int8)
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(
k.data_ptr(),
L,
H * DimQ,
BLOCK_N,
BLOCK_D_Q,
k.element_size(),
desc_k,
)
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(
v.data_ptr(),
L,
H * DimV,
BLOCK_N,
BLOCK_D_V,
v.element_size(),
desc_v,
)
desc_k = torch.tensor(desc_k, device=v.device)
desc_v = torch.tensor(desc_v, device=v.device)

grid = lambda meta: ( # noqa E731
triton.cdiv(N, meta["BLOCK_M"]),
Z * H,
)
def grid_tma(META):
nonlocal desc_k
nonlocal desc_v
k_buf = torch.empty_like(desc_k, device="cpu", pin_memory=True)
v_buf = torch.empty_like(desc_v, device="cpu", pin_memory=True)
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(
k.data_ptr(),
L,
H * DimQ,
META['BLOCK_N'],
BLOCK_D_Q,
k.element_size(),
k_buf.numpy()
)
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(
v.data_ptr(),
L,
H * DimV,
META['BLOCK_N'],
BLOCK_D_V,
v.element_size(),
v_buf.numpy()
)
desc_k.copy_(k_buf, non_blocking=True)
desc_v.copy_(v_buf, non_blocking=True)
return (triton.cdiv(N, META["BLOCK_M"]), Z * H, 1)

stride_sz = 0
stride_sm = 0
Expand All @@ -797,7 +798,7 @@ def triton_ragged_attention(
stride_sz = attn_scale.stride(0)
stride_sm = attn_scale.stride(1)

_ragged_hstu_attn_fwd[grid](
_ragged_hstu_attn_fwd[grid_tma](
Q=q,
K=k,
V=v,
Expand Down Expand Up @@ -886,10 +887,6 @@ def triton_ragged_attention_relative_bias(
L, H, DimQ = q.shape
_, _, DimV = v.shape
out = torch.empty_like(v)
grid = lambda meta: ( # noqa E731
triton.cdiv(N, meta["BLOCK_M"]),
Z * H,
)

TMA_SIZE = 128
BLOCK_D_V, BLOCK_D_Q = DimV, DimQ
Expand Down

0 comments on commit ffbfb01

Please sign in to comment.