Skip to content

Commit

Permalink
rename arguments in _fwd_combine_kv_splits, clean unsed code
Browse files Browse the repository at this point in the history
  • Loading branch information
iclementine committed Feb 4, 2024
1 parent 1623c58 commit 3608a91
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 47 deletions.
4 changes: 2 additions & 2 deletions src/flag_attn/flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@ def forward(ctx, q, k, v, causal, sm_scale, return_log_normalizer, return_total_

# to work around https://github.com/openai/triton/issues/2441
device = torch.cuda.device_of(q)
num_sms = torch.cuda.get_device_properties(device).multi_processor_count

with torch.cuda.device(device):
config_for_split_kv = get_fwd_config_kv_split(B, H, M, N, D, causal)
S = num_splits_herustic(B, H, M, N, D, causal, config_for_split_kv, 128)
S = num_splits_herustic(B, H, M, N, config_for_split_kv[0], config_for_split_kv[1], num_sms, 128)
split_kv: bool = S > 1

if not split_kv:
Expand Down Expand Up @@ -58,7 +59,6 @@ def forward(ctx, q, k, v, causal, sm_scale, return_log_normalizer, return_total_
else: # split kv
BLOCK_M, BLOCK_N, num_stages, num_warps = config_for_split_kv


divisible_m = M % BLOCK_M == 0
divisible_n = N % BLOCK_N == 0

Expand Down
79 changes: 34 additions & 45 deletions src/flag_attn/split_kv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
import triton.language as tl

"""
This file implements flash decoding. flash attention with split_kv, which exposes another
This file implements flash decoding, flash attention with split_kv, which exposes another
dimension of parallelism when batch_size * num_heads * blocks_along_seqlen_q cannot saturate
the gpu's SM's.
For more details, refer to https://princeton-nlp.github.io/flash-decoding/.
"""

@triton.jit
Expand Down Expand Up @@ -62,10 +64,10 @@ def _fwd_split_kv_kernel(

# load q
if DIVISIBLE_M:
q = tl.load(q_ptrs, cache_modifier=".cg")
q = tl.load(q_ptrs)
else:
mask_m = offs_m < M
q = tl.load(q_ptrs, mask=mask_m[:, None], cache_modifier=".cg")
q = tl.load(q_ptrs, mask=mask_m[:, None])

#Dot I trick: to place q in registers, it saves shared memory
if BLOCK_DMODEL < 128:
Expand Down Expand Up @@ -158,10 +160,10 @@ def _fwd_split_kv_kernel(

@triton.jit
def _fwd_combine_kv_splits(
O, L,
A, B,
stride_oz, stride_oh, stride_os, stride_om, stride_ok,
stride_az, stride_ah, stride_am, stride_ak,
multiple_o, multiple_l,
final_o, final_l,
stride_mul_oz, stride_mul_oh, stride_mul_os, stride_mul_om, stride_mul_ok,
stride_fin_oz, stride_fin_oh, stride_fin_om, stride_fin_ok,
Z, H, M, S,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
DIVISIBLE_M: tl.constexpr,
Expand All @@ -171,18 +173,18 @@ def _fwd_combine_kv_splits(
offs_z = tl.program_id(2)

# offset
O += offs_z * stride_oz + offs_h * stride_oh # (B, H, S, M, D)
L += (offs_z * H + offs_h) * S * M # (B, H, S, M)
multiple_o += offs_z * stride_mul_oz + offs_h * stride_mul_oh # (B, H, S, M, D)
multiple_l += (offs_z * H + offs_h) * S * M # (B, H, S, M)

offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
if not DIVISIBLE_M:
mask_m = offs_m < M

# first loop to compute max of log normalizers
# 1st loop: online logsumexp to save a swipe
l_max = tl.full([BLOCK_M], value=float("-inf"), dtype=tl.float32)
l_acc = tl.full([BLOCK_M], value=float("-inf"), dtype=tl.float32)
l_ptrs = L + offs_m
for i in range(0, S):
l_ptrs = multiple_l + offs_m
for _ in range(0, S):
if DIVISIBLE_M:
l = tl.load(l_ptrs)
else:
Expand All @@ -191,25 +193,12 @@ def _fwd_combine_kv_splits(
l_acc = tl.log(tl.exp(l_acc - l_max) + tl.exp(l - l_max)) + l_max
l_ptrs += M

# 2nd loop to compute max of log normalizers
# l_acc = tl.zeros([BLOCK_M], dtype=tl.float32)
# l_ptrs = L + offs_m
# for i in range(0, S):
# if DIVISIBLE_M:
# l = tl.load(l_ptrs)
# else:
# l = tl.load(l_ptrs, mask=mask_m)
# l_acc += tl.exp(l - l_max)
# l_ptrs += M
# l_acc = l_max + tl.log(l_acc)
# NOTE: we can also use an online algorithm to compute log normalizer

# 3rd loop to rescale and accumulate o
# 2rd loop to rescale and accumulate o
o_acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
l_ptrs = L + offs_m
l_ptrs = multiple_l + offs_m
offs_k = tl.arange(0, BLOCK_DMODEL)
o_ptrs = O + offs_m[:, None] * stride_om + offs_k[None, :] * stride_ok
for i in range(0, S):
o_ptrs = multiple_o + offs_m[:, None] * stride_mul_om + offs_k[None, :] * stride_mul_ok
for _ in range(0, S):
l = tl.load(l_ptrs, mask=offs_m < M)
rescale = tl.exp(l - l_acc)
if DIVISIBLE_M:
Expand All @@ -219,13 +208,13 @@ def _fwd_combine_kv_splits(
o_acc += o * rescale[:, None]

l_ptrs += M
o_ptrs += stride_os
o_ptrs += stride_mul_os

# write back
A += offs_z * stride_az + offs_h * stride_ah
B += (offs_z * H + offs_h) * M
a_ptrs = A + offs_m[:, None] * stride_am + offs_k * stride_ak
b_ptrs = B + offs_m
final_o += offs_z * stride_fin_oz + offs_h * stride_fin_oh
final_l += (offs_z * H + offs_h) * M
a_ptrs = final_o + offs_m[:, None] * stride_fin_om + offs_k * stride_fin_ok
b_ptrs = final_l + offs_m

if DIVISIBLE_M:
tl.store(a_ptrs, o_acc)
Expand All @@ -235,20 +224,19 @@ def _fwd_combine_kv_splits(
tl.store(b_ptrs, l_acc, mask=mask_m)

def get_fwd_config(B, H, M, N, D, causal):
return (16, 128, 1, 4)
# BLOCK_M, BLOCK_N, num_stages, num_warps
if D <=64:
return (16, 256, 2, 4)
else:
return (16, 128, 2, 8)

# this function is adapted from https://github.com/Dao-AILab/flash-attention/blob/61a777247900f6c2a37376f3ffd7134385fdc95c/csrc/flash_attn/flash_api.cpp#L235
def num_splits_herustic(B, H, M, N, D, causal, config, max_splits):
BLOCK_M, BLOCK_N, num_stages, num_warps = config
def num_splits_herustic(B, H, M, N, BLOCK_M, BLOCK_N, num_sms, max_splits):
num_blocks_without_split_kv = B * H * triton.cdiv(M, BLOCK_M)

dev_prop = torch.cuda.get_device_properties(0)
num_sms = dev_prop.multi_processor_count
num_n_blocks = triton.cdiv(N, BLOCK_N)

if num_blocks_without_split_kv >= 0.8 * num_sms:
return 1

num_n_blocks = triton.cdiv(N, BLOCK_N)
def num_split_avaiable(s):
blocks_per_split = triton.cdiv(num_n_blocks, s)
return s == 1 or (blocks_per_split * s - num_n_blocks < blocks_per_split)
Expand All @@ -262,7 +250,7 @@ def efficiency(s):
plans = [] # (num_split, efficiency)
max_splits = min(num_sms, num_n_blocks, max_splits)

for num_split in range(1, max_splits):
for num_split in range(1, max_splits + 1):
if num_split_avaiable(num_split):
eff = efficiency(num_split)
plans.append((num_split, eff))
Expand Down Expand Up @@ -290,11 +278,12 @@ def attention(q, k, v, causal=False, sm_scale=None):

# to work around https://github.com/openai/triton/issues/2441
device = torch.cuda.device_of(q)
num_sms = torch.cuda.get_device_properties(device).multi_processor_count

with torch.cuda.device(device):
config = get_fwd_config(B, H, M, N, D, causal)
BLOCK_M, BLOCK_N, num_stages, num_warps = config
S = num_splits_herustic(B, H, M, N, D, causal, config, 128)
# print(f"num_splits: {S}")
S = num_splits_herustic(B, H, M, N, BLOCK_M, BLOCK_N, num_sms, 128)

divisible_m = M % BLOCK_M == 0
divisible_n = N % BLOCK_N == 0
Expand Down

0 comments on commit 3608a91

Please sign in to comment.