Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions python/sglang/srt/batch_invariant_ops/batch_invariant_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,6 @@ def matmul_kernel_persistent(
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
num_tiles = num_pid_m * num_pid_n

tile_id_c = start_pid - NUM_SMS

offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K)
num_pid_in_group = GROUP_SIZE_M * num_pid_n

Expand Down Expand Up @@ -120,10 +118,6 @@ def matmul_kernel_persistent(
)
accumulator = tl.dot(a, b, accumulator)

tile_id_c += NUM_SMS
pid_m, pid_n = _compute_pid(
tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS
)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
if C_LARGE:
Expand All @@ -137,6 +131,10 @@ def matmul_kernel_persistent(
accumulator += bias
if c_ptr.dtype.element_ty == tl.float8e4nv:
c = accumulator.to(tl.float8e4nv)
elif c_ptr.dtype.element_ty == tl.bfloat16:
c = accumulator.to(tl.bfloat16)
elif c_ptr.dtype.element_ty == tl.float32:
c = accumulator.to(tl.float32)
else:
c = accumulator.to(tl.float16)
tl.store(c_ptrs, c, mask=c_mask)
Expand Down
163 changes: 163 additions & 0 deletions test/srt/batch_invariant/test_batch_invariant_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# Adapted from https://github.com/thinking-machines-lab/batch_invariant_ops/blob/main/test_batch_invariance.py
import math
import unittest

import torch

from sglang.srt.batch_invariant_ops.batch_invariant_ops import set_batch_invariant_mode
from sglang.test.test_utils import CustomTestCase

device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu")
torch.set_default_device(device_type)

# Just to get the logging out of the way
with set_batch_invariant_mode(True):
pass


class TestBatchInvariantOps(CustomTestCase):
def _test_batch_invariance(self, M, K, N, dtype):
"""
Test that matrix operations produce identical results for:
- Method 1: Matrix-vector multiplication (batch size 1)
- Method 2: Matrix-matrix multiplication, then slice (full batch)
"""
a = torch.linspace(-100, 100, M * K, dtype=dtype).reshape(M, K)

# Create non-contiguous tensor
b = torch.linspace(-100, 100, K * N, dtype=dtype).reshape(N, K)
b = b.transpose(0, 1)

# Method 1: Matrix-vector multiplication (batch size 1)
out1 = torch.mm(a[:1], b)

# Method 2: Matrix-matrix multiplication, then slice (full batch)
out2_pre = torch.mm(a, b)
out2 = out2_pre[:1]

# Check if results are identical
diff = (out1 - out2).abs().max()
return diff.item()

def _run_multiple_iterations(self, iters, M, K, N, dtype):
"""Run multiple iterations and collect diff statistics"""
difflist = []
for _ in range(iters):
diff = self._test_batch_invariance(M, K, N, dtype)
difflist.append(diff)
return difflist

def _assert_batch_invariant_results(self, difflist, dtype, test_name):
"""
Assert that in batch-invariant mode:
1. All diffs must not be NaN
2. All diffs must be exactly 0
3. Max, min, and diff of diffs must all be 0
"""
max_diff = max(difflist)
min_diff = min(difflist)
diff_range = max_diff - min_diff

# Check for NaN values
self.assertFalse(
math.isnan(max_diff), f"{test_name}: max_diff is NaN for {dtype}"
)
self.assertFalse(
math.isnan(min_diff), f"{test_name}: min_diff is NaN for {dtype}"
)
self.assertFalse(
math.isnan(diff_range), f"{test_name}: diff_range is NaN for {dtype}"
)

# Check that all diffs are exactly 0
self.assertEqual(
max_diff,
0.0,
f"{test_name}: max_diff must be 0 in batch-invariant mode, got {max_diff} for {dtype}",
)
self.assertEqual(
min_diff,
0.0,
f"{test_name}: min_diff must be 0 in batch-invariant mode, got {min_diff} for {dtype}",
)
self.assertEqual(
diff_range,
0.0,
f"{test_name}: diff_range must be 0 in batch-invariant mode, got {diff_range} for {dtype}",
)

def test_small_matrices(self):
"""Test batch invariance with small matrix sizes"""
test_cases = [
("Small-1", 8, 64, 128),
("Small-2", 16, 128, 256),
("Small-3", 4, 32, 64),
]

for name, M, K, N in test_cases:
with self.subTest(name=name, M=M, K=K, N=N):
for dtype in [torch.float32, torch.bfloat16]:
with self.subTest(dtype=dtype):
# Run with batch-invariant mode
with set_batch_invariant_mode(True):
difflist = self._run_multiple_iterations(
iters=5, M=M, K=K, N=N, dtype=dtype
)
self._assert_batch_invariant_results(difflist, dtype, name)

def test_medium_matrices(self):
"""Test batch invariance with medium matrix sizes"""
test_cases = [
("Medium-1", 32, 128, 1024),
("Medium-2", 64, 512, 2048),
("Medium-3", 24, 192, 768),
]

for name, M, K, N in test_cases:
with self.subTest(name=name, M=M, K=K, N=N):
for dtype in [torch.float32, torch.bfloat16]:
with self.subTest(dtype=dtype):
# Run with batch-invariant mode
with set_batch_invariant_mode(True):
difflist = self._run_multiple_iterations(
iters=5, M=M, K=K, N=N, dtype=dtype
)
self._assert_batch_invariant_results(difflist, dtype, name)

def test_large_matrices(self):
"""Test batch invariance with large matrix sizes"""
test_cases = [
("Large-1", 128, 1024, 4096),
("Large-2", 256, 2048, 8192),
("Large-3", 96, 768, 3072),
]

for name, M, K, N in test_cases:
with self.subTest(name=name, M=M, K=K, N=N):
for dtype in [torch.float32, torch.bfloat16]:
with self.subTest(dtype=dtype):
# Run with batch-invariant mode
with set_batch_invariant_mode(True):
difflist = self._run_multiple_iterations(
iters=5, M=M, K=K, N=N, dtype=dtype
)
self._assert_batch_invariant_results(difflist, dtype, name)

def test_without_batch_invariant_mode(self):
"""
Test that without batch-invariant mode, results may differ.
This test demonstrates the difference batch-invariant mode makes.
"""
M, K, N = 32, 128, 1024
dtype = torch.float32

# Run without batch-invariant mode
with set_batch_invariant_mode(False):
difflist = self._run_multiple_iterations(
iters=5, M=M, K=K, N=N, dtype=dtype
)
print(f"Without batch-invariant mode, we get diffs: {difflist}")


if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class TestFile:
# TestFile("models/test_gme_qwen_models.py", 45),
# TestFile("models/test_grok_models.py", 60), # Disabled due to illegal memory access
TestFile("models/test_qwen_models.py", 82),
TestFile("batch_invariant/test_batch_invariant_ops.py", 10),
TestFile("models/test_reward_models.py", 132),
TestFile("models/test_vlm_models.py", 741),
TestFile("models/test_transformers_models.py", 320),
Expand Down
Loading