⚡️ Speed up method CompressedTensorsConfig._is_fp8_w8a16 by 26%
#339
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
📄 26% (0.26x) speedup for
CompressedTensorsConfig._is_fp8_w8a16inpython/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py⏱️ Runtime :
1.01 milliseconds→804 microseconds(best of103runs)📝 Explanation and details
The optimized code achieves a 25% speedup by restructuring the conditional logic and eliminating expensive operations in the
_is_fp8_w8a16method.Key optimizations applied:
Early return pattern: The optimized version replaces the complex nested boolean logic with a single conditional that returns
Truewhen all conditions are met, andFalseotherwise. This eliminates the need to store intermediate boolean variables (is_symmetric_weight,is_static_weight,is_per_tensor_or_channel_weight).Strategy comparison optimization: Instead of creating a list
[QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL]and using theinoperator (which requires list allocation and linear search), the code uses directiscomparisons withorlogic:strategy is QuantizationStrategy.TENSOR or strategy is QuantizationStrategy.CHANNEL.Variable caching: The
weight_quant.strategyis cached in a local variable to avoid repeated attribute access.Why this leads to speedup:
isare faster than list membership testsPerformance characteristics based on test results:
weight_quant is Noneor wrong type) show modest improvements (1-2%)This optimization is especially beneficial for quantization validation workflows where this method may be called frequently during model initialization or inference setup.
✅ Correctness verification report:
🌀 Generated Regression Tests and Runtime
import pytest
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
CompressedTensorsConfig
Minimal stubs for dependencies
class QuantizationType:
FLOAT = "float"
INT = "int"
# Could add more types as needed
class QuantizationStrategy:
TENSOR = "tensor"
CHANNEL = "channel"
BLOCK = "block"
# Could add more strategies as needed
class BaseModel:
"""A minimal stub to simulate a quantization config for weights/inputs."""
def init(self, type, symmetric, dynamic, strategy):
self.type = type
self.symmetric = symmetric
self.dynamic = dynamic
self.strategy = strategy
class QuantizationConfig:
def init(self):
self.packed_modules_mapping = {}
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
CompressedTensorsConfig
-------------------- UNIT TESTS --------------------
@pytest.fixture
def config():
# Provide a minimal config instance for all tests
return CompressedTensorsConfig(
target_scheme_map={},
ignore=[],
quant_format="",
sparsity_scheme_map={},
sparsity_ignore_list=[]
)
-------------------- BASIC TEST CASES --------------------
def test_fp8_w8a16_true_tensor_strategy(config):
# All conditions met: FLOAT, symmetric, static, TENSOR strategy
weight_quant = BaseModel(
type=QuantizationType.FLOAT,
symmetric=True,
dynamic=False,
strategy=QuantizationStrategy.TENSOR
)
input_quant = None # Not used
codeflash_output = config._is_fp8_w8a16(weight_quant, input_quant) # 2.21μs -> 2.17μs (1.52% faster)
def test_fp8_w8a16_true_channel_strategy(config):
# All conditions met: FLOAT, symmetric, static, CHANNEL strategy
weight_quant = BaseModel(
type=QuantizationType.FLOAT,
symmetric=True,
dynamic=False,
strategy=QuantizationStrategy.CHANNEL
)
input_quant = None
codeflash_output = config._is_fp8_w8a16(weight_quant, input_quant) # 1.68μs -> 1.72μs (1.98% slower)
def test_fp8_w8a16_false_weight_quant_none(config):
# weight_quant is None
weight_quant = None
input_quant = None
codeflash_output = config._is_fp8_w8a16(weight_quant, input_quant) # 419ns -> 414ns (1.21% faster)
def test_fp8_w8a16_false_type_not_float(config):
# type is not FLOAT
weight_quant = BaseModel(
type=QuantizationType.INT,
symmetric=True,
dynamic=False,
strategy=QuantizationStrategy.TENSOR
)
input_quant = None
codeflash_output = config._is_fp8_w8a16(weight_quant, input_quant) # 808ns -> 988ns (18.2% slower)
def test_fp8_w8a16_false_not_symmetric(config):
# symmetric is False
weight_quant = BaseModel(
type=QuantizationType.FLOAT,
symmetric=False,
dynamic=False,
strategy=QuantizationStrategy.TENSOR
)
input_quant = None
codeflash_output = config._is_fp8_w8a16(weight_quant, input_quant) # 1.55μs -> 1.12μs (38.3% faster)
def test_fp8_w8a16_false_dynamic(config):
# dynamic is True (not static)
weight_quant = BaseModel(
type=QuantizationType.FLOAT,
symmetric=True,
dynamic=True,
strategy=QuantizationStrategy.TENSOR
)
input_quant = None
codeflash_output = config._is_fp8_w8a16(weight_quant, input_quant) # 1.62μs -> 1.19μs (36.2% faster)
def test_fp8_w8a16_false_strategy_not_tensor_or_channel(config):
# strategy is BLOCK, which is not allowed
weight_quant = BaseModel(
type=QuantizationType.FLOAT,
symmetric=True,
dynamic=False,
strategy=QuantizationStrategy.BLOCK
)
input_quant = None
codeflash_output = config._is_fp8_w8a16(weight_quant, input_quant) # 1.58μs -> 1.77μs (10.8% slower)
-------------------- EDGE TEST CASES --------------------
def test_fp8_w8a16_false_strategy_none(config):
# strategy is None
weight_quant = BaseModel(
type=QuantizationType.FLOAT,
symmetric=True,
dynamic=False,
strategy=None
)
input_quant = None
codeflash_output = config._is_fp8_w8a16(weight_quant, input_quant) # 1.48μs -> 1.55μs (4.52% slower)
def test_fp8_w8a16_false_type_none(config):
# type is None
weight_quant = BaseModel(
type=None,
symmetric=True,
dynamic=False,
strategy=QuantizationStrategy.TENSOR
)
input_quant = None
codeflash_output = config._is_fp8_w8a16(weight_quant, input_quant) # 839ns -> 999ns (16.0% slower)
def test_fp8_w8a16_false_symmetric_none(config):
# symmetric is None (should be treated as False)
weight_quant = BaseModel(
type=QuantizationType.FLOAT,
symmetric=None,
dynamic=False,
strategy=QuantizationStrategy.TENSOR
)
input_quant = None
codeflash_output = config._is_fp8_w8a16(weight_quant, input_quant) # 1.62μs -> 1.08μs (50.2% faster)
def test_fp8_w8a16_false_dynamic_none(config):
# dynamic is None (should be treated as not False)
weight_quant = BaseModel(
type=QuantizationType.FLOAT,
symmetric=True,
dynamic=None,
strategy=QuantizationStrategy.TENSOR
)
input_quant = None
# not weight_quant.dynamic will be True, so is_static_weight = True
# But dynamic=None is not a valid config, so let's see behavior
codeflash_output = config._is_fp8_w8a16(weight_quant, input_quant) # 1.58μs -> 1.75μs (9.88% slower)
def test_fp8_w8a16_false_strategy_unexpected_string(config):
# strategy is an unexpected string
weight_quant = BaseModel(
type=QuantizationType.FLOAT,
symmetric=True,
dynamic=False,
strategy="unexpected"
)
input_quant = None
codeflash_output = config._is_fp8_w8a16(weight_quant, input_quant) # 1.48μs -> 1.63μs (9.26% slower)
def test_fp8_w8a16_false_all_false(config):
# All fields are set to falsy values
weight_quant = BaseModel(
type=None,
symmetric=False,
dynamic=True,
strategy=None
)
input_quant = None
codeflash_output = config._is_fp8_w8a16(weight_quant, input_quant) # 751ns -> 920ns (18.4% slower)
def test_fp8_w8a16_input_quant_irrelevant(config):
# input_quant is set but should not affect result
weight_quant = BaseModel(
type=QuantizationType.FLOAT,
symmetric=True,
dynamic=False,
strategy=QuantizationStrategy.TENSOR
)
input_quant = BaseModel(
type=QuantizationType.INT,
symmetric=False,
dynamic=True,
strategy=QuantizationStrategy.BLOCK
)
codeflash_output = config._is_fp8_w8a16(weight_quant, input_quant) # 1.60μs -> 1.57μs (2.23% faster)
def test_fp8_w8a16_weight_quant_extra_attributes(config):
# weight_quant has extra attributes, should be ignored
class ExtendedBaseModel(BaseModel):
def init(self, *args, **kwargs):
super().init(*args, **kwargs)
self.extra = "extra"
weight_quant = ExtendedBaseModel(
type=QuantizationType.FLOAT,
symmetric=True,
dynamic=False,
strategy=QuantizationStrategy.CHANNEL
)
input_quant = None
codeflash_output = config._is_fp8_w8a16(weight_quant, input_quant) # 1.73μs -> 1.66μs (3.73% faster)
-------------------- LARGE SCALE TEST CASES --------------------
def test_fp8_w8a16_many_true(config):
# Test many valid configs in a row (scalability, deterministic)
for i in range(100):
weight_quant = BaseModel(
type=QuantizationType.FLOAT,
symmetric=True,
dynamic=False,
strategy=QuantizationStrategy.TENSOR if i % 2 == 0 else QuantizationStrategy.CHANNEL
)
input_quant = None
codeflash_output = config._is_fp8_w8a16(weight_quant, input_quant) # 46.2μs -> 40.6μs (13.9% faster)
def test_fp8_w8a16_many_false(config):
# Test many invalid configs in a row (scalability, deterministic)
for i in range(100):
# Vary which field is invalid
if i % 4 == 0:
weight_quant = BaseModel(
type=QuantizationType.INT,
symmetric=True,
dynamic=False,
strategy=QuantizationStrategy.TENSOR
)
elif i % 4 == 1:
weight_quant = BaseModel(
type=QuantizationType.FLOAT,
symmetric=False,
dynamic=False,
strategy=QuantizationStrategy.TENSOR
)
elif i % 4 == 2:
weight_quant = BaseModel(
type=QuantizationType.FLOAT,
symmetric=True,
dynamic=True,
strategy=QuantizationStrategy.TENSOR
)
else:
weight_quant = BaseModel(
type=QuantizationType.FLOAT,
symmetric=True,
dynamic=False,
strategy=QuantizationStrategy.BLOCK
)
input_quant = None
codeflash_output = config._is_fp8_w8a16(weight_quant, input_quant) # 41.7μs -> 33.3μs (25.1% faster)
def test_fp8_w8a16_large_variety(config):
# Test a large variety of combinations (combinatorial coverage)
types = [QuantizationType.FLOAT, QuantizationType.INT, None]
symmetrics = [True, False, None]
dynamics = [False, True, None]
strategies = [QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL, QuantizationStrategy.BLOCK, None]
count = 0
for t in types:
for s in symmetrics:
for d in dynamics:
for st in strategies:
weight_quant = BaseModel(
type=t,
symmetric=s,
dynamic=d,
strategy=st
)
input_quant = None
# Only one combination should be True:
# FLOAT, True, False, TENSOR or CHANNEL
expected = (
t == QuantizationType.FLOAT and
s is True and
d is False and
st in (QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL)
)
codeflash_output = config._is_fp8_w8a16(weight_quant, input_quant)
count += 1
def test_fp8_w8a16_performance_large_batch(config):
# Simulate a batch of 500 configs, all should be True
weight_quants = [
BaseModel(
type=QuantizationType.FLOAT,
symmetric=True,
dynamic=False,
strategy=QuantizationStrategy.TENSOR if i % 2 == 0 else QuantizationStrategy.CHANNEL
)
for i in range(500)
]
input_quants = [None] * 500
results = [config._is_fp8_w8a16(wq, iq) for wq, iq in zip(weight_quants, input_quants)]
def test_fp8_w8a16_performance_large_batch_false(config):
# Simulate a batch of 500 configs, all should be False (invalid types)
weight_quants = [
BaseModel(
type=QuantizationType.INT,
symmetric=True,
dynamic=False,
strategy=QuantizationStrategy.TENSOR
)
for _ in range(500)
]
input_quants = [None] * 500
results = [config._is_fp8_w8a16(wq, iq) for wq, iq in zip(weight_quants, input_quants)]
codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
import pytest
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
CompressedTensorsConfig
--- Minimal stubs for dependencies (to avoid using external packages) ---
QuantizationType enum stub
class QuantizationType:
FLOAT = "float"
INT8 = "int8"
INT4 = "int4"
# Add other types as needed
QuantizationStrategy enum stub
class QuantizationStrategy:
TENSOR = "tensor"
CHANNEL = "channel"
BLOCK = "block"
# Add other strategies as needed
BaseModel stub (since pydantic.BaseModel is not used for logic here)
class BaseModel:
def init(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
QuantizationConfig stub
class QuantizationConfig:
def init(self):
self.packed_modules_mapping = dict()
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
CompressedTensorsConfig
--- Fixtures and helpers ---
@pytest.fixture
def config():
# Provide a minimal config for CompressedTensorsConfig
return CompressedTensorsConfig(
target_scheme_map={},
ignore=[],
quant_format="test",
sparsity_scheme_map={},
sparsity_ignore_list=[],
)
--- Basic Test Cases ---
def test_fp8_w8a16_true_tensor_strategy(config):
# Basic case: all conditions satisfied with TENSOR strategy
weight_quant = BaseModel(
type=QuantizationType.FLOAT,
symmetric=True,
dynamic=False,
strategy=QuantizationStrategy.TENSOR,
)
input_quant = BaseModel() # Not used in function
codeflash_output = config._is_fp8_w8a16(weight_quant, input_quant) # 2.14μs -> 2.13μs (0.281% faster)
def test_fp8_w8a16_true_channel_strategy(config):
# Basic case: all conditions satisfied with CHANNEL strategy
weight_quant = BaseModel(
type=QuantizationType.FLOAT,
symmetric=True,
dynamic=False,
strategy=QuantizationStrategy.CHANNEL,
)
input_quant = BaseModel()
codeflash_output = config._is_fp8_w8a16(weight_quant, input_quant) # 1.68μs -> 1.68μs (0.000% faster)
def test_fp8_w8a16_false_weight_quant_none(config):
# weight_quant is None
weight_quant = None
input_quant = BaseModel()
codeflash_output = config._is_fp8_w8a16(weight_quant, input_quant) # 429ns -> 434ns (1.15% slower)
def test_fp8_w8a16_false_type_not_float(config):
# type is not FLOAT
weight_quant = BaseModel(
type=QuantizationType.INT8,
symmetric=True,
dynamic=False,
strategy=QuantizationStrategy.TENSOR,
)
input_quant = BaseModel()
codeflash_output = config._is_fp8_w8a16(weight_quant, input_quant) # 787ns -> 1.03μs (23.7% slower)
def test_fp8_w8a16_false_symmetric_false(config):
# symmetric is False
weight_quant = BaseModel(
type=QuantizationType.FLOAT,
symmetric=False,
dynamic=False,
strategy=QuantizationStrategy.TENSOR,
)
input_quant = BaseModel()
codeflash_output = config._is_fp8_w8a16(weight_quant, input_quant) # 1.64μs -> 1.16μs (41.5% faster)
def test_fp8_w8a16_false_dynamic_true(config):
# dynamic is True (should be static)
weight_quant = BaseModel(
type=QuantizationType.FLOAT,
symmetric=True,
dynamic=True,
strategy=QuantizationStrategy.TENSOR,
)
input_quant = BaseModel()
codeflash_output = config._is_fp8_w8a16(weight_quant, input_quant) # 1.60μs -> 1.17μs (36.7% faster)
def test_fp8_w8a16_false_strategy_block(config):
# strategy is not TENSOR or CHANNEL
weight_quant = BaseModel(
type=QuantizationType.FLOAT,
symmetric=True,
dynamic=False,
strategy=QuantizationStrategy.BLOCK,
)
input_quant = BaseModel()
codeflash_output = config._is_fp8_w8a16(weight_quant, input_quant) # 1.52μs -> 1.69μs (9.93% slower)
--- Edge Test Cases ---
def test_fp8_w8a16_strategy_case_sensitivity(config):
# strategy string is lowercased and doesn't match enum (should fail)
weight_quant = BaseModel(
type=QuantizationType.FLOAT,
symmetric=True,
dynamic=False,
strategy="tensor", # not QuantizationStrategy.TENSOR
)
input_quant = BaseModel()
codeflash_output = config._is_fp8_w8a16(weight_quant, input_quant) # 2.26μs -> 2.12μs (6.51% faster)
def test_fp8_w8a16_input_quant_irrelevant(config):
# input_quant is None, should not affect result
weight_quant = BaseModel(
type=QuantizationType.FLOAT,
symmetric=True,
dynamic=False,
strategy=QuantizationStrategy.TENSOR,
)
input_quant = None
codeflash_output = config._is_fp8_w8a16(weight_quant, input_quant) # 1.66μs -> 1.70μs (1.89% slower)
def test_fp8_w8a16_extra_attributes(config):
# weight_quant has extra attributes, should not affect result
weight_quant = BaseModel(
type=QuantizationType.FLOAT,
symmetric=True,
dynamic=False,
strategy=QuantizationStrategy.CHANNEL,
extra1=123,
extra2="abc",
)
input_quant = BaseModel()
codeflash_output = config._is_fp8_w8a16(weight_quant, input_quant) # 1.55μs -> 1.58μs (2.03% slower)
def test_fp8_w8a16_unexpected_types(config):
# weight_quant fields are wrong types (should fail gracefully)
weight_quant = BaseModel(
type=123, # not a string
symmetric="yes", # not a bool
dynamic="no", # not a bool
strategy=456, # not a valid strategy
)
input_quant = BaseModel()
codeflash_output = config._is_fp8_w8a16(weight_quant, input_quant) # 801ns -> 985ns (18.7% slower)
def test_fp8_w8a16_all_false(config):
# All relevant fields are set to "falsy" values
weight_quant = BaseModel(
type=None,
symmetric=False,
dynamic=True,
strategy=None,
)
input_quant = BaseModel()
codeflash_output = config._is_fp8_w8a16(weight_quant, input_quant) # 832ns -> 954ns (12.8% slower)
--- Large Scale Test Cases ---
def test_fp8_w8a16_large_batch_true(config):
# Test with a large batch of valid configs
weight_quants = [
BaseModel(
type=QuantizationType.FLOAT,
symmetric=True,
dynamic=False,
strategy=QuantizationStrategy.TENSOR if i % 2 == 0 else QuantizationStrategy.CHANNEL,
)
for i in range(500)
]
input_quant = BaseModel()
for wq in weight_quants:
codeflash_output = config._is_fp8_w8a16(wq, input_quant) # 225μs -> 198μs (13.6% faster)
def test_fp8_w8a16_large_batch_false(config):
# Test with a large batch of invalid configs (dynamic=True)
weight_quants = [
BaseModel(
type=QuantizationType.FLOAT,
symmetric=True,
dynamic=True,
strategy=QuantizationStrategy.TENSOR,
)
for _ in range(500)
]
input_quant = BaseModel()
for wq in weight_quants:
codeflash_output = config._is_fp8_w8a16(wq, input_quant) # 216μs -> 144μs (49.2% faster)
def test_fp8_w8a16_mixed_large_batch(config):
# Test with a mix of valid and invalid configs
weight_quants = []
expected = []
for i in range(1000):
if i % 4 == 0:
# Valid
wq = BaseModel(
type=QuantizationType.FLOAT,
symmetric=True,
dynamic=False,
strategy=QuantizationStrategy.CHANNEL,
)
weight_quants.append(wq)
expected.append(True)
elif i % 4 == 1:
# Invalid: type wrong
wq = BaseModel(
type=QuantizationType.INT8,
symmetric=True,
dynamic=False,
strategy=QuantizationStrategy.CHANNEL,
)
weight_quants.append(wq)
expected.append(False)
elif i % 4 == 2:
# Invalid: symmetric False
wq = BaseModel(
type=QuantizationType.FLOAT,
symmetric=False,
dynamic=False,
strategy=QuantizationStrategy.CHANNEL,
)
weight_quants.append(wq)
expected.append(False)
else:
# Invalid: dynamic True
wq = BaseModel(
type=QuantizationType.FLOAT,
symmetric=True,
dynamic=True,
strategy=QuantizationStrategy.CHANNEL,
)
weight_quants.append(wq)
expected.append(False)
input_quant = BaseModel()
for wq, exp in zip(weight_quants, expected):
codeflash_output = config._is_fp8_w8a16(wq, input_quant) # 407μs -> 318μs (28.0% faster)
To edit these changes
git checkout codeflash/optimize-CompressedTensorsConfig._is_fp8_w8a16-mhtyambjand push.