Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Nov 11, 2025

📄 14% (0.14x) speedup for create_grouped_scores in python/sglang/test/test_deepep_utils.py

⏱️ Runtime : 1.72 milliseconds 1.52 milliseconds (best of 161 runs)

📝 Explanation and details

The optimization eliminates unnecessary tensor operations by restructuring how the mask tensor is created and used.

Key changes:

  1. Direct 3D mask creation: The mask is initialized directly as (num_tokens, num_groups, 1) instead of (num_tokens, num_groups), eliminating the need for unsqueeze(-1).expand_as(scores) later.
  2. In-place scatter operation: Uses mask.scatter_() instead of mask = mask.scatter_(), avoiding creating a new tensor object.
  3. Minimal dimension adjustment: Only group_idx needs unsqueeze(-1) to match the 3D mask dimensions for the scatter operation.

Performance benefits:

  • Reduced memory operations: Eliminates the expand_as() call which creates a view with expanded dimensions, reducing memory allocation overhead.
  • Fewer tensor operations: The chain of scatter_() -> unsqueeze() -> expand_as() is reduced to just scatter_() with proper initial dimensioning.
  • Better memory locality: Direct 3D initialization likely has better cache performance than reshaping operations.

Test case analysis:
The optimization shows consistent 8-17% speedups across most test cases, with particularly strong gains (30-50%) on large-scale tests with many tokens/experts. This indicates the optimization scales well with tensor size, making it valuable for production workloads involving large expert models or batch processing scenarios.

The preserved correctness across all test cases (including edge cases with zero experts, out-of-bounds indices, and various tensor shapes) confirms the optimization maintains identical functionality while improving performance.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 29 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime

import pytest # used for our unit tests
import torch
from sglang.test.test_deepep_utils import create_grouped_scores

unit tests

---- Basic Test Cases ----

def test_basic_single_token_single_group():
# One token, one group, one expert
scores = torch.tensor([[1.0]])
group_idx = torch.tensor([[0]])
num_groups = 1
# Only one expert, so output should be unchanged
codeflash_output = create_grouped_scores(scores, group_idx, num_groups); out = codeflash_output # 48.3μs -> 43.0μs (12.1% faster)

def test_basic_multiple_tokens_groups_experts():
# Two tokens, two groups, two experts per group (4 experts total)
scores = torch.tensor([
[1.0, 2.0, 3.0, 4.0], # token 0
[5.0, 6.0, 7.0, 8.0], # token 1
])
group_idx = torch.tensor([
[0], # token 0 assigned to group 0
[1], # token 1 assigned to group 1
])
num_groups = 2
# After masking, token 0 should keep experts 0,1; token 1 should keep experts 2,3
expected = torch.tensor([
[1.0, 2.0, 0.0, 0.0],
[0.0, 0.0, 7.0, 8.0],
])
codeflash_output = create_grouped_scores(scores, group_idx, num_groups); out = codeflash_output # 37.8μs -> 34.3μs (10.2% faster)

def test_basic_all_tokens_same_group():
# Three tokens, two groups, two experts per group (4 experts)
scores = torch.tensor([
[1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0],
[9.0, 10.0, 11.0, 12.0],
])
group_idx = torch.tensor([
[1],
[1],
[1],
])
num_groups = 2
# All tokens assigned to group 1, so only experts 2,3 are kept
expected = torch.tensor([
[0.0, 0.0, 3.0, 4.0],
[0.0, 0.0, 7.0, 8.0],
[0.0, 0.0, 11.0, 12.0],
])
codeflash_output = create_grouped_scores(scores, group_idx, num_groups); out = codeflash_output # 37.5μs -> 33.7μs (11.1% faster)

def test_basic_different_group_assignments():
# Four tokens, four groups, one expert per group (4 experts)
scores = torch.tensor([
[1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0],
[9.0, 10.0, 11.0, 12.0],
[13.0, 14.0, 15.0, 16.0],
])
group_idx = torch.tensor([
[0],
[1],
[2],
[3],
])
num_groups = 4
# Each token assigned to a different group, only one expert per token is kept
expected = torch.tensor([
[1.0, 0.0, 0.0, 0.0],
[0.0, 6.0, 0.0, 0.0],
[0.0, 0.0, 11.0, 0.0],
[0.0, 0.0, 0.0, 16.0],
])
codeflash_output = create_grouped_scores(scores, group_idx, num_groups); out = codeflash_output # 36.4μs -> 33.0μs (10.1% faster)

---- Edge Test Cases ----

def test_edge_single_expert_per_group():
# 2 tokens, 3 groups, 1 expert per group (3 experts)
scores = torch.tensor([
[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
])
group_idx = torch.tensor([
[2],
[0],
])
num_groups = 3
# Only the expert corresponding to the group index is kept
expected = torch.tensor([
[0.0, 0.0, 3.0],
[4.0, 0.0, 0.0],
])
codeflash_output = create_grouped_scores(scores, group_idx, num_groups); out = codeflash_output # 60.1μs -> 55.2μs (8.98% faster)

def test_edge_group_idx_out_of_bounds():
# Should raise an error if group_idx contains out-of-bounds indices
scores = torch.tensor([[1.0, 2.0]])
group_idx = torch.tensor([[2]]) # num_groups = 2, so index 2 is invalid
num_groups = 2
with pytest.raises(RuntimeError):
create_grouped_scores(scores, group_idx, num_groups) # 81.2μs -> 85.9μs (5.41% slower)

def test_edge_non_contiguous_scores():
# Non-contiguous tensor input should still work
scores = torch.arange(12, dtype=torch.float32).reshape(3, 4)
scores_t = scores.t().t() # Make non-contiguous
group_idx = torch.tensor([[0], [1], [0]])
num_groups = 2
codeflash_output = create_grouped_scores(scores_t, group_idx, num_groups); out = codeflash_output # 41.9μs -> 36.0μs (16.3% faster)
# Should behave identically to contiguous
codeflash_output = create_grouped_scores(scores, group_idx, num_groups); expected = codeflash_output # 13.4μs -> 11.7μs (15.0% faster)

def test_edge_zero_groups():
# Should raise an error if num_groups is zero
scores = torch.tensor([[1.0, 2.0]])
group_idx = torch.tensor([[0]])
num_groups = 0
with pytest.raises(RuntimeError):
create_grouped_scores(scores, group_idx, num_groups) # 53.4μs -> 54.3μs (1.65% slower)

def test_edge_zero_experts():
# Zero experts (scores shape [N, 0]) should return zeros
scores = torch.empty((2, 0))
group_idx = torch.tensor([[0], [0]])
num_groups = 1
codeflash_output = create_grouped_scores(scores, group_idx, num_groups); out = codeflash_output # 49.4μs -> 46.9μs (5.32% faster)

def test_edge_negative_group_idx():
# Negative group_idx should raise error
scores = torch.tensor([[1.0, 2.0, 3.0, 4.0]])
group_idx = torch.tensor([[-1]])
num_groups = 2
with pytest.raises(RuntimeError):
create_grouped_scores(scores, group_idx, num_groups) # 87.7μs -> 97.2μs (9.71% slower)

---- Large Scale Test Cases ----

def test_large_scale_many_tokens_and_experts():
# 1000 tokens, 10 groups, 10 experts per group (100 experts)
num_tokens = 1000
num_groups = 10
experts_per_group = 10
num_experts = num_groups * experts_per_group
# Random scores
scores = torch.rand((num_tokens, num_experts))
# Random group assignments for each token
group_idx = torch.randint(0, num_groups, (num_tokens, 1))
codeflash_output = create_grouped_scores(scores, group_idx, num_groups); out = codeflash_output # 145μs -> 100μs (44.6% faster)
# For each token, only experts in the assigned group should be nonzero
for i in range(num_tokens):
group = group_idx[i, 0].item()
# The indices for this group
start = group * experts_per_group
end = start + experts_per_group

def test_large_scale_all_tokens_same_group():
# 500 tokens, 5 groups, 20 experts per group (100 experts)
num_tokens = 500
num_groups = 5
experts_per_group = 20
num_experts = num_groups * experts_per_group
scores = torch.rand((num_tokens, num_experts))
group_idx = torch.zeros((num_tokens, 1), dtype=torch.long) # All assigned to group 0
codeflash_output = create_grouped_scores(scores, group_idx, num_groups); out = codeflash_output # 94.2μs -> 67.0μs (40.6% faster)

def test_large_scale_all_tokens_different_groups():
# 100 tokens, 100 groups, 1 expert per group (100 experts)
num_tokens = 100
num_groups = 100
experts_per_group = 1
num_experts = num_groups * experts_per_group
scores = torch.rand((num_tokens, num_experts))
group_idx = torch.arange(0, num_groups).unsqueeze(1)[:num_tokens]
codeflash_output = create_grouped_scores(scores, group_idx, num_groups); out = codeflash_output # 48.3μs -> 41.4μs (16.6% faster)
for i in range(num_tokens):
pass

def test_large_scale_performance():
# Check that the function runs efficiently for large input
num_tokens = 1000
num_groups = 10
experts_per_group = 10
num_experts = num_groups * experts_per_group
scores = torch.rand((num_tokens, num_experts))
group_idx = torch.randint(0, num_groups, (num_tokens, 1))
import time
start = time.time()
codeflash_output = create_grouped_scores(scores, group_idx, num_groups); out = codeflash_output # 144μs -> 95.9μs (50.6% faster)
duration = time.time() - start

codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

#------------------------------------------------
import pytest # used for our unit tests
import torch
from sglang.test.test_deepep_utils import create_grouped_scores

unit tests

----------- Basic Test Cases -----------

def test_basic_single_token_single_group():
# One token, one group, one expert per group
scores = torch.tensor([[1.0]])
group_idx = torch.tensor([[0]])
num_groups = 1
# Only one expert, so output should be unchanged
codeflash_output = create_grouped_scores(scores, group_idx, num_groups); output = codeflash_output # 52.1μs -> 47.5μs (9.71% faster)

def test_basic_multiple_tokens_groups_experts():
# Two tokens, two groups, two experts per group (4 experts total)
scores = torch.tensor([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]])
group_idx = torch.tensor([[0], [1]])
num_groups = 2
# For token 0, group_idx=0, so only first two experts kept; for token 1, group_idx=1, so last two experts kept
expected = torch.tensor([[1.0, 2.0, 0.0, 0.0], [0.0, 0.0, 7.0, 8.0]])
codeflash_output = create_grouped_scores(scores, group_idx, num_groups); output = codeflash_output # 38.4μs -> 34.8μs (10.3% faster)

def test_basic_all_groups_selected():
# Every token selects every group in turn
scores = torch.tensor([[1,2,3,4],[5,6,7,8]])
group_idx = torch.tensor([[0],[0]])
num_groups = 2
# Both tokens select group 0, so experts 0,1 are kept
expected = torch.tensor([[1,2,0,0],[5,6,0,0]])
codeflash_output = create_grouped_scores(scores, group_idx, num_groups); output = codeflash_output # 37.8μs -> 35.4μs (6.83% faster)

def test_basic_different_group_per_token():
# Three tokens, two groups, two experts per group
scores = torch.tensor([[1,2,3,4],[5,6,7,8],[9,10,11,12]])
group_idx = torch.tensor([[0],[1],[0]])
num_groups = 2
# Token 0 and 2 select group 0, token 1 selects group 1
expected = torch.tensor([[1,2,0,0],[0,0,7,8],[9,10,0,0]])
codeflash_output = create_grouped_scores(scores, group_idx, num_groups); output = codeflash_output # 37.2μs -> 34.3μs (8.49% faster)

----------- Edge Test Cases -----------

def test_edge_zero_experts_per_group():
# Zero experts per group (invalid case)
scores = torch.empty((2, 0))
group_idx = torch.tensor([[0],[0]])
num_groups = 2
# Should not crash; output should be shape (2,0)
codeflash_output = create_grouped_scores(scores, group_idx, num_groups); output = codeflash_output # 34.6μs -> 33.5μs (3.43% faster)

def test_edge_group_idx_out_of_bounds():
# group_idx contains invalid index
scores = torch.tensor([[1.0,2.0,3.0,4.0]])
group_idx = torch.tensor([[2]]) # num_groups=2, so 2 is out of bounds
num_groups = 2
# Should raise an error
with pytest.raises(RuntimeError):
create_grouped_scores(scores, group_idx, num_groups) # 80.1μs -> 82.8μs (3.34% slower)

def test_edge_single_token_multiple_groups():
# Single token, multiple groups, multiple experts per group
scores = torch.tensor([[1,2,3,4,5,6]])
group_idx = torch.tensor([[1]])
num_groups = 3
# Only experts 2,3 are kept
expected = torch.tensor([[0,0,3,4,0,0]])
codeflash_output = create_grouped_scores(scores, group_idx, num_groups); output = codeflash_output # 61.7μs -> 57.0μs (8.21% faster)

def test_edge_cuda_tensor():
# If CUDA is available, test with CUDA tensors
if torch.cuda.is_available():
scores = torch.tensor([[1.,2.,3.,4.]], device='cuda')
group_idx = torch.tensor([[1]], device='cuda')
num_groups = 2
expected = torch.tensor([[0.,0.,3.,4.]], device='cuda')
codeflash_output = create_grouped_scores(scores, group_idx, num_groups); output = codeflash_output

def test_edge_scores_dtype_int():
# Scores tensor of integer dtype
scores = torch.tensor([[1,2,3,4]], dtype=torch.int32)
group_idx = torch.tensor([[1]])
num_groups = 2
expected = torch.tensor([[0,0,3,4]], dtype=torch.int32)
codeflash_output = create_grouped_scores(scores, group_idx, num_groups); output = codeflash_output # 60.0μs -> 56.2μs (6.66% faster)

def test_edge_group_idx_shape_mismatch():
# group_idx shape mismatch
scores = torch.tensor([[1,2,3,4],[5,6,7,8]])
group_idx = torch.tensor([0,1]) # missing extra dimension
num_groups = 2
with pytest.raises(RuntimeError):
create_grouped_scores(scores, group_idx, num_groups) # 62.7μs -> 69.4μs (9.73% slower)

----------- Large Scale Test Cases -----------

def test_large_scale_many_tokens():
# 1000 tokens, 4 groups, 2 experts per group (8 experts)
num_tokens = 1000
num_groups = 4
experts_per_group = 2
num_experts = num_groups * experts_per_group
scores = torch.arange(num_tokens * num_experts, dtype=torch.float32).reshape(num_tokens, num_experts)
group_idx = torch.randint(0, num_groups, (num_tokens, 1))
codeflash_output = create_grouped_scores(scores, group_idx, num_groups); output = codeflash_output # 63.5μs -> 54.0μs (17.5% faster)
# Each token should only have nonzero values in the selected group
for i in range(0, num_tokens, 100): # test every 100th token for speed
group = group_idx[i,0].item()
start = group * experts_per_group
end = start + experts_per_group

def test_large_scale_many_experts():
# 10 tokens, 10 groups, 10 experts per group (100 experts)
num_tokens = 10
num_groups = 10
experts_per_group = 10
num_experts = num_groups * experts_per_group
scores = torch.arange(num_tokens * num_experts, dtype=torch.float32).reshape(num_tokens, num_experts)
group_idx = torch.randint(0, num_groups, (num_tokens, 1))
codeflash_output = create_grouped_scores(scores, group_idx, num_groups); output = codeflash_output # 40.5μs -> 36.3μs (11.4% faster)
for i in range(num_tokens):
group = group_idx[i,0].item()
start = group * experts_per_group
end = start + experts_per_group

def test_large_scale_randomized():
# Randomized test: 50 tokens, 5 groups, 3 experts per group
num_tokens = 50
num_groups = 5
experts_per_group = 3
num_experts = num_groups * experts_per_group
scores = torch.randn(num_tokens, num_experts)
group_idx = torch.randint(0, num_groups, (num_tokens, 1))
codeflash_output = create_grouped_scores(scores, group_idx, num_groups); output = codeflash_output # 40.1μs -> 36.3μs (10.2% faster)
for i in range(0, num_tokens, 10): # check every 10th token
group = group_idx[i,0].item()
start = group * experts_per_group
end = start + experts_per_group

def test_large_scale_max_size_tensor():
# Near 100MB tensor: 500 tokens, 25 groups, 8 experts per group = 200 experts
num_tokens = 500
num_groups = 25
experts_per_group = 8
num_experts = num_groups * experts_per_group
scores = torch.randn(num_tokens, num_experts)
group_idx = torch.randint(0, num_groups, (num_tokens, 1))
codeflash_output = create_grouped_scores(scores, group_idx, num_groups); output = codeflash_output # 136μs -> 104μs (30.5% faster)
for i in range(0, num_tokens, 100): # check every 100th token for speed
group = group_idx[i,0].item()
start = group * experts_per_group
end = start + experts_per_group

codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-create_grouped_scores-mhtuvu1i and push.

Codeflash Static Badge

The optimization eliminates unnecessary tensor operations by restructuring how the mask tensor is created and used.

**Key changes:**
1. **Direct 3D mask creation**: The mask is initialized directly as `(num_tokens, num_groups, 1)` instead of `(num_tokens, num_groups)`, eliminating the need for `unsqueeze(-1).expand_as(scores)` later.
2. **In-place scatter operation**: Uses `mask.scatter_()` instead of `mask = mask.scatter_()`, avoiding creating a new tensor object.
3. **Minimal dimension adjustment**: Only `group_idx` needs `unsqueeze(-1)` to match the 3D mask dimensions for the scatter operation.

**Performance benefits:**
- **Reduced memory operations**: Eliminates the `expand_as()` call which creates a view with expanded dimensions, reducing memory allocation overhead.
- **Fewer tensor operations**: The chain of `scatter_() -> unsqueeze() -> expand_as()` is reduced to just `scatter_()` with proper initial dimensioning.
- **Better memory locality**: Direct 3D initialization likely has better cache performance than reshaping operations.

**Test case analysis:**
The optimization shows consistent 8-17% speedups across most test cases, with particularly strong gains (30-50%) on large-scale tests with many tokens/experts. This indicates the optimization scales well with tensor size, making it valuable for production workloads involving large expert models or batch processing scenarios.

The preserved correctness across all test cases (including edge cases with zero experts, out-of-bounds indices, and various tensor shapes) confirms the optimization maintains identical functionality while improving performance.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 November 11, 2025 00:51
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Nov 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant