Skip to content

Commit

Permalink
less broken
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Mar 28, 2024
1 parent b46956a commit 187ab7b
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 50 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ jobs:
python -m pip install --upgrade pip
pip install -e .
pip install -e .'[dev]'
- name: Lint with flake8
- name: Lint with ruff
run: |
flake8
ruff check .
- name: Test with pytest
run: |
pytest
9 changes: 5 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ repos:
- ufmt == 2.1.0
- libcst == 1.0.1

- repo: https://github.com/pycqa/flake8
rev: 7.0.0
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.3.0
hooks:
- id: flake8
additional_dependencies: [flake8-pyproject]
# Run the linter.
- id: ruff
43 changes: 38 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,49 @@ llama = [
]

# ---------- TOOL CONFIGURATIONS ------------
[tool.flake8]
max-line-length = 99
ignore = ['E231', 'E241', 'E501', 'C408', 'E261', 'E731', 'G004', 'W503', 'E203']
per-file-ignores = [
'__init__.py:F401',

# ---------- RUFF ------------
[tool.ruff]
ignore = ['E231', 'E731']
# Exclude a variety of commonly ignored directories.
exclude = [
".bzr",
".direnv",
".eggs",
".git",
".git-rewrite",
".hg",
".ipynb_checkpoints",
".mypy_cache",
".nox",
".pants.d",
".pyenv",
".pytest_cache",
".pytype",
".ruff_cache",
".svn",
".tox",
".venv",
".vscode",
"__pypackages__",
"_build",
"buck-out",
"build",
"dist",
"node_modules",
"site-packages",
"venv",
]

[tool.ruff.per-file-ignores]
"__init__.py" = ["F401", "F403"]

# ---------- UFMT ------------

[tool.usort]
first_party_detection = false

# ---------- Black ------------
[tool.black]
target-version = ["py38"]
line-length = 99
Expand Down
64 changes: 42 additions & 22 deletions test/test_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch

from torch.nn.attention import sdpa_kernel, SDPBackend
from transformer_nuggets.flash import attention, BiasMode, build_rel_mask
from transformer_nuggets.flash import attention, BiasMode, build_causal_mask, build_rel_mask


def clone_grad_and_reset(tensor):
Expand All @@ -15,13 +15,31 @@ def clone_grad_and_reset_all(*tensors):
return (clone_grad_and_reset(tensor) for tensor in tensors)


@pytest.mark.parametrize("Z, H, N_CTX, D_HEAD", [(6, 8, 256, 16)])
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("bias_choice", [BiasMode.rel_pos, BiasMode.none, BiasMode.alibi])
def maybe_grab_upper_section(mask, N_CTX, causal):
BLOCK_M = 128
if N_CTX > BLOCK_M and causal:
# Since the kernel will not iterate over all seq_len_kv when causal
# We will only check the minimum rectangular block
mask = mask[:, :, :, :BLOCK_M]
return mask


def check_bias(bias_choice, causal, attn_bias, mask, N_CTX):
if bias_choice != BiasMode.none:
mask = maybe_grab_upper_section(mask, N_CTX, causal)
attn_bias = maybe_grab_upper_section(attn_bias, N_CTX, causal)
torch.testing.assert_close(attn_bias, mask.to(attn_bias.dtype), atol=4e-2, rtol=0)


@pytest.mark.parametrize("Z, H, N_CTX, D_HEAD", [(1, 8, 128, 16)])
@pytest.mark.parametrize("is_causal", [False])
@pytest.mark.parametrize(
"bias_choice", [BiasMode.rel_pos, BiasMode.none, BiasMode.alibi, BiasMode.causal]
)
@pytest.mark.parametrize("sm_scale", [None, 1])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
def test_flash_specific_masks(
Z, H, N_CTX, D_HEAD, causal, bias_choice, sm_scale, dtype=torch.float16
Z, H, N_CTX, D_HEAD, is_causal, bias_choice, sm_scale, dtype=torch.float16
):
torch.manual_seed(20)
q = (
Expand All @@ -46,12 +64,18 @@ def test_flash_specific_masks(

# reference implementation
is_causal = False
if bias_choice in {BiasMode.none, BiasMode.causal}:
attn_bias = None
is_causal = causal
attn_bias = None
if bias_choice in {BiasMode.causal}:
attn_bias = (
build_causal_mask(N_CTX, N_CTX)
.to(device=q.device, dtype=q.dtype)
.expand(Z, H, N_CTX, N_CTX)
)
elif bias_choice in {BiasMode.rel_pos, BiasMode.alibi}:
attn_bias = build_rel_mask(N_CTX, N_CTX, H, bias_choice, causal=causal)
attn_bias = build_rel_mask(N_CTX, N_CTX, H, bias_choice, causal=is_causal)
attn_bias = attn_bias.expand(Z, H, N_CTX, N_CTX).to(q.device).to(q.dtype)
elif bias_choice == BiasMode.none:
pass
else:
raise ValueError(f"Invalid bias_choice: {bias_choice}")

Expand All @@ -62,35 +86,29 @@ def test_flash_specific_masks(
ref_out.backward(dout)
ref_dq, ref_dk, ref_dv = clone_grad_and_reset_all(q, k, v)
# triton implementation
tri_out, mask = attention(q, k, v, causal, sm_scale, bias_choice, True)
tri_out, mask = attention(q, k, v, is_causal, sm_scale, bias_choice, True)
tri_out.half()
tri_out.backward(dout)
tri_dq, tri_dk, tri_dv = clone_grad_and_reset_all(q, k, v)
# Check attn_bias equivalence
if bias_choice != BiasMode.none:
BLOCK_M = 128
mask = mask.half()
if N_CTX > BLOCK_M and causal:
# Since the kernel will not iterate over all seq_len_kv when causal
# We will only check the minimum rectangular block
attn_bias = attn_bias[:, :, :, :BLOCK_M]
mask = mask[:, :, :, :BLOCK_M]
torch.testing.assert_close(attn_bias, mask, atol=4e-2, rtol=0)

# compare
check_bias(bias_choice, is_causal, attn_bias, mask, N_CTX)

torch.testing.assert_close(ref_out, tri_out, atol=5.8e-2, rtol=0)
if bias_choice != BiasMode.none:
fudge_factor = 6.1
else:
fudge_factor = 1
atol = 2e-2 * fudge_factor
if bias_choice == BiasMode.rel_pos and not causal:
if bias_choice == BiasMode.rel_pos and not is_causal:
atol *= 4.5
torch.testing.assert_close(ref_dv, tri_dv, atol=atol, rtol=0)
torch.testing.assert_close(ref_dk, tri_dk, atol=atol, rtol=0)
torch.testing.assert_close(ref_dq, tri_dq, atol=atol, rtol=0)


@pytest.mark.xfail(reason="This test is failing due to a bug in the implementation")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
def test_flash_masked_block(dtype=torch.float16):
torch.manual_seed(20)
Z, H, N_CTX, D_HEAD = (6, 8, 256, 16)
Expand Down Expand Up @@ -132,8 +150,10 @@ def test_flash_masked_block(dtype=torch.float16):
tri_dq, tri_dk, tri_dv = clone_grad_and_reset_all(q, k, v)
# Check attn_bias equivalence
atol = 2e-2 * 6
# compare
check_bias(BiasMode.inverse_causal, False, ref_mask, mask, N_CTX)

torch.testing.assert_close(ref_out, tri_out, atol=5.8e-2, rtol=0)
torch.testing.assert_close(ref_mask, mask.half(), atol=4e-2, rtol=0)

torch.testing.assert_close(ref_dv, tri_dv, atol=atol, rtol=0)
torch.testing.assert_close(ref_dk, tri_dk, atol=atol, rtol=0)
Expand Down
14 changes: 6 additions & 8 deletions transformer_nuggets/flash/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,13 @@ def _fwd_kernel(
lo = 0
hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX
for start_n in range(lo, hi, BLOCK_N):
# -- load k, v --
# -- load k --
k = tl.load(K_block_ptr)
v = tl.load(V_block_ptr)
# -- compute qk ---
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
# ~~~~~~~~~~~~~~~~~~~ Do Score Modification ~~~~~~~~~~~~~~~~~~~
score_modification(
qk = score_modification(
qk,
offs_m,
start_n,
Expand All @@ -146,8 +145,8 @@ def _fwd_kernel(
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
# -- compute scaling constant ---
row_max = tl.max(qk, 1)
masked_out_rows = masked_row(row_max)
m_i_new = tl.maximum(m_i, row_max)
masked_out_rows = masked_row(m_i_new)
# TODO FIX ME
# alpha = tl.math.exp2(m_i - m_i_new)
# p = tl.math.exp2(qk - m_i_new[:, None])
Expand All @@ -156,6 +155,7 @@ def _fwd_kernel(
p = tl.math.exp(qk - m_i_new[:, None])
p = tl.where(masked_out_rows[:, None], 0, p)
# -- scale and update acc --
v = tl.load(V_block_ptr)
acc_scale = l_i * 0 + alpha # workaround some compiler bug
acc *= acc_scale[:, None]
acc += tl.dot(p.to(tl.float16), v)
Expand Down Expand Up @@ -293,7 +293,7 @@ def _bwd_kernel(
qk += tl.dot(q, tl.trans(k))
qk *= qk_scale
# ~~~~~~~~~~~~~~~~~~~ Do Score Modification ~~~~~~~~~~~~~~~~~~~
score_modification(
qk = score_modification(
qk,
offs_m,
start_n,
Expand All @@ -310,12 +310,9 @@ def _bwd_kernel(
)

l_i = tl.load(l_ptrs + offs_m_curr)
row_max = tl.max(qk, 1)
masked_out_rows = masked_row(row_max)
# TODO fix me
# p = tl.math.exp2(qk - l_i[:, None])
p = tl.math.exp(qk - l_i[:, None])
p = tl.where(masked_out_rows[:, None], 0, p)
# compute dv
do = tl.load(do_ptrs)
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
Expand Down Expand Up @@ -375,6 +372,7 @@ def forward(
dtype=torch.float32,
)

print(f"Using a bias choice of {bias_choice} with value {bias_choice.value}")
num_warps = 4 if Lk <= 64 else 8
_fwd_kernel[grid](
q,
Expand Down
16 changes: 8 additions & 8 deletions transformer_nuggets/flash/masks.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def build_rel_mask(
n_keys: int,
n_heads: int,
mode: BiasMode,
causal=True,
causal: bool,
):
"""Builds torch equivalent mask
Args:
Expand Down Expand Up @@ -104,19 +104,19 @@ def score_modification(
head = off_hz % num_heads
seq_len_q = offs_m[:, None]
seq_len_kv = start_n + offs_n[None, :]
if BIAS_CHOICE == 1:
if BIAS_CHOICE == BiasMode.rel_pos.value:
score = rel_attention_triton(score, batch, head, seq_len_q, seq_len_kv)
elif BIAS_CHOICE == 2:
elif BIAS_CHOICE == BiasMode.alibi.value:
score = alibi_attention_triton(score, batch, head, seq_len_q, seq_len_kv, num_heads)
elif BIAS_CHOICE == 3:
elif BIAS_CHOICE == BiasMode.inverse_causal.value:
score = inverse_causal_mask_triton(score, batch, head, seq_len_q, seq_len_kv)
elif BIAS_CHOICE == 4:
elif BIAS_CHOICE == BiasMode.causal.value:
# CAUSAL MASK
score = causal_mask_triton(score, batch, head, seq_len_q, seq_len_kv)
if DEBUG_MASK and BIAS_CHOICE != BiasMode.none:
mask = score - tl.dot(q, k)
if IS_CAUSAL:
mask = tl.where(seq_len_q >= seq_len_kv, mask, float("-inf"))
mask = score - tl.dot(q.to(MATMUL_PRECISION), k.to(MATMUL_PRECISION))
# if IS_CAUSAL:
# mask = tl.where(seq_len_q >= seq_len_kv, mask, float("-inf"))
tl.store(mask_block_ptr, mask)

return score
2 changes: 1 addition & 1 deletion transformer_nuggets/llama/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
Returns:
List[int]: A list of token IDs.
"""
assert type(s) is str
assert isinstance(s, str)
t = self.sp_model.encode(s)
if bos:
t = [self.bos_id] + t
Expand Down

0 comments on commit 187ab7b

Please sign in to comment.