diff --git a/python/xgrammar/kernels/__init__.py b/python/xgrammar/kernels/__init__.py index 6954b3f..4bec490 100644 --- a/python/xgrammar/kernels/__init__.py +++ b/python/xgrammar/kernels/__init__.py @@ -2,3 +2,4 @@ from .apply_token_bitmask_inplace_cpu import apply_token_bitmask_inplace_cpu from .apply_token_bitmask_inplace_cuda import apply_token_bitmask_inplace_cuda +from .apply_token_bitmask_inplace_triton import apply_token_bitmask_inplace_triton diff --git a/python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py b/python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py new file mode 100644 index 0000000..85758c1 --- /dev/null +++ b/python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py @@ -0,0 +1,77 @@ +import torch +import triton +import triton.language as tl + +from typing import List, Optional, Union + + +@triton.jit +def apply_token_bitmask_inplace_kernel( + logits_ptr, + bitmask_ptr, + indices_ptr, + num_rows, + vocab_size, + bitmask_size, + NUM_SMS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + num_blocks = tl.cdiv(vocab_size, BLOCK_SIZE) + for work_id in tl.range(pid, num_rows * num_blocks, NUM_SMS): + block_offset = (work_id % num_blocks) * BLOCK_SIZE + row_id = work_id // num_blocks + batch_id = tl.load(indices_ptr + row_id) + offsets = block_offset + tl.arange(0, BLOCK_SIZE) + bitmask_offsets = block_offset // 32 + tl.arange(0, BLOCK_SIZE // 32) + vocab_mask = offsets < vocab_size + packed_bitmask_mask = bitmask_offsets < bitmask_size + logits = tl.load(logits_ptr + batch_id * vocab_size + offsets, vocab_mask) + packed_bitmask = tl.load( + bitmask_ptr + row_id * bitmask_size + bitmask_offsets, packed_bitmask_mask + ) + bitmask = ((packed_bitmask[:, None] >> (tl.arange(0, 32)[None, :])) & 1) == 0 + bitmask = bitmask.reshape(BLOCK_SIZE) + + logits = tl.where(bitmask, -float("inf"), logits) + tl.store(logits_ptr + batch_id * vocab_size + offsets, logits, vocab_mask) + + +def apply_token_bitmask_inplace_triton( + logits: torch.Tensor, + bitmask: torch.Tensor, + indices: Optional[Union[List[int], torch.Tensor]] = None, +): + def ceil_div(a, b): + return (a + b - 1) // b + + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + BLOCK_SIZE = 4096 + # Check input tensor shapes. + if logits.ndim == 2: + batch_size, vocab_size = logits.shape + elif logits.ndim == 1: + batch_size = 1 + (vocab_size,) = logits.shape + else: + raise ValueError(f"Invalid logits tensor shape {logits.shape}") + + if indices is None: + indices = torch.arange(batch_size, dtype=torch.int32, device=logits.device) + elif isinstance(indices, list): + indices = torch.tensor(indices, dtype=torch.int32, device=logits.device) + + grid = lambda meta: (NUM_SMS,) + + apply_token_bitmask_inplace_kernel[grid]( + logits, + bitmask, + indices, + indices.shape[0], + vocab_size, + ceil_div(vocab_size, 32), + NUM_SMS, + BLOCK_SIZE, + num_warps=BLOCK_SIZE // 32 // (16 // logits.element_size()), + num_stages=3, + ) diff --git a/python/xgrammar/matcher.py b/python/xgrammar/matcher.py index 74da739..c24d5d7 100644 --- a/python/xgrammar/matcher.py +++ b/python/xgrammar/matcher.py @@ -8,7 +8,7 @@ from .base import XGRObject, _core from .compiler import CompiledGrammar -from .kernels import apply_token_bitmask_inplace_cpu, apply_token_bitmask_inplace_cuda +from .kernels import apply_token_bitmask_inplace_cpu, apply_token_bitmask_inplace_triton, apply_token_bitmask_inplace_cuda """The dtype of the bitmask: int32.""" bitmask_dtype = torch.int32 @@ -107,7 +107,7 @@ def apply_token_bitmask_inplace( ) if logits.device.type == "cuda": - apply_token_bitmask_inplace_cuda(logits, bitmask, indices) + apply_token_bitmask_inplace_triton(logits, bitmask, indices) elif logits.device.type == "cpu": apply_token_bitmask_inplace_cpu(logits, bitmask, indices) else: diff --git a/tests/python/test_grammar_matcher.py b/tests/python/test_grammar_matcher.py index f4c6587..4adfbd3 100644 --- a/tests/python/test_grammar_matcher.py +++ b/tests/python/test_grammar_matcher.py @@ -6,6 +6,7 @@ import pytest import torch +from triton.testing import do_bench from transformers import AutoTokenizer import xgrammar as xgr @@ -170,11 +171,11 @@ def test_apply_token_bitmask_inplace(is_cuda: bool): bitmask = torch.tensor([0b1010101010], dtype=torch.int32).to("cuda") xgr.apply_token_bitmask_inplace(logits_gpu, bitmask) torch.cuda.synchronize() - assert torch.all(logits_gpu == expected.to("cuda")) + torch.testing.assert_allclose(logits_gpu, expected.to("cuda")) else: bitmask = torch.tensor([0b1010101010], dtype=torch.int32) xgr.apply_token_bitmask_inplace(logits, bitmask) - assert torch.all(logits == expected) + torch.testing.assert_allclose(logits, expected) batch_size_vocab_size_masked_cnt_stride = [ @@ -221,26 +222,29 @@ def bool_mask_to_bitmask(bool_mask: torch.Tensor) -> torch.Tensor: logits_gpu = logits.to("cuda") bitmask_gpu = bitmask.to("cuda") torch.cuda.synchronize() - time_start = time.monotonic_ns() if stride == 1: # Test logic without indices - xgr.apply_token_bitmask_inplace(logits_gpu, bitmask_gpu) + f = lambda: xgr.apply_token_bitmask_inplace(logits_gpu, bitmask_gpu) else: - xgr.apply_token_bitmask_inplace(logits_gpu, bitmask_gpu, indices=masked_batch_ids) - torch.cuda.synchronize() - time_end = time.monotonic_ns() - print(f"Time taken: {(time_end - time_start) / 1e3} us") - assert torch.all(logits_gpu == logits_expected.to("cuda")) + f = lambda: xgr.apply_token_bitmask_inplace( + logits_gpu, bitmask_gpu, indices=masked_batch_ids + ) + f() + torch.testing.assert_allclose(logits_gpu, logits_expected.to("cuda")) + + dur = do_bench(f, warmup=100, rep=1000) + print(f"Time taken: {(dur) * 1e3} us") else: time_start = time.monotonic_ns() if stride == 1: # Test logic without indices xgr.apply_token_bitmask_inplace(logits, bitmask) + else: xgr.apply_token_bitmask_inplace(logits, bitmask, indices=masked_batch_ids) time_end = time.monotonic_ns() print(f"Time taken: {(time_end - time_start) / 1e3} us") - assert torch.all(logits == logits_expected) + torch.testing.assert_allclose(logits, logits_expected) def test_rollback(): @@ -282,7 +286,8 @@ def test_rollback(): matcher.fill_next_token_bitmask(new_token_bitmask2) result_after_rollback.append(new_token_bitmask2) assert matcher.accept_token(i_2) - assert all(torch.all(l == r) for l, r in zip(orig_result, result_after_rollback)) + for l, r in zip(orig_result, result_after_rollback): + torch.testing.assert_allclose(l, r) def test_reset(): @@ -315,7 +320,8 @@ def test_reset(): result_after_reset.append(token_bitmask) assert matcher.accept_token(i) - assert all(torch.all(l == r) for l, r in zip(orig_result, result_after_reset)) + for l, r in zip(orig_result, result_after_reset): + torch.testing.assert_allclose(l, r) def test_termination():