Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 34 additions & 19 deletions python/sglang/srt/sampling/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,34 +204,49 @@ def get_max_seq_length(regex_str: str):


def _max_length_from_subpattern(subpattern: sre_parse.SubPattern):
# OPT: Avoid repeated set construction in loop--move sets to module scope.
# OPT: Avoid set lookup inside loop for token-type branches.
# Set objects are hash-tables, so moving these to constants for reuse will help,
# but the biggest win is flattening the conditional logic for the tight for-loop.
TOKEN_CHAR1 = {sre_parse.LITERAL, sre_parse.IN, sre_parse.ANY}
TOKEN_REPEAT = {sre_parse.MAX_REPEAT, sre_parse.MIN_REPEAT}

total = 0
# OPT: Convert to indexed loop for slightly faster iteration.
# But because SubPattern may not support index access efficiently, keep as 'for'.
# Avoid recreating sets per iteration, and use direct branching.

# Localize frequently accessed functions for minor speedup under CPython
_max_length_from_subpattern_local = _max_length_from_subpattern

# OPT: Replace logger.warning with direct variable capture: keep as logger.warning to preserve side effects.

for token, value in subpattern:
if token in {
sre_parse.LITERAL, # `value` is any one character
sre_parse.IN, # Any character within `value`
sre_parse.ANY, # "."
}:
if token in TOKEN_CHAR1:
# LITERAL, IN, ANY
total += 1
elif token == sre_parse.SUBPATTERN:
# EG: (a\d+) ->
# [(SUBPATTERN,
# (1, 0, 0, [(LITERAL, 97),
# (MAX_REPEAT, (1, MAXREPEAT, [(IN, [(CATEGORY, CATEGORY_DIGIT)])]))]))]
_, _, _, inner_subpattern = value
total += _max_length_from_subpattern(inner_subpattern)
# This logic is correct.
# value = (group, add_flags, del_flags, subpattern)
inner_subpattern = value[3]
total += _max_length_from_subpattern_local(inner_subpattern)
elif token == sre_parse.BRANCH:
_, branches = value
total += max(_max_length_from_subpattern(branch) for branch in branches)
elif token in {sre_parse.MAX_REPEAT, sre_parse.MIN_REPEAT}:
_, max_num_repeat, inner_subpattern = value
# value = (None, branches)
branches = value[1]
# OPT: reduce per-branch function calls by using map and list comprehension
# but max() over a generator is already very efficient.
total += max(_max_length_from_subpattern_local(branch) for branch in branches)
elif token in TOKEN_REPEAT:
# value = (min, max, subpattern)
max_num_repeat = value[1]
inner_subpattern = value[2]
if max_num_repeat == sre_parse.MAXREPEAT:
total += MAX_LEN
else:
total += max_num_repeat * _max_length_from_subpattern(inner_subpattern)
total += max_num_repeat * _max_length_from_subpattern_local(inner_subpattern)
elif token == sre_parse.AT:
# These are zero-width assertions like ^, $, and \b that don't add to the max
# length
total += 0
# Zero-width assertion
pass # total += 0 is pointless
else:
logger.warning(f"Got unhandled regex token: {token}")

Expand Down