Skip to content

Commit

Permalink
add triton mask kernel implementations (#100)
Browse files Browse the repository at this point in the history
This PR adds triton implementation of the mask kernels because triton is
easier and more friendly to maintain.

This is just a proof of concept, and I haven't tuned performance yet,
leave it for future work.
  • Loading branch information
yzh119 authored Nov 27, 2024
1 parent 590eace commit 6e73e1d
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 14 deletions.
1 change: 1 addition & 0 deletions python/xgrammar/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@

from .apply_token_bitmask_inplace_cpu import apply_token_bitmask_inplace_cpu
from .apply_token_bitmask_inplace_cuda import apply_token_bitmask_inplace_cuda
from .apply_token_bitmask_inplace_triton import apply_token_bitmask_inplace_triton
77 changes: 77 additions & 0 deletions python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import torch
import triton
import triton.language as tl

from typing import List, Optional, Union


@triton.jit
def apply_token_bitmask_inplace_kernel(
logits_ptr,
bitmask_ptr,
indices_ptr,
num_rows,
vocab_size,
bitmask_size,
NUM_SMS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
num_blocks = tl.cdiv(vocab_size, BLOCK_SIZE)
for work_id in tl.range(pid, num_rows * num_blocks, NUM_SMS):
block_offset = (work_id % num_blocks) * BLOCK_SIZE
row_id = work_id // num_blocks
batch_id = tl.load(indices_ptr + row_id)
offsets = block_offset + tl.arange(0, BLOCK_SIZE)
bitmask_offsets = block_offset // 32 + tl.arange(0, BLOCK_SIZE // 32)
vocab_mask = offsets < vocab_size
packed_bitmask_mask = bitmask_offsets < bitmask_size
logits = tl.load(logits_ptr + batch_id * vocab_size + offsets, vocab_mask)
packed_bitmask = tl.load(
bitmask_ptr + row_id * bitmask_size + bitmask_offsets, packed_bitmask_mask
)
bitmask = ((packed_bitmask[:, None] >> (tl.arange(0, 32)[None, :])) & 1) == 0
bitmask = bitmask.reshape(BLOCK_SIZE)

logits = tl.where(bitmask, -float("inf"), logits)
tl.store(logits_ptr + batch_id * vocab_size + offsets, logits, vocab_mask)


def apply_token_bitmask_inplace_triton(
logits: torch.Tensor,
bitmask: torch.Tensor,
indices: Optional[Union[List[int], torch.Tensor]] = None,
):
def ceil_div(a, b):
return (a + b - 1) // b

NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
BLOCK_SIZE = 4096
# Check input tensor shapes.
if logits.ndim == 2:
batch_size, vocab_size = logits.shape
elif logits.ndim == 1:
batch_size = 1
(vocab_size,) = logits.shape
else:
raise ValueError(f"Invalid logits tensor shape {logits.shape}")

if indices is None:
indices = torch.arange(batch_size, dtype=torch.int32, device=logits.device)
elif isinstance(indices, list):
indices = torch.tensor(indices, dtype=torch.int32, device=logits.device)

grid = lambda meta: (NUM_SMS,)

apply_token_bitmask_inplace_kernel[grid](
logits,
bitmask,
indices,
indices.shape[0],
vocab_size,
ceil_div(vocab_size, 32),
NUM_SMS,
BLOCK_SIZE,
num_warps=BLOCK_SIZE // 32 // (16 // logits.element_size()),
num_stages=3,
)
4 changes: 2 additions & 2 deletions python/xgrammar/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from .base import XGRObject, _core
from .compiler import CompiledGrammar
from .kernels import apply_token_bitmask_inplace_cpu, apply_token_bitmask_inplace_cuda
from .kernels import apply_token_bitmask_inplace_cpu, apply_token_bitmask_inplace_triton, apply_token_bitmask_inplace_cuda

"""The dtype of the bitmask: int32."""
bitmask_dtype = torch.int32
Expand Down Expand Up @@ -107,7 +107,7 @@ def apply_token_bitmask_inplace(
)

if logits.device.type == "cuda":
apply_token_bitmask_inplace_cuda(logits, bitmask, indices)
apply_token_bitmask_inplace_triton(logits, bitmask, indices)
elif logits.device.type == "cpu":
apply_token_bitmask_inplace_cpu(logits, bitmask, indices)
else:
Expand Down
30 changes: 18 additions & 12 deletions tests/python/test_grammar_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import pytest
import torch
from triton.testing import do_bench
from transformers import AutoTokenizer

import xgrammar as xgr
Expand Down Expand Up @@ -170,11 +171,11 @@ def test_apply_token_bitmask_inplace(is_cuda: bool):
bitmask = torch.tensor([0b1010101010], dtype=torch.int32).to("cuda")
xgr.apply_token_bitmask_inplace(logits_gpu, bitmask)
torch.cuda.synchronize()
assert torch.all(logits_gpu == expected.to("cuda"))
torch.testing.assert_allclose(logits_gpu, expected.to("cuda"))
else:
bitmask = torch.tensor([0b1010101010], dtype=torch.int32)
xgr.apply_token_bitmask_inplace(logits, bitmask)
assert torch.all(logits == expected)
torch.testing.assert_allclose(logits, expected)


batch_size_vocab_size_masked_cnt_stride = [
Expand Down Expand Up @@ -221,26 +222,29 @@ def bool_mask_to_bitmask(bool_mask: torch.Tensor) -> torch.Tensor:
logits_gpu = logits.to("cuda")
bitmask_gpu = bitmask.to("cuda")
torch.cuda.synchronize()
time_start = time.monotonic_ns()
if stride == 1:
# Test logic without indices
xgr.apply_token_bitmask_inplace(logits_gpu, bitmask_gpu)
f = lambda: xgr.apply_token_bitmask_inplace(logits_gpu, bitmask_gpu)
else:
xgr.apply_token_bitmask_inplace(logits_gpu, bitmask_gpu, indices=masked_batch_ids)
torch.cuda.synchronize()
time_end = time.monotonic_ns()
print(f"Time taken: {(time_end - time_start) / 1e3} us")
assert torch.all(logits_gpu == logits_expected.to("cuda"))
f = lambda: xgr.apply_token_bitmask_inplace(
logits_gpu, bitmask_gpu, indices=masked_batch_ids
)
f()
torch.testing.assert_allclose(logits_gpu, logits_expected.to("cuda"))

dur = do_bench(f, warmup=100, rep=1000)
print(f"Time taken: {(dur) * 1e3} us")
else:
time_start = time.monotonic_ns()
if stride == 1:
# Test logic without indices
xgr.apply_token_bitmask_inplace(logits, bitmask)

else:
xgr.apply_token_bitmask_inplace(logits, bitmask, indices=masked_batch_ids)
time_end = time.monotonic_ns()
print(f"Time taken: {(time_end - time_start) / 1e3} us")
assert torch.all(logits == logits_expected)
torch.testing.assert_allclose(logits, logits_expected)


def test_rollback():
Expand Down Expand Up @@ -282,7 +286,8 @@ def test_rollback():
matcher.fill_next_token_bitmask(new_token_bitmask2)
result_after_rollback.append(new_token_bitmask2)
assert matcher.accept_token(i_2)
assert all(torch.all(l == r) for l, r in zip(orig_result, result_after_rollback))
for l, r in zip(orig_result, result_after_rollback):
torch.testing.assert_allclose(l, r)


def test_reset():
Expand Down Expand Up @@ -315,7 +320,8 @@ def test_reset():
result_after_reset.append(token_bitmask)
assert matcher.accept_token(i)

assert all(torch.all(l == r) for l, r in zip(orig_result, result_after_reset))
for l, r in zip(orig_result, result_after_reset):
torch.testing.assert_allclose(l, r)


def test_termination():
Expand Down

0 comments on commit 6e73e1d

Please sign in to comment.