-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Fix batch invariant ops #11368
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Fix batch invariant ops #11368
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
e489d59
Fix batch invariant ops
ba882d5
Fix batch invariant kernel
hebiao064 ad20b27
Merge branch 'main' into bhe/fix_batch_invariant_ops
hebiao064 1a46798
lint
hebiao064 656fdcf
fix
hebiao064 3ee3205
Merge branch 'main' into bhe/fix_batch_invariant_ops
Fridge003 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,163 @@ | ||
| # Adapted from https://github.com/thinking-machines-lab/batch_invariant_ops/blob/main/test_batch_invariance.py | ||
| import math | ||
| import unittest | ||
|
|
||
| import torch | ||
|
|
||
| from sglang.srt.batch_invariant_ops.batch_invariant_ops import set_batch_invariant_mode | ||
| from sglang.test.test_utils import CustomTestCase | ||
|
|
||
| device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu") | ||
| torch.set_default_device(device_type) | ||
|
|
||
| # Just to get the logging out of the way | ||
| with set_batch_invariant_mode(True): | ||
| pass | ||
|
|
||
|
|
||
| class TestBatchInvariantOps(CustomTestCase): | ||
| def _test_batch_invariance(self, M, K, N, dtype): | ||
| """ | ||
| Test that matrix operations produce identical results for: | ||
| - Method 1: Matrix-vector multiplication (batch size 1) | ||
| - Method 2: Matrix-matrix multiplication, then slice (full batch) | ||
| """ | ||
| a = torch.linspace(-100, 100, M * K, dtype=dtype).reshape(M, K) | ||
|
|
||
| # Create non-contiguous tensor | ||
| b = torch.linspace(-100, 100, K * N, dtype=dtype).reshape(N, K) | ||
| b = b.transpose(0, 1) | ||
|
|
||
| # Method 1: Matrix-vector multiplication (batch size 1) | ||
| out1 = torch.mm(a[:1], b) | ||
|
|
||
| # Method 2: Matrix-matrix multiplication, then slice (full batch) | ||
| out2_pre = torch.mm(a, b) | ||
| out2 = out2_pre[:1] | ||
|
|
||
| # Check if results are identical | ||
| diff = (out1 - out2).abs().max() | ||
| return diff.item() | ||
|
|
||
| def _run_multiple_iterations(self, iters, M, K, N, dtype): | ||
| """Run multiple iterations and collect diff statistics""" | ||
| difflist = [] | ||
| for _ in range(iters): | ||
| diff = self._test_batch_invariance(M, K, N, dtype) | ||
| difflist.append(diff) | ||
| return difflist | ||
|
|
||
| def _assert_batch_invariant_results(self, difflist, dtype, test_name): | ||
| """ | ||
| Assert that in batch-invariant mode: | ||
| 1. All diffs must not be NaN | ||
| 2. All diffs must be exactly 0 | ||
| 3. Max, min, and diff of diffs must all be 0 | ||
| """ | ||
| max_diff = max(difflist) | ||
| min_diff = min(difflist) | ||
| diff_range = max_diff - min_diff | ||
|
|
||
| # Check for NaN values | ||
| self.assertFalse( | ||
| math.isnan(max_diff), f"{test_name}: max_diff is NaN for {dtype}" | ||
| ) | ||
| self.assertFalse( | ||
| math.isnan(min_diff), f"{test_name}: min_diff is NaN for {dtype}" | ||
| ) | ||
| self.assertFalse( | ||
| math.isnan(diff_range), f"{test_name}: diff_range is NaN for {dtype}" | ||
| ) | ||
|
|
||
| # Check that all diffs are exactly 0 | ||
| self.assertEqual( | ||
| max_diff, | ||
| 0.0, | ||
| f"{test_name}: max_diff must be 0 in batch-invariant mode, got {max_diff} for {dtype}", | ||
| ) | ||
| self.assertEqual( | ||
| min_diff, | ||
| 0.0, | ||
| f"{test_name}: min_diff must be 0 in batch-invariant mode, got {min_diff} for {dtype}", | ||
| ) | ||
| self.assertEqual( | ||
| diff_range, | ||
| 0.0, | ||
| f"{test_name}: diff_range must be 0 in batch-invariant mode, got {diff_range} for {dtype}", | ||
| ) | ||
|
|
||
| def test_small_matrices(self): | ||
| """Test batch invariance with small matrix sizes""" | ||
| test_cases = [ | ||
| ("Small-1", 8, 64, 128), | ||
| ("Small-2", 16, 128, 256), | ||
| ("Small-3", 4, 32, 64), | ||
| ] | ||
|
|
||
| for name, M, K, N in test_cases: | ||
| with self.subTest(name=name, M=M, K=K, N=N): | ||
| for dtype in [torch.float32, torch.bfloat16]: | ||
| with self.subTest(dtype=dtype): | ||
| # Run with batch-invariant mode | ||
| with set_batch_invariant_mode(True): | ||
| difflist = self._run_multiple_iterations( | ||
| iters=5, M=M, K=K, N=N, dtype=dtype | ||
| ) | ||
| self._assert_batch_invariant_results(difflist, dtype, name) | ||
|
|
||
| def test_medium_matrices(self): | ||
| """Test batch invariance with medium matrix sizes""" | ||
| test_cases = [ | ||
| ("Medium-1", 32, 128, 1024), | ||
| ("Medium-2", 64, 512, 2048), | ||
| ("Medium-3", 24, 192, 768), | ||
| ] | ||
|
|
||
| for name, M, K, N in test_cases: | ||
| with self.subTest(name=name, M=M, K=K, N=N): | ||
| for dtype in [torch.float32, torch.bfloat16]: | ||
| with self.subTest(dtype=dtype): | ||
| # Run with batch-invariant mode | ||
| with set_batch_invariant_mode(True): | ||
| difflist = self._run_multiple_iterations( | ||
| iters=5, M=M, K=K, N=N, dtype=dtype | ||
| ) | ||
| self._assert_batch_invariant_results(difflist, dtype, name) | ||
|
|
||
| def test_large_matrices(self): | ||
| """Test batch invariance with large matrix sizes""" | ||
| test_cases = [ | ||
| ("Large-1", 128, 1024, 4096), | ||
| ("Large-2", 256, 2048, 8192), | ||
| ("Large-3", 96, 768, 3072), | ||
| ] | ||
|
|
||
| for name, M, K, N in test_cases: | ||
| with self.subTest(name=name, M=M, K=K, N=N): | ||
| for dtype in [torch.float32, torch.bfloat16]: | ||
| with self.subTest(dtype=dtype): | ||
| # Run with batch-invariant mode | ||
| with set_batch_invariant_mode(True): | ||
| difflist = self._run_multiple_iterations( | ||
| iters=5, M=M, K=K, N=N, dtype=dtype | ||
| ) | ||
| self._assert_batch_invariant_results(difflist, dtype, name) | ||
|
|
||
| def test_without_batch_invariant_mode(self): | ||
| """ | ||
| Test that without batch-invariant mode, results may differ. | ||
| This test demonstrates the difference batch-invariant mode makes. | ||
| """ | ||
| M, K, N = 32, 128, 1024 | ||
| dtype = torch.float32 | ||
|
|
||
| # Run without batch-invariant mode | ||
| with set_batch_invariant_mode(False): | ||
| difflist = self._run_multiple_iterations( | ||
| iters=5, M=M, K=K, N=N, dtype=dtype | ||
| ) | ||
| print(f"Without batch-invariant mode, we get diffs: {difflist}") | ||
|
|
||
Fridge003 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.