⚡️ Speed up function _assert_all_same by 90%
#336
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
📄 90% (0.90x) speedup for
_assert_all_sameinpython/sglang/srt/operations_strategy.py⏱️ Runtime :
422 microseconds→222 microseconds(best of214runs)📝 Explanation and details
The optimization replaces Python's built-in
all()function with an explicit for loop, achieving an 89% speedup by eliminating the overhead of generator expressions and function calls.Key Changes:
all(item == items[0] for item in items)creates a generator that evaluates each comparison lazilyitems[0]in a variable to avoid repeated indexingWhy This is Faster:
all()is a built-in function call, while the for loop executes at bytecode levelPerformance Impact Analysis:
Based on the function reference,
_assert_all_sameis called inOperationsStrategy.concat()for validating configuration consistency across multiple strategy objects. The 89% speedup is particularly beneficial because:Test Case Performance:
✅ Correctness verification report:
🌀 Generated Regression Tests and Runtime
from typing import List
imports
import pytest # used for our unit tests
from sglang.srt.operations_strategy import _assert_all_same
unit tests
1. Basic Test Cases
def test_all_same_integers():
# All integers are the same
codeflash_output = _assert_all_same([1, 1, 1]) # 1.39μs -> 662ns (110% faster)
def test_all_same_strings():
# All strings are the same
codeflash_output = _assert_all_same(["foo", "foo", "foo"]) # 1.37μs -> 645ns (112% faster)
def test_all_same_floats():
# All floats are the same
codeflash_output = _assert_all_same([3.14, 3.14, 3.14]) # 1.36μs -> 652ns (108% faster)
def test_single_element_list():
# Single element list should return that element
codeflash_output = _assert_all_same([42]) # 1.25μs -> 552ns (127% faster)
def test_two_element_list_same():
# Two elements, both the same
codeflash_output = _assert_all_same(["bar", "bar"]) # 1.26μs -> 615ns (105% faster)
2. Edge Test Cases
def test_different_integers_raises():
# Not all integers are the same
with pytest.raises(AssertionError):
_assert_all_same([1, 2, 1]) # 2.26μs -> 1.28μs (75.9% faster)
def test_different_strings_raises():
# Not all strings are the same
with pytest.raises(AssertionError):
_assert_all_same(["foo", "bar", "foo"]) # 1.97μs -> 1.13μs (74.0% faster)
def test_different_types_raises():
# Different types in the list
with pytest.raises(AssertionError):
_assert_all_same([1, "1", 1.0]) # 1.87μs -> 1.06μs (76.8% faster)
def test_none_values_same():
# All values are None
codeflash_output = _assert_all_same([None, None, None]) # 1.47μs -> 749ns (96.4% faster)
def test_none_and_non_none_raises():
# Mixed None and non-None
with pytest.raises(AssertionError):
_assert_all_same([None, 0, None]) # 1.85μs -> 1.09μs (69.9% faster)
def test_all_same_tuple_objects():
# All tuple objects are the same
t = (1, 2)
codeflash_output = _assert_all_same([t, t, t]) # 1.42μs -> 714ns (98.5% faster)
def test_list_of_lists_same():
# All lists are the same object/value
l = [1, 2]
codeflash_output = _assert_all_same([l, l, l]) # 1.43μs -> 646ns (121% faster)
def test_list_of_lists_equal_but_not_same_object():
# Lists with equal values but not same object
codeflash_output = _assert_all_same([[1,2], [1,2], [1,2]]) # 1.35μs -> 671ns (101% faster)
def test_list_with_falsey_values():
# All values are False
codeflash_output = _assert_all_same([False, False, False]) # 1.38μs -> 648ns (114% faster)
def test_list_with_truthy_values():
# All values are True
codeflash_output = _assert_all_same([True, True, True]) # 1.28μs -> 619ns (106% faster)
def test_list_with_mixed_truthy_falsey_raises():
# Mixed True and False
with pytest.raises(AssertionError):
_assert_all_same([True, False, True]) # 1.84μs -> 1.13μs (63.4% faster)
def test_list_with_custom_objects_equal():
# Custom objects with eq defined
class Foo:
def eq(self, other):
return isinstance(other, Foo)
a = Foo()
b = Foo()
c = Foo()
def test_list_with_custom_objects_not_equal():
# Custom objects with eq returning False
class Foo:
def eq(self, other):
return False
a = Foo()
b = Foo()
c = Foo()
with pytest.raises(AssertionError):
_assert_all_same([a, b, c]) # 2.17μs -> 1.35μs (61.1% faster)
def test_list_with_mutable_objects_equal():
# Different list objects with same contents
codeflash_output = _assert_all_same([[1], [1], [1]]) # 1.40μs -> 732ns (91.8% faster)
def test_list_with_mutable_objects_not_equal():
# Different list objects with different contents
with pytest.raises(AssertionError):
_assert_all_same([[1], [2], [1]]) # 1.83μs -> 1.15μs (59.5% faster)
3. Large Scale Test Cases
def test_large_list_all_same():
# Large list of same values
large_list = [999] * 1000
codeflash_output = _assert_all_same(large_list) # 29.4μs -> 14.4μs (104% faster)
def test_large_list_all_same_strings():
# Large list of same string
large_list = ["big"] * 1000
codeflash_output = _assert_all_same(large_list) # 30.2μs -> 15.5μs (94.3% faster)
def test_large_list_one_difference_raises():
# Large list where one value is different
large_list = [10] * 999 + [11]
with pytest.raises(AssertionError):
_assert_all_same(large_list) # 29.9μs -> 15.0μs (99.0% faster)
def test_large_list_all_same_none():
# Large list of None
large_list = [None] * 1000
codeflash_output = _assert_all_same(large_list) # 29.4μs -> 14.7μs (99.6% faster)
def test_large_list_with_custom_objects_equal():
# Large list of custom objects that compare equal
class Foo:
def eq(self, other):
return isinstance(other, Foo)
large_list = [Foo() for _ in range(1000)]
def test_large_list_with_custom_objects_not_equal():
# Large list of custom objects that compare not equal
class Foo:
def eq(self, other):
return False
large_list = [Foo() for _ in range(1000)]
with pytest.raises(AssertionError):
_assert_all_same(large_list) # 2.22μs -> 1.42μs (56.1% faster)
codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
from typing import List
imports
import pytest # used for our unit tests
from sglang.srt.operations_strategy import _assert_all_same
unit tests
-------------------------------
BASIC TEST CASES
-------------------------------
def test_all_integers_same():
# All elements are the same integer
codeflash_output = _assert_all_same([1, 1, 1]) # 1.51μs -> 733ns (105% faster)
def test_all_strings_same():
# All elements are the same string
codeflash_output = _assert_all_same(["a", "a", "a"]) # 1.39μs -> 676ns (106% faster)
def test_single_element():
# List with a single element should return that element
codeflash_output = _assert_all_same([42]) # 1.26μs -> 569ns (121% faster)
def test_all_floats_same():
# All elements are the same float
codeflash_output = _assert_all_same([2.5, 2.5, 2.5]) # 1.36μs -> 643ns (111% faster)
def test_all_booleans_same():
# All elements are the same boolean
codeflash_output = _assert_all_same([True, True, True]) # 1.33μs -> 651ns (105% faster)
def test_all_tuples_same():
# All elements are the same tuple
codeflash_output = _assert_all_same([(1,2), (1,2), (1,2)]) # 1.32μs -> 655ns (102% faster)
-------------------------------
EDGE TEST CASES
-------------------------------
def test_different_integers_raises():
# Not all elements are the same integer
with pytest.raises(AssertionError):
_assert_all_same([1, 2, 1]) # 2.28μs -> 1.28μs (78.2% faster)
def test_different_types_raises():
# Elements of different types but same value
with pytest.raises(AssertionError):
_assert_all_same([1, "1", 1.0]) # 1.96μs -> 1.18μs (66.3% faster)
def test_first_element_different():
# First element is different
with pytest.raises(AssertionError):
_assert_all_same([2, 1, 1]) # 1.80μs -> 1.10μs (64.2% faster)
def test_last_element_different():
# Last element is different
with pytest.raises(AssertionError):
_assert_all_same([1, 1, 2]) # 1.70μs -> 1.12μs (52.0% faster)
def test_middle_element_different():
# Middle element is different
with pytest.raises(AssertionError):
_assert_all_same([1, 2, 1]) # 1.69μs -> 1.00μs (68.3% faster)
def test_all_none_same():
# All elements are None
codeflash_output = _assert_all_same([None, None, None]) # 1.37μs -> 689ns (98.7% faster)
def test_none_and_value_raises():
# List contains None and a value
with pytest.raises(AssertionError):
_assert_all_same([None, 0, None]) # 1.77μs -> 1.07μs (65.7% faster)
def test_all_empty_lists_same():
# All elements are empty lists
codeflash_output = _assert_all_same([[], [], []]) # 1.37μs -> 724ns (89.5% faster)
def test_empty_and_nonempty_list_raises():
# List contains empty and non-empty lists
with pytest.raises(AssertionError):
_assert_all_same([[], [1], []]) # 1.80μs -> 1.04μs (73.7% faster)
def test_all_dicts_same():
# All elements are the same dict
d = {"a": 1}
codeflash_output = _assert_all_same([d, d, d]) # 1.66μs -> 939ns (76.4% faster)
def test_dicts_equal_but_different_objects():
# Dicts with same content but different objects
codeflash_output = _assert_all_same([{"a": 1}, {"a": 1}, {"a": 1}]) # 1.50μs -> 757ns (98.2% faster)
def test_dicts_different_content_raises():
# Dicts with different content
with pytest.raises(AssertionError):
_assert_all_same([{"a": 1}, {"a": 2}, {"a": 1}]) # 2.05μs -> 1.28μs (59.7% faster)
def test_all_sets_same():
# All elements are the same set
codeflash_output = _assert_all_same([{1,2}, {1,2}, {1,2}]) # 1.79μs -> 1.02μs (74.9% faster)
def test_sets_different_order():
# Sets with same elements in different order
codeflash_output = _assert_all_same([{2,1}, {1,2}, {2,1}]) # 1.40μs -> 812ns (72.0% faster)
def test_sets_different_content_raises():
# Sets with different content
with pytest.raises(AssertionError):
_assert_all_same([{1,2}, {2,3}, {1,2}]) # 1.98μs -> 1.29μs (53.1% faster)
def test_nested_lists_same():
# All elements are the same nested list
codeflash_output = _assert_all_same([[1, [2]], [1, [2]], [1, [2]]]) # 1.41μs -> 815ns (73.5% faster)
def test_nested_lists_different_raises():
# Nested lists with different inner values
with pytest.raises(AssertionError):
_assert_all_same([[1, [2]], [1, [3]], [1, [2]]]) # 1.97μs -> 1.20μs (63.5% faster)
-------------------------------
LARGE SCALE TEST CASES
-------------------------------
def test_large_list_all_same():
# Large list of same value
large_list = [7] * 1000
codeflash_output = _assert_all_same(large_list) # 29.3μs -> 14.5μs (102% faster)
def test_large_list_all_strings_same():
# Large list of same string
large_list = ["test"] * 999
codeflash_output = _assert_all_same(large_list) # 30.1μs -> 15.5μs (94.8% faster)
def test_large_list_one_different_raises():
# Large list, one element different
large_list = [0] * 999
large_list[500] = 1
with pytest.raises(AssertionError):
_assert_all_same(large_list) # 16.0μs -> 8.09μs (97.6% faster)
def test_large_list_first_element_different_raises():
# Large list, first element different
large_list = [1] + [0] * 999
with pytest.raises(AssertionError):
_assert_all_same(large_list) # 1.79μs -> 1.05μs (70.4% faster)
def test_large_list_last_element_different_raises():
# Large list, last element different
large_list = [0] * 999 + [1]
with pytest.raises(AssertionError):
_assert_all_same(large_list) # 29.6μs -> 14.9μs (98.7% faster)
def test_large_list_all_none():
# Large list of None
large_list = [None] * 1000
codeflash_output = _assert_all_same(large_list) # 29.4μs -> 14.7μs (100% faster)
def test_large_list_all_empty_lists():
# Large list of empty lists
large_list = [[] for _ in range(1000)]
codeflash_output = _assert_all_same(large_list) # 31.0μs -> 16.3μs (89.4% faster)
def test_large_list_of_tuples():
# Large list of same tuple
large_list = [(1,2)] * 1000
codeflash_output = _assert_all_same(large_list) # 29.8μs -> 17.3μs (72.7% faster)
def test_large_list_of_dicts_equal_but_different_objects():
# Large list of dicts with same content but different objects
large_list = [{"a": 1} for _ in range(1000)]
codeflash_output = _assert_all_same(large_list) # 36.0μs -> 21.5μs (67.3% faster)
codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
To edit these changes
git checkout codeflash/optimize-_assert_all_same-mhtwf5l2and push.