Skip to content

Commit

Permalink
Improved TMA descriptor setup
Browse files Browse the repository at this point in the history
  • Loading branch information
plotfi committed Sep 4, 2024
1 parent 0446bb0 commit 3e2fc8c
Showing 1 changed file with 28 additions and 49 deletions.
77 changes: 28 additions & 49 deletions ops/triton/triton_ragged_hstu_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# pyre-unsafe

from typing import List, Optional
import numpy as np

import torch

Expand Down Expand Up @@ -758,11 +757,11 @@ def triton_ragged_attention(
has_attn_scale = attn_scale is not None
has_max_attn_len = max_attn_len is not None

# TMA SETUP:
TMA_SIZE = 128
BLOCK_N, BLOCK_D_V, BLOCK_D_Q = 64, DimV, DimQ
desc_k = np.empty(TMA_SIZE, dtype=np.int8)
desc_v = np.empty(TMA_SIZE, dtype=np.int8)
BLOCK_N = 64
BLOCK_D_V, BLOCK_D_Q = DimV, DimQ
desc_k = torch.empty((TMA_SIZE), device="cuda", dtype=torch.int8)
desc_v = torch.empty((TMA_SIZE), device="cuda", dtype=torch.int8)
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(
k.data_ptr(),
L,
Expand Down Expand Up @@ -896,51 +895,32 @@ def triton_ragged_attention_relative_bias(
BLOCK_D_V, BLOCK_D_Q = DimV, DimQ
desc_k = torch.empty((TMA_SIZE), device="cuda", dtype=torch.int8)
desc_v = torch.empty((TMA_SIZE), device="cuda", dtype=torch.int8)
'''
desc_k = np.empty(TMA_SIZE, dtype=np.int8)
desc_v = np.empty(TMA_SIZE, dtype=np.int8)
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(
k.data_ptr(),
L,
H * DimQ,
BLOCK_N,
BLOCK_D_Q,
k.element_size(),
desc_k,
)
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(
v.data_ptr(),
L,
H * DimV,
BLOCK_N,
BLOCK_D_V,
v.element_size(),
desc_v,
)
desc_k = torch.tensor(desc_k, device=v.device)
desc_v = torch.tensor(desc_v, device=v.device)
'''

# TODO: ???
def grid2(META):
def grid_tma(META):
nonlocal desc_k
nonlocal desc_v
#a_buf = torch.empty(TMA_SIZE, dtype=torch.int8)
k_buf = torch.empty_like(desc_k, device="cpu")
v_buf = torch.empty_like(desc_v, device="cpu")
#desc_a = desc_a.numpy() # if start with cuda, will need cpu() here
#desc_b = desc_b.numpy()
#desc_c = desc_c.numpy()
#print("enter grid2", META['BLOCK_M'], META['BLOCK_K'])
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(k.data_ptr(), L, H * DimQ, META['BLOCK_N'], BLOCK_D_Q, k.element_size(),
k_buf.numpy())
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(v.data_ptr(), L, H * DimV, META['BLOCK_N'], BLOCK_D_V, v.element_size(),
v_buf.numpy())
#desc_a = torch.tensor(desc_a, device="cuda")
#desc_b = torch.tensor(desc_b, device="cuda")
#desc_c = torch.tensor(desc_c, device="cuda")
desc_k.copy_(k_buf)
desc_v.copy_(v_buf)
k_buf = torch.empty_like(desc_k, device="cpu", pin_memory=True)
v_buf = torch.empty_like(desc_v, device="cpu", pin_memory=True)
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(
k.data_ptr(),
L,
H * DimQ,
META['BLOCK_N'],
BLOCK_D_Q,
k.element_size(),
k_buf.numpy()
)
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(
v.data_ptr(),
L,
H * DimV,
META['BLOCK_N'],
BLOCK_D_V,
v.element_size(),
v_buf.numpy()
)
desc_k.copy_(k_buf, non_blocking=True)
desc_v.copy_(v_buf, non_blocking=True)
return (triton.cdiv(N, META["BLOCK_M"]), Z * H, 1)

stride_sz = 0
Expand All @@ -954,8 +934,7 @@ def grid2(META):
use_time_bias = relative_bias_type == "TIME" or relative_bias_type == "ALL"
use_pos_bias = relative_bias_type == "POSITION" or relative_bias_type == "ALL"

# TODO: grid ???
_ragged_hstu_attn_fwd[grid](
_ragged_hstu_attn_fwd[grid_tma](
Q=q,
K=k,
V=v,
Expand Down

0 comments on commit 3e2fc8c

Please sign in to comment.