Skip to content

Commit

Permalink
Support full masked out blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Mar 26, 2024
1 parent e0403d1 commit 740e924
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 83 deletions.
72 changes: 66 additions & 6 deletions test/test_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from transformer_nuggets.flash import attention, BiasMode, build_rel_mask


@pytest.mark.parametrize("Z, H, N_CTX, D_HEAD", [(6, 8, 128, 16)])
@pytest.mark.parametrize("Z, H, N_CTX, D_HEAD", [(6, 8, 256, 16)])
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("bias_choice", [BiasMode.rel_pos, BiasMode.none, BiasMode.alibi])
@pytest.mark.parametrize("sm_scale", [None, 1])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
def test_op(Z, H, N_CTX, D_HEAD, causal, bias_choice, sm_scale, dtype=torch.float16):
def test_flash_all(Z, H, N_CTX, D_HEAD, causal, bias_choice, sm_scale, dtype=torch.float16):
torch.manual_seed(20)
q = (
torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
Expand Down Expand Up @@ -57,21 +57,81 @@ def test_op(Z, H, N_CTX, D_HEAD, causal, bias_choice, sm_scale, dtype=torch.floa
tri_dq, q.grad = q.grad.clone(), None
# Check attn_bias equivalence
if bias_choice != BiasMode.none:
torch.testing.assert_close(attn_bias, mask.half(), atol=4e-2, rtol=0)
BLOCK_M = 128
BLOCK_N = 64
mask = mask.half()
if N_CTX > BLOCK_M and causal:
# Since the kernel will not iterate over all seq_len_kv when causal
# We will only check the minimum rectangular block
attn_bias = attn_bias[:, :, :, :BLOCK_M]
mask = mask[:, :, :, :BLOCK_M]
torch.testing.assert_close(attn_bias, mask, atol=4e-2, rtol=0)

# compare
torch.testing.assert_close(ref_out, tri_out, atol=5.5e-2, rtol=0)
torch.testing.assert_close(ref_out, tri_out, atol=5.8e-2, rtol=0)
if bias_choice != BiasMode.none:
fudge_factor = 6.1
else:
fudge_factor = 1
atol = 1e-2 * fudge_factor
atol = 2e-2 * fudge_factor
if bias_choice == BiasMode.rel_pos and not causal:
atol *= 3
atol *= 4.5
torch.testing.assert_close(ref_dv, tri_dv, atol=atol, rtol=0)
torch.testing.assert_close(ref_dk, tri_dk, atol=atol, rtol=0)
torch.testing.assert_close(ref_dq, tri_dq, atol=atol, rtol=0)


def test_flash_masked_block(dtype=torch.float16):
torch.manual_seed(20)
Z, H, N_CTX, D_HEAD = (6, 8, 256, 16)
q = (
torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
.normal_(mean=0.0, std=0.5)
.requires_grad_()
)
k = (
torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
.normal_(mean=0.0, std=0.5)
.requires_grad_()
)
v = (
torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
.normal_(mean=0.0, std=0.5)
.requires_grad_()
)

sm_scale = 1 / (D_HEAD**0.5)

temp_mask = torch.ones((Z, H, N_CTX, N_CTX)).tril_(-1).bool()
ref_mask = torch.zeros_like(temp_mask, dtype=torch.float32)
ref_mask.masked_fill_(temp_mask, float("-inf"))
ref_mask = ref_mask.to(q.device).to(q.dtype)
dout = torch.randn_like(q)
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False):
ref_out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, scale=sm_scale, is_causal=False, attn_mask=ref_mask
)

ref_out.backward(dout)
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None

tri_out, mask = attention(q, k, v, False, sm_scale, BiasMode.inverse_causal, True) # type: ignore

tri_out.half()
tri_out.backward(dout)
tri_dv, v.grad = v.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
tri_dq, q.grad = q.grad.clone(), None
# Check attn_bias equivalence
atol = 2e-2 * 6
torch.testing.assert_close(ref_out, tri_out, atol=5.8e-2, rtol=0)
torch.testing.assert_close(ref_mask, mask.half(), atol=4e-2, rtol=0)
breakpoint()
torch.testing.assert_close(ref_dv, tri_dv, atol=atol, rtol=0)
torch.testing.assert_close(ref_dk, tri_dk, atol=atol, rtol=0)
torch.testing.assert_close(ref_dq, tri_dq, atol=atol, rtol=0)

if __name__ == "__main__":
pytest.main([__file__])
3 changes: 2 additions & 1 deletion transformer_nuggets/flash/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from transformer_nuggets.flash.flash_attention import * # noqa: F403
from transformer_nuggets.flash.flash_attention import *
from transformer_nuggets.flash.masks import *
111 changes: 35 additions & 76 deletions transformer_nuggets/flash/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,75 +13,31 @@
import triton
import triton.language as tl

import torch
import enum
from transformer_nuggets.flash.masks import (
alibi_attention_triton, rel_attention_triton, inverse_causal_mask_triton
)

class BiasMode(enum.Enum):
none = 0
rel_pos = 1
alibi = 2


def build_causal_mask(seq_len_q, seq_len_kv):
temp_mask = torch.ones((seq_len_q, seq_len_kv)).tril_().bool()
mask = torch.zeros_like(temp_mask, dtype=torch.float32)
mask.masked_fill_(temp_mask.logical_not(), float("-inf"))
return mask


def build_rel_mask(
n_queries: int,
n_keys: int,
n_heads: int,
mode: BiasMode,
causal=True,
):
"""Builds torch equivalent mask
Args:
n_queries: Number of queries.
n_keys: Number of keys.
n_heads: Number of attention heads.
mode: Bias mode for the attention mask.
causal: Whether to include causal mask. Defaults to True.
Returns:
torch.Tensor: The alibi attention mask.
"""
if mode == BiasMode.alibi:
assert n_heads % 8 == 0
m_0 = 2.0 ** (-8.0 / n_heads)
slopes = torch.pow(m_0, torch.arange(1, 1 + n_heads))[:, None, None]
base = -1 * (torch.arange(n_queries)[:, None] - torch.arange(n_keys)[None, :])
mask = base
mask = mask * slopes if mode == BiasMode.alibi else mask
mask = mask.expand(n_heads, n_queries, n_keys)
if causal:
causal_mask = build_causal_mask(n_queries, n_keys)
causal_mask = causal_mask.expand(n_heads, n_queries, n_keys)
full_mask = mask + causal_mask
else:
full_mask = mask
return full_mask


@triton.jit
def rel_attention_triton(cur, m, n, head_num, num_heads):
bias = n - m
cur = cur + bias
return cur


@triton.jit
def alibi_attention_triton(cur, m, n, head_num, num_heads):
# 0 Indexing
alibi_scale = tl.math.exp2(-((head_num + 1) * 8.0 / num_heads))
bias = n - m
cur = cur + (alibi_scale * bias)
return cur

inverse_causal = 3

@triton.jit
def max_fn(x, y):
return tl.math.max(x, y)

@triton.jit
def masked_row(rows):
""" rows is BLOCK_M slice of the QK score
Returns:
BLOCK_M vector of boolean values indicating whether this
Query x Key position is fully masked
"""
return rows == float("-inf")

@triton.jit
def _fwd_kernel(
Expand Down Expand Up @@ -182,13 +138,12 @@ def _fwd_kernel(
qk += tl.dot(q, k)
# ~~~~~~~~~~~~~~~~~~~ This is all mask stuff ~~~~~~~~~~~~~~~~~~~
if BIAS_CHOICE == 1:
qk = rel_attention_triton(
qk, offs_m[:, None], (start_n + offs_n[None, :]), off_hz % H, H
)
qk = rel_attention_triton(qk, offs_m[:, None], (start_n + offs_n[None, :]), off_hz%H, H)
elif BIAS_CHOICE == 2:
qk = alibi_attention_triton(
qk, offs_m[:, None], (start_n + offs_n[None, :]), off_hz % H, H
)
qk = alibi_attention_triton(qk, offs_m[:, None], (start_n + offs_n[None, :]), off_hz%H, H)
elif BIAS_CHOICE == 3:
# This should only be used for debugging
qk = inverse_causal_mask_triton(qk, offs_m[:, None], (start_n + offs_n[None, :]), off_hz%H, H)
if DEBUG_MASK and BIAS_CHOICE != BiasMode.none:
mask = qk - tl.dot(q, k)
if IS_CAUSAL:
Expand All @@ -200,12 +155,16 @@ def _fwd_kernel(
if IS_CAUSAL:
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
# -- compute scaling constant ---
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
row_max = tl.max(qk, 1)
masked_out_rows = masked_row(row_max)
m_i_new = tl.maximum(m_i, row_max)
# TODO FIX ME
# alpha = tl.math.exp2(m_i - m_i_new)
# p = tl.math.exp2(qk - m_i_new[:, None])
alpha = tl.math.exp(m_i - m_i_new)
alpha = tl.where(masked_out_rows, 0, alpha)
p = tl.math.exp(qk - m_i_new[:, None])
p = tl.where(masked_out_rows[:, None], 0, p)
# -- scale and update acc --
acc_scale = l_i * 0 + alpha # workaround some compiler bug
acc *= acc_scale[:, None]
Expand Down Expand Up @@ -345,18 +304,20 @@ def _bwd_kernel(
qk *= qk_scale
# ~~~~~~~~~~~~~~~~~~~ This is all mask stuff ~~~~~~~~~~~~~~~~~~~
if BIAS_CHOICE == 1:
qk = rel_attention_triton(
qk, offs_m_curr[:, None], (offs_n[None, :]), off_hz % H, H
)
qk = rel_attention_triton(qk, offs_m_curr[:, None], (offs_n[None, :]), off_hz%H, H)
elif BIAS_CHOICE == 2:
qk = alibi_attention_triton(
qk, offs_m_curr[:, None], (offs_n[None, :]), off_hz % H, H
)
qk = alibi_attention_triton(qk, offs_m_curr[:, None], (offs_n[None, :]), off_hz%H, H)
elif BIAS_CHOICE == 3:
# This should only be used for debugging
qk = inverse_causal_mask_triton(qk, offs_m[:, None], (start_n + offs_n[None, :]), off_hz%H, H)
# ~~~~~~~~~~~~~~~~~~~ This is the end of mask stuff ~~~~~~~~~~~~~~~~~~~
l_i = tl.load(l_ptrs + offs_m_curr)
row_max = tl.max(qk, 1)
masked_out_rows= masked_row(row_max)
# TODO fix me
# p = tl.math.exp2(qk - l_i[:, None])
p = tl.math.exp(qk - l_i[:, None])
p = tl.where(masked_out_rows[:, None], 0, p)
# compute dv
do = tl.load(do_ptrs)
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
Expand Down Expand Up @@ -394,10 +355,8 @@ def forward(ctx, q, k, v, causal, sm_scale, bias_choice: BiasMode, debug_mask=Fa
o = torch.empty_like(q)
BLOCK_M = 128
BLOCK_N = 64
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
L = torch.empty(
(q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32
)
grid = (triton.cdiv(seq_len_qv, BLOCK_M), batch_size * num_heads, 1)
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)

scratch_space = None
if debug_mask:
Expand Down
58 changes: 58 additions & 0 deletions transformer_nuggets/flash/masks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import torch
import triton
import triton.language as tl


def build_causal_mask(seq_len_q, seq_len_kv):
temp_mask = torch.ones((seq_len_q, seq_len_kv)).tril_().bool()
mask = torch.zeros_like(temp_mask, dtype=torch.float32)
mask.masked_fill_(temp_mask.logical_not(), float("-inf"))
return mask


def build_alibi_mask(n_queries, n_keys, n_heads, scale=None, causal=True):
if scale is None:
assert n_heads % 8 == 0
m_0 = 2.0 ** (-8.0 / n_heads)
slopes = torch.pow(m_0, torch.arange(1, 1 + n_heads))[:, None, None]
base = -1 * (torch.arange(n_queries)[:, None] - torch.arange(n_keys)[None, :])
if scale is not None:
alibi_base = base * scale
else:
alibi_base = base * slopes
alibi_base = alibi_base.expand(n_heads, n_queries, n_keys)
if causal:
causal_mask = build_causal_mask(n_queries, n_keys)
causal_mask = causal_mask.expand(n_heads, n_queries, n_keys)
full_mask = alibi_base + causal_mask
else:
full_mask = alibi_base
return full_mask


@triton.jit
def rel_attention_triton(cur, m, n, head_num, num_heads):
bias = n - m
cur = cur + bias
return cur


@triton.jit
def alibi_attention_triton(cur, m, n, head_num, num_heads):
# 0 Indexing
alibi_scale = tl.math.exp2(-((head_num + 1) * 8.0 / num_heads))
bias = n - m
cur = cur + (alibi_scale * bias)
return cur


@triton.jit
def causal_mask_triton(cur, m, n, head_num, num_heads):
cur = tl.where(m >= n, cur, float("-inf"))
return cur


@triton.jit
def inverse_causal_mask_triton(cur, m, n, head_num, num_heads):
cur = tl.where(m > n, float("-inf"), cur)
return cur

0 comments on commit 740e924

Please sign in to comment.