⚡️ Speed up function get_max_seq_length by 19%
#334
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
📄 19% (0.19x) speedup for
get_max_seq_lengthinpython/sglang/srt/sampling/sampling_params.py⏱️ Runtime :
338 milliseconds→285 milliseconds(best of30runs)📝 Explanation and details
The optimization achieves an 18% speedup by eliminating repeated set construction and streamlining conditional logic in the tight recursive loop.
Key Optimizations:
Moved set construction outside the loop:
TOKEN_CHAR1andTOKEN_REPEATsets are now created once per function call rather than being reconstructed on every loop iteration. The profiler shows the original code spent significant time (38.1% + 20.7% + 10.4% + 9.2% + 9.7% = ~88% of function time) on set operations and membership checks.Direct tuple indexing: Replaced tuple unpacking (
_, _, _, inner_subpattern = value) with direct indexing (value[3]) for SUBPATTERN, BRANCH, and REPEAT cases, reducing allocation overhead.Function localization: Cached the recursive function reference as
_max_length_from_subpattern_localto reduce attribute lookup overhead in recursive calls.Eliminated no-op operation: Changed
total += 0topassfor AT tokens (zero-width assertions).Performance Impact:
The optimization is particularly effective for large-scale regex patterns, as seen in test results where
test_large_branch()improved by 19.3% andtest_large_concat()by 20.6%. Small regex patterns show minor slowdowns (5-12%) due to the overhead of creating constant sets, but this is outweighed by gains in realistic workloads.Context Relevance:
Based on the function reference, this optimization is called during sampling parameter normalization for stop regex processing. Since this normalization likely occurs during model initialization or configuration updates, the improved performance for complex regex patterns (which are common in structured generation tasks) provides meaningful benefits without impacting the hot inference path.
✅ Correctness verification report:
⚙️ Existing Unit Tests and Runtime
🌀 Generated Regression Tests and Runtime
import logging
import sre_parse
imports
import pytest # used for our unit tests
from sglang.srt.sampling.sampling_params import get_max_seq_length
logger = logging.getLogger(name)
from sglang.srt.sampling.sampling_params import get_max_seq_length
MAX_LEN = 2**30
unit tests
----------------------------
Basic Test Cases
----------------------------
def test_single_literal():
# Single character regex should return 1
codeflash_output = get_max_seq_length("a") # 11.6μs -> 12.5μs (7.20% slower)
def test_multiple_literals():
# Multiple literals concatenated should sum up
codeflash_output = get_max_seq_length("abc") # 13.0μs -> 13.9μs (6.56% slower)
def test_dot_any_character():
# '.' matches any character, counts as 1
codeflash_output = get_max_seq_length(".") # 11.2μs -> 12.5μs (10.5% slower)
def test_character_class():
# '[abc]' matches one character, so length is 1
codeflash_output = get_max_seq_length("[abc]") # 16.6μs -> 17.6μs (6.11% slower)
def test_simple_group():
# '(a)' is the same as 'a'
codeflash_output = get_max_seq_length("(a)") # 20.9μs -> 22.3μs (6.38% slower)
def test_simple_branch():
# 'a|b' matches either 'a' or 'b', both length 1
codeflash_output = get_max_seq_length("a|b") # 18.2μs -> 19.2μs (5.07% slower)
def test_concat_and_branch():
# 'ab|cd' matches 'ab' or 'cd', both length 2
codeflash_output = get_max_seq_length("ab|cd") # 20.9μs -> 22.3μs (5.95% slower)
def test_group_with_branch():
# '(ab|c)' matches 'ab' (2) or 'c' (1), max is 2
codeflash_output = get_max_seq_length("(ab|c)") # 30.7μs -> 32.1μs (4.32% slower)
def test_simple_repeat():
# 'a{3}' matches exactly 3 'a's
codeflash_output = get_max_seq_length("a{3}") # 18.4μs -> 19.4μs (5.18% slower)
def test_mixed_literals_and_repeats():
# 'ab{2}c' matches 'a', 'bb', 'c' => 1 + 2 + 1 = 4
codeflash_output = get_max_seq_length("ab{2}c") # 19.6μs -> 20.2μs (2.90% slower)
----------------------------
Edge Test Cases
----------------------------
def test_empty_regex():
# Empty string matches empty sequence, so 0
codeflash_output = get_max_seq_length("") # 9.13μs -> 10.5μs (12.7% slower)
def test_zero_repeat():
# 'a{0}' matches zero 'a's, so 0
codeflash_output = get_max_seq_length("a{0}") # 18.3μs -> 19.0μs (3.58% slower)
def test_zero_width_assertions():
# '^a should only count 'a'
codeflash_output = get_max_seq_length("^a$") # 15.5μs -> 15.2μs (1.90% faster)
def test_optional_character():
# 'a?' matches '' or 'a', so max is 1
codeflash_output = get_max_seq_length("a?") # 16.6μs -> 17.4μs (4.78% slower)
def test_optional_group():
# '(ab)?' matches '' or 'ab', so max is 2
codeflash_output = get_max_seq_length("(ab)?") # 27.3μs -> 27.8μs (1.99% slower)
def test_nested_groups():
# '((a)b)' should be 2
codeflash_output = get_max_seq_length("((a)b)") # 28.4μs -> 29.0μs (2.10% slower)
def test_nested_branch():
# 'a|(bc|d)' matches 'a' (1), 'bc' (2), or 'd' (1), max is 2
codeflash_output = get_max_seq_length("a|(bc|d)") # 36.1μs -> 36.6μs (1.41% slower)
def test_repeat_of_group():
# '(ab){2}' matches 'ab' twice: 2 * 2 = 4
codeflash_output = get_max_seq_length("(ab){2}") # 27.2μs -> 27.4μs (0.551% slower)
def test_repeat_of_branch():
# '(a|bc){3}' could be 3*2=6 if 'bc' chosen each time
codeflash_output = get_max_seq_length("(a|bc){3}") # 35.7μs -> 36.3μs (1.68% slower)
def test_repeat_zero_to_n():
# 'a{0,5}' matches up to 5 'a's, so max is 5
codeflash_output = get_max_seq_length("a{0,5}") # 17.8μs -> 18.7μs (4.87% slower)
def test_repeat_within_character_class():
# '[ab]{4}' matches 4 characters, each one of 'a' or 'b'
codeflash_output = get_max_seq_length("[ab]{4}") # 21.5μs -> 22.7μs (5.63% slower)
def test_min_repeat():
# 'a{2,4}' should return 4
codeflash_output = get_max_seq_length("a{2,4}") # 18.1μs -> 18.3μs (1.17% slower)
def test_repeat_with_dot():
# '.{3}' matches any 3 characters
codeflash_output = get_max_seq_length(".{3}") # 17.8μs -> 17.8μs (0.275% slower)
def test_repeat_with_subpattern_and_branch():
# '((ab|c){2})' max of (ab|c) is 2, repeat 2: 2*2=4
codeflash_output = get_max_seq_length("((ab|c){2})") # 44.1μs -> 44.2μs (0.195% slower)
def test_branch_with_different_lengths():
# 'a|bc|def' max is 3
codeflash_output = get_max_seq_length("a|bc|def") # 23.9μs -> 25.1μs (4.95% slower)
def test_branch_with_empty():
# 'a|' matches 'a' or '', so max is 1
codeflash_output = get_max_seq_length("a|") # 17.9μs -> 19.0μs (5.84% slower)
def test_empty_group():
# '()' is an empty group, so 0
codeflash_output = get_max_seq_length("()") # 19.3μs -> 20.2μs (4.37% slower)
def test_unhandled_token_warning(monkeypatch):
# Test that an unhandled token triggers the MAX_LEN fallback
called = {}
def fake_warning(msg):
called['called'] = True
monkeypatch.setattr(logger, "warning", fake_warning)
# '\Z' is an end-of-string assertion, which is handled as AT, so let's use an unsupported token
# We'll simulate this by patching sre_parse to return a fake token
class FakeSubPattern(list):
def init(self):
super().init([("FAKE_TOKEN", None)])
----------------------------
Large Scale Test Cases
----------------------------
def test_large_repeat():
# 'a{1000}' should return 1000
codeflash_output = get_max_seq_length("a{1000}") # 21.6μs -> 22.3μs (3.30% slower)
def test_large_concat():
# 'a'*1000 should return 1000
regex = "a" * 1000
codeflash_output = get_max_seq_length(regex) # 684μs -> 567μs (20.6% faster)
def test_large_branch():
# 'a|aa|aaa|...|a'*1000, max is 1000
regex = "|".join("a"*i for i in range(1, 1001))
codeflash_output = get_max_seq_length(regex) # 329ms -> 276ms (19.3% faster)
def test_large_nested_repeats():
# '((a{10}){10})' = 1010 = 100
regex = "(" * 2 + "a{10}" + ")" * 2 + "{10}"
# But this is not valid, so instead:
# '(a{10}){10}' = 1010 = 100
regex = "(a{10}){10}"
codeflash_output = get_max_seq_length(regex) # 41.0μs -> 42.8μs (4.16% slower)
def test_large_nested_branches():
# '((a|b|c|d|e){5})' max per branch is 1, repeated 5 times = 5
regex = "(" + "|".join("a b c d e".split()) + "){5}"
codeflash_output = get_max_seq_length(regex) # 40.3μs -> 42.1μs (4.25% slower)
def test_large_complex_structure():
# '((ab|cd|ef){10}g){5}' max of branch is 2, repeat 10 = 20, plus 'g' = 21, repeat 5 = 105
regex = "((ab|cd|ef){10}g){5}"
# ab|cd|ef: max=2, {10}=20, +g=21, {5}=105
codeflash_output = get_max_seq_length(regex) # 53.2μs -> 54.7μs (2.77% slower)
def test_max_repeat_infinite():
# 'a*' is MAX_REPEAT, so should return MAX_LEN
codeflash_output = get_max_seq_length("a*") # 15.6μs -> 16.7μs (6.49% slower)
def test_max_repeat_infinite_group():
# '(ab)' is MAX_REPEAT, ab=2, so should return MAX_LEN
codeflash_output = get_max_seq_length("(ab)") # 24.4μs -> 24.8μs (1.50% slower)
def test_large_min_repeat():
# 'a{100,200}' should return 200
codeflash_output = get_max_seq_length("a{100,200}") # 19.5μs -> 21.0μs (7.27% slower)
def test_large_branch_with_repeats():
# '(a{100}|b{200})' max is 200
codeflash_output = get_max_seq_length("(a{100}|b{200})") # 41.8μs -> 43.2μs (3.15% slower)
def test_large_mixed():
# '(a{100}|bc{50}|def{30})' max is max(100, 1+50, 2+30) = 100
codeflash_output = get_max_seq_length("(a{100}|bc{50}|def{30})") # 51.0μs -> 52.4μs (2.77% slower)
codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
import logging
import sre_parse
imports
import pytest
from sglang.srt.sampling.sampling_params import get_max_seq_length
MAX_LEN = 2**30
unit tests
-------------------- BASIC TEST CASES --------------------
def test_single_literal():
# Single character literal should have max seq length 1
codeflash_output = get_max_seq_length("a") # 12.4μs -> 13.2μs (5.67% slower)
def test_multiple_literals():
# Multiple literals should sum their lengths
codeflash_output = get_max_seq_length("abc") # 13.7μs -> 14.7μs (6.43% slower)
def test_dot_any_character():
# '.' matches any character, so length is 1
codeflash_output = get_max_seq_length(".") # 11.8μs -> 12.4μs (4.86% slower)
# Multiple dots
codeflash_output = get_max_seq_length("...") # 9.20μs -> 9.52μs (3.32% slower)
def test_character_class():
# Character class [abc] matches one character
codeflash_output = get_max_seq_length("[abc]") # 16.4μs -> 17.1μs (3.93% slower)
# Multiple character classes
codeflash_output = get_max_seq_length("[abc][123]") # 13.4μs -> 13.6μs (1.86% slower)
def test_simple_group():
# Grouping does not affect length
codeflash_output = get_max_seq_length("(a)") # 21.0μs -> 21.9μs (4.41% slower)
codeflash_output = get_max_seq_length("(ab)") # 13.6μs -> 14.4μs (5.34% slower)
def test_branch_simple():
# Branch: a|bc, should take the longer branch
codeflash_output = get_max_seq_length("a|bc") # 20.4μs -> 20.9μs (2.39% slower)
# Branch: abc|d, should take the longer branch
codeflash_output = get_max_seq_length("abc|d") # 13.7μs -> 13.8μs (0.986% slower)
def test_repetition_fixed():
# Fixed repetition: a{3} should be 3
codeflash_output = get_max_seq_length("a{3}") # 17.8μs -> 18.3μs (2.34% slower)
# Multiple fixed repetitions
codeflash_output = get_max_seq_length("a{2}b{4}") # 16.7μs -> 16.4μs (1.82% faster)
def test_repetition_range():
# Range repetition: a{2,4} should take the max, i.e., 4
codeflash_output = get_max_seq_length("a{2,4}") # 17.5μs -> 17.8μs (1.56% slower)
def test_repetition_zero():
# Zero repetition: a{0} should be 0
codeflash_output = get_max_seq_length("a{0}") # 17.0μs -> 17.6μs (3.53% slower)
def test_repetition_min_repeat():
# ? is MIN_REPEAT (0 or 1)
codeflash_output = get_max_seq_length("a?") # 15.2μs -> 15.8μs (3.71% slower)
def test_repetition_plus():
# + is 1 or more, but unbounded, so should be MAX_LEN
codeflash_output = get_max_seq_length("a+") # 14.7μs -> 14.8μs (0.717% slower)
def test_repetition_star():
# * is 0 or more, so unbounded, should be MAX_LEN
codeflash_output = get_max_seq_length("a*") # 14.4μs -> 14.9μs (3.74% slower)
def test_sequence_with_repetition():
# ab{3}c should be 1 + 3 + 1 = 5
codeflash_output = get_max_seq_length("ab{3}c") # 19.6μs -> 20.6μs (4.62% slower)
def test_nested_group_with_repetition():
# (ab){2} should be 2 * 2 = 4
codeflash_output = get_max_seq_length("(ab){2}") # 27.3μs -> 29.6μs (7.59% slower)
def test_branch_with_repetition():
# (a|bc){3} should be 3 * max(1,2) = 6
codeflash_output = get_max_seq_length("(a|bc){3}") # 37.4μs -> 38.6μs (3.22% slower)
-------------------- EDGE TEST CASES --------------------
def test_empty_string():
# Empty regex should have length 0
codeflash_output = get_max_seq_length("") # 8.98μs -> 10.2μs (12.1% slower)
def test_only_assertions():
# Only zero-width assertions (start/end)
codeflash_output = get_max_seq_length("^$") # 14.1μs -> 14.2μs (0.763% slower)
# Word boundary
codeflash_output = get_max_seq_length(r"\b") # 9.21μs -> 8.91μs (3.39% faster)
def test_nested_empty_group():
# Nested empty group
codeflash_output = get_max_seq_length("()") # 18.0μs -> 19.5μs (7.72% slower)
def test_repetition_of_zero_length():
# (?:){5} is a group of zero-length repeated 5 times, should be 0
codeflash_output = get_max_seq_length("(){5}") # 24.9μs -> 26.5μs (6.32% slower)
def test_branch_with_empty():
# Empty branch: a| should be max(1,0) = 1
codeflash_output = get_max_seq_length("a|") # 18.2μs -> 19.3μs (5.62% slower)
# Both branches empty
codeflash_output = get_max_seq_length("|") # 9.23μs -> 9.68μs (4.65% slower)
def test_nested_repetition_unbounded():
# (ab*)* is unbounded, so MAX_LEN
codeflash_output = get_max_seq_length("(ab*)*") # 28.6μs -> 30.0μs (4.49% slower)
def test_group_with_zero_width_and_literal():
# (?:^a$) should be 1 (only 'a' is counted)
codeflash_output = get_max_seq_length("(^a$)") # 23.1μs -> 24.1μs (3.94% slower)
def test_character_class_with_range():
# [a-z] is still one character
codeflash_output = get_max_seq_length("[a-z]") # 16.1μs -> 17.7μs (9.09% slower)
def test_branch_with_zero_and_nonzero():
# (|abc) should be max(0,3) = 3
codeflash_output = get_max_seq_length("(|abc)") # 29.1μs -> 31.4μs (7.60% slower)
def test_dot_with_repetition():
# .{5} should be 5
codeflash_output = get_max_seq_length(".{5}") # 18.0μs -> 19.1μs (5.72% slower)
def test_nested_branches():
# (a|(bc|d)) should be max(1, max(2,1)) = 2
codeflash_output = get_max_seq_length("a|(bc|d)") # 37.4μs -> 38.5μs (2.74% slower)
def test_repetition_of_group_with_branch():
# (a|bc){2} should be 2*max(1,2) = 4
codeflash_output = get_max_seq_length("(a|bc){2}") # 36.7μs -> 37.7μs (2.80% slower)
def test_deeply_nested_groups():
# (((a))) should be 1
codeflash_output = get_max_seq_length("(((a)))") # 28.9μs -> 29.9μs (3.50% slower)
def test_incomplete_regex_raises():
# Invalid regex should raise an error
with pytest.raises(Exception):
get_max_seq_length("(") # 18.4μs -> 19.2μs (4.06% slower)
-------------------- LARGE SCALE TEST CASES --------------------
def test_long_literal_sequence():
# 1000-character literal string
long_str = "a" * 1000
codeflash_output = get_max_seq_length(long_str) # 665μs -> 552μs (20.5% faster)
def test_large_fixed_repetition():
# a{1000} should be 1000
codeflash_output = get_max_seq_length("a{1000}") # 19.9μs -> 21.4μs (7.09% slower)
def test_large_branch():
# Branch with two long branches
regex = "a" * 500 + "|" + "b" * 1000
codeflash_output = get_max_seq_length(regex) # 999μs -> 847μs (18.0% faster)
def test_large_nested_groups():
# Nested groups: (((...((a))...)))
nested = "a"
for _ in range(999):
nested = f"({nested})"
codeflash_output = get_max_seq_length(nested) # 3.89ms -> 4.08ms (4.61% slower)
def test_large_repetition_of_group():
# (ab){500} should be 2*500 = 1000
codeflash_output = get_max_seq_length("(ab){500}") # 35.0μs -> 36.5μs (4.23% slower)
def test_large_branch_with_repetition():
# (a{1000}|b{999}) should be 1000
codeflash_output = get_max_seq_length("(a{1000}|b{999})") # 43.6μs -> 45.2μs (3.64% slower)
def test_large_branch_with_empty():
# (|a{1000}) should be 1000
codeflash_output = get_max_seq_length("(|a{1000})") # 34.1μs -> 36.8μs (7.25% slower)
def test_large_repetition_with_dot():
# .{1000} should be 1000
codeflash_output = get_max_seq_length(".{1000}") # 18.3μs -> 19.4μs (5.77% slower)
def test_large_repetition_with_character_class():
# [ab]{1000} should be 1000
codeflash_output = get_max_seq_length("[ab]{1000}") # 22.3μs -> 24.2μs (7.92% slower)
def test_large_alternation_of_literals():
# a|b|c|...|z (26 branches, each 1 char)
regex = "|".join(chr(ord('a') + i) for i in range(26))
codeflash_output = get_max_seq_length(regex) # 60.3μs -> 63.0μs (4.30% slower)
def test_large_branch_with_mixed_lengths():
# (a{1000}|b{500}|c{10}) should be 1000
regex = "(a{1000}|b{500}|c{10})"
codeflash_output = get_max_seq_length(regex) # 50.1μs -> 51.9μs (3.52% slower)
codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
from sglang.srt.sampling.sampling_params import get_max_seq_length
def test_get_max_seq_length():
get_max_seq_length('')
🔎 Concolic Coverage Tests and Runtime
To edit these changes
git checkout codeflash/optimize-get_max_seq_length-mhtrea3cand push.