From 740e924f0eb73069f058728db0890019411e2760 Mon Sep 17 00:00:00 2001 From: drisspg Date: Wed, 9 Aug 2023 15:20:02 -0700 Subject: [PATCH] Support full masked out blocks --- test/test_flash.py | 72 +++++++++++- transformer_nuggets/flash/__init__.py | 3 +- transformer_nuggets/flash/flash_attention.py | 111 ++++++------------- transformer_nuggets/flash/masks.py | 58 ++++++++++ 4 files changed, 161 insertions(+), 83 deletions(-) create mode 100644 transformer_nuggets/flash/masks.py diff --git a/test/test_flash.py b/test/test_flash.py index c90c024..7ebc746 100644 --- a/test/test_flash.py +++ b/test/test_flash.py @@ -5,12 +5,12 @@ from transformer_nuggets.flash import attention, BiasMode, build_rel_mask -@pytest.mark.parametrize("Z, H, N_CTX, D_HEAD", [(6, 8, 128, 16)]) +@pytest.mark.parametrize("Z, H, N_CTX, D_HEAD", [(6, 8, 256, 16)]) @pytest.mark.parametrize("causal", [True, False]) @pytest.mark.parametrize("bias_choice", [BiasMode.rel_pos, BiasMode.none, BiasMode.alibi]) @pytest.mark.parametrize("sm_scale", [None, 1]) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") -def test_op(Z, H, N_CTX, D_HEAD, causal, bias_choice, sm_scale, dtype=torch.float16): +def test_flash_all(Z, H, N_CTX, D_HEAD, causal, bias_choice, sm_scale, dtype=torch.float16): torch.manual_seed(20) q = ( torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") @@ -57,21 +57,81 @@ def test_op(Z, H, N_CTX, D_HEAD, causal, bias_choice, sm_scale, dtype=torch.floa tri_dq, q.grad = q.grad.clone(), None # Check attn_bias equivalence if bias_choice != BiasMode.none: - torch.testing.assert_close(attn_bias, mask.half(), atol=4e-2, rtol=0) + BLOCK_M = 128 + BLOCK_N = 64 + mask = mask.half() + if N_CTX > BLOCK_M and causal: + # Since the kernel will not iterate over all seq_len_kv when causal + # We will only check the minimum rectangular block + attn_bias = attn_bias[:, :, :, :BLOCK_M] + mask = mask[:, :, :, :BLOCK_M] + torch.testing.assert_close(attn_bias, mask, atol=4e-2, rtol=0) # compare - torch.testing.assert_close(ref_out, tri_out, atol=5.5e-2, rtol=0) + torch.testing.assert_close(ref_out, tri_out, atol=5.8e-2, rtol=0) if bias_choice != BiasMode.none: fudge_factor = 6.1 else: fudge_factor = 1 - atol = 1e-2 * fudge_factor + atol = 2e-2 * fudge_factor if bias_choice == BiasMode.rel_pos and not causal: - atol *= 3 + atol *= 4.5 torch.testing.assert_close(ref_dv, tri_dv, atol=atol, rtol=0) torch.testing.assert_close(ref_dk, tri_dk, atol=atol, rtol=0) torch.testing.assert_close(ref_dq, tri_dq, atol=atol, rtol=0) +def test_flash_masked_block(dtype=torch.float16): + torch.manual_seed(20) + Z, H, N_CTX, D_HEAD = (6, 8, 256, 16) + q = ( + torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") + .normal_(mean=0.0, std=0.5) + .requires_grad_() + ) + k = ( + torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") + .normal_(mean=0.0, std=0.5) + .requires_grad_() + ) + v = ( + torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") + .normal_(mean=0.0, std=0.5) + .requires_grad_() + ) + + sm_scale = 1 / (D_HEAD**0.5) + + temp_mask = torch.ones((Z, H, N_CTX, N_CTX)).tril_(-1).bool() + ref_mask = torch.zeros_like(temp_mask, dtype=torch.float32) + ref_mask.masked_fill_(temp_mask, float("-inf")) + ref_mask = ref_mask.to(q.device).to(q.dtype) + dout = torch.randn_like(q) + with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False): + ref_out = torch.nn.functional.scaled_dot_product_attention( + q, k, v, scale=sm_scale, is_causal=False, attn_mask=ref_mask + ) + + ref_out.backward(dout) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + + tri_out, mask = attention(q, k, v, False, sm_scale, BiasMode.inverse_causal, True) # type: ignore + + tri_out.half() + tri_out.backward(dout) + tri_dv, v.grad = v.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dq, q.grad = q.grad.clone(), None + # Check attn_bias equivalence + atol = 2e-2 * 6 + torch.testing.assert_close(ref_out, tri_out, atol=5.8e-2, rtol=0) + torch.testing.assert_close(ref_mask, mask.half(), atol=4e-2, rtol=0) + breakpoint() + torch.testing.assert_close(ref_dv, tri_dv, atol=atol, rtol=0) + torch.testing.assert_close(ref_dk, tri_dk, atol=atol, rtol=0) + torch.testing.assert_close(ref_dq, tri_dq, atol=atol, rtol=0) + if __name__ == "__main__": pytest.main([__file__]) diff --git a/transformer_nuggets/flash/__init__.py b/transformer_nuggets/flash/__init__.py index 7714bb9..f41b1f5 100644 --- a/transformer_nuggets/flash/__init__.py +++ b/transformer_nuggets/flash/__init__.py @@ -1 +1,2 @@ -from transformer_nuggets.flash.flash_attention import * # noqa: F403 +from transformer_nuggets.flash.flash_attention import * +from transformer_nuggets.flash.masks import * diff --git a/transformer_nuggets/flash/flash_attention.py b/transformer_nuggets/flash/flash_attention.py index a7ca3ff..fef9d63 100644 --- a/transformer_nuggets/flash/flash_attention.py +++ b/transformer_nuggets/flash/flash_attention.py @@ -13,75 +13,31 @@ import triton import triton.language as tl +import torch +import enum +from transformer_nuggets.flash.masks import ( + alibi_attention_triton, rel_attention_triton, inverse_causal_mask_triton +) class BiasMode(enum.Enum): none = 0 rel_pos = 1 alibi = 2 - - -def build_causal_mask(seq_len_q, seq_len_kv): - temp_mask = torch.ones((seq_len_q, seq_len_kv)).tril_().bool() - mask = torch.zeros_like(temp_mask, dtype=torch.float32) - mask.masked_fill_(temp_mask.logical_not(), float("-inf")) - return mask - - -def build_rel_mask( - n_queries: int, - n_keys: int, - n_heads: int, - mode: BiasMode, - causal=True, -): - """Builds torch equivalent mask - Args: - n_queries: Number of queries. - n_keys: Number of keys. - n_heads: Number of attention heads. - mode: Bias mode for the attention mask. - causal: Whether to include causal mask. Defaults to True. - - Returns: - torch.Tensor: The alibi attention mask. - """ - if mode == BiasMode.alibi: - assert n_heads % 8 == 0 - m_0 = 2.0 ** (-8.0 / n_heads) - slopes = torch.pow(m_0, torch.arange(1, 1 + n_heads))[:, None, None] - base = -1 * (torch.arange(n_queries)[:, None] - torch.arange(n_keys)[None, :]) - mask = base - mask = mask * slopes if mode == BiasMode.alibi else mask - mask = mask.expand(n_heads, n_queries, n_keys) - if causal: - causal_mask = build_causal_mask(n_queries, n_keys) - causal_mask = causal_mask.expand(n_heads, n_queries, n_keys) - full_mask = mask + causal_mask - else: - full_mask = mask - return full_mask - - -@triton.jit -def rel_attention_triton(cur, m, n, head_num, num_heads): - bias = n - m - cur = cur + bias - return cur - - -@triton.jit -def alibi_attention_triton(cur, m, n, head_num, num_heads): - # 0 Indexing - alibi_scale = tl.math.exp2(-((head_num + 1) * 8.0 / num_heads)) - bias = n - m - cur = cur + (alibi_scale * bias) - return cur - + inverse_causal = 3 @triton.jit def max_fn(x, y): return tl.math.max(x, y) +@triton.jit +def masked_row(rows): + """ rows is BLOCK_M slice of the QK score + Returns: + BLOCK_M vector of boolean values indicating whether this + Query x Key position is fully masked + + """ + return rows == float("-inf") @triton.jit def _fwd_kernel( @@ -182,13 +138,12 @@ def _fwd_kernel( qk += tl.dot(q, k) # ~~~~~~~~~~~~~~~~~~~ This is all mask stuff ~~~~~~~~~~~~~~~~~~~ if BIAS_CHOICE == 1: - qk = rel_attention_triton( - qk, offs_m[:, None], (start_n + offs_n[None, :]), off_hz % H, H - ) + qk = rel_attention_triton(qk, offs_m[:, None], (start_n + offs_n[None, :]), off_hz%H, H) elif BIAS_CHOICE == 2: - qk = alibi_attention_triton( - qk, offs_m[:, None], (start_n + offs_n[None, :]), off_hz % H, H - ) + qk = alibi_attention_triton(qk, offs_m[:, None], (start_n + offs_n[None, :]), off_hz%H, H) + elif BIAS_CHOICE == 3: + # This should only be used for debugging + qk = inverse_causal_mask_triton(qk, offs_m[:, None], (start_n + offs_n[None, :]), off_hz%H, H) if DEBUG_MASK and BIAS_CHOICE != BiasMode.none: mask = qk - tl.dot(q, k) if IS_CAUSAL: @@ -200,12 +155,16 @@ def _fwd_kernel( if IS_CAUSAL: qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) # -- compute scaling constant --- - m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + row_max = tl.max(qk, 1) + masked_out_rows = masked_row(row_max) + m_i_new = tl.maximum(m_i, row_max) # TODO FIX ME # alpha = tl.math.exp2(m_i - m_i_new) # p = tl.math.exp2(qk - m_i_new[:, None]) alpha = tl.math.exp(m_i - m_i_new) + alpha = tl.where(masked_out_rows, 0, alpha) p = tl.math.exp(qk - m_i_new[:, None]) + p = tl.where(masked_out_rows[:, None], 0, p) # -- scale and update acc -- acc_scale = l_i * 0 + alpha # workaround some compiler bug acc *= acc_scale[:, None] @@ -345,18 +304,20 @@ def _bwd_kernel( qk *= qk_scale # ~~~~~~~~~~~~~~~~~~~ This is all mask stuff ~~~~~~~~~~~~~~~~~~~ if BIAS_CHOICE == 1: - qk = rel_attention_triton( - qk, offs_m_curr[:, None], (offs_n[None, :]), off_hz % H, H - ) + qk = rel_attention_triton(qk, offs_m_curr[:, None], (offs_n[None, :]), off_hz%H, H) elif BIAS_CHOICE == 2: - qk = alibi_attention_triton( - qk, offs_m_curr[:, None], (offs_n[None, :]), off_hz % H, H - ) + qk = alibi_attention_triton(qk, offs_m_curr[:, None], (offs_n[None, :]), off_hz%H, H) + elif BIAS_CHOICE == 3: + # This should only be used for debugging + qk = inverse_causal_mask_triton(qk, offs_m[:, None], (start_n + offs_n[None, :]), off_hz%H, H) # ~~~~~~~~~~~~~~~~~~~ This is the end of mask stuff ~~~~~~~~~~~~~~~~~~~ l_i = tl.load(l_ptrs + offs_m_curr) + row_max = tl.max(qk, 1) + masked_out_rows= masked_row(row_max) # TODO fix me # p = tl.math.exp2(qk - l_i[:, None]) p = tl.math.exp(qk - l_i[:, None]) + p = tl.where(masked_out_rows[:, None], 0, p) # compute dv do = tl.load(do_ptrs) dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do) @@ -394,10 +355,8 @@ def forward(ctx, q, k, v, causal, sm_scale, bias_choice: BiasMode, debug_mask=Fa o = torch.empty_like(q) BLOCK_M = 128 BLOCK_N = 64 - grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1) - L = torch.empty( - (q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32 - ) + grid = (triton.cdiv(seq_len_qv, BLOCK_M), batch_size * num_heads, 1) + L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) scratch_space = None if debug_mask: diff --git a/transformer_nuggets/flash/masks.py b/transformer_nuggets/flash/masks.py new file mode 100644 index 0000000..3f79dcf --- /dev/null +++ b/transformer_nuggets/flash/masks.py @@ -0,0 +1,58 @@ +import torch +import triton +import triton.language as tl + + +def build_causal_mask(seq_len_q, seq_len_kv): + temp_mask = torch.ones((seq_len_q, seq_len_kv)).tril_().bool() + mask = torch.zeros_like(temp_mask, dtype=torch.float32) + mask.masked_fill_(temp_mask.logical_not(), float("-inf")) + return mask + + +def build_alibi_mask(n_queries, n_keys, n_heads, scale=None, causal=True): + if scale is None: + assert n_heads % 8 == 0 + m_0 = 2.0 ** (-8.0 / n_heads) + slopes = torch.pow(m_0, torch.arange(1, 1 + n_heads))[:, None, None] + base = -1 * (torch.arange(n_queries)[:, None] - torch.arange(n_keys)[None, :]) + if scale is not None: + alibi_base = base * scale + else: + alibi_base = base * slopes + alibi_base = alibi_base.expand(n_heads, n_queries, n_keys) + if causal: + causal_mask = build_causal_mask(n_queries, n_keys) + causal_mask = causal_mask.expand(n_heads, n_queries, n_keys) + full_mask = alibi_base + causal_mask + else: + full_mask = alibi_base + return full_mask + + +@triton.jit +def rel_attention_triton(cur, m, n, head_num, num_heads): + bias = n - m + cur = cur + bias + return cur + + +@triton.jit +def alibi_attention_triton(cur, m, n, head_num, num_heads): + # 0 Indexing + alibi_scale = tl.math.exp2(-((head_num + 1) * 8.0 / num_heads)) + bias = n - m + cur = cur + (alibi_scale * bias) + return cur + + +@triton.jit +def causal_mask_triton(cur, m, n, head_num, num_heads): + cur = tl.where(m >= n, cur, float("-inf")) + return cur + + +@triton.jit +def inverse_causal_mask_triton(cur, m, n, head_num, num_heads): + cur = tl.where(m > n, float("-inf"), cur) + return cur