Skip to content

Commit 13698db

Browse files
authored
Improve configs - ModelConfig (#17130)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
1 parent 2c4f59a commit 13698db

36 files changed

+492
-650
lines changed

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -738,7 +738,7 @@ class VllmRunner:
738738
- `block_size`: Set to `16` instead of `None` to reduce memory usage.
739739
- `enable_chunked_prefill`: Set to `False` instead of `None` for
740740
test reproducibility.
741-
- `enforce_eager`: Set to `False` instead of `None` to test CUDA graph.
741+
- `enforce_eager`: Set to `False` to test CUDA graph.
742742
"""
743743

744744
def __init__(

tests/engine/test_arg_utils.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import pytest
1010

11-
from vllm.config import PoolerConfig, config
11+
from vllm.config import config
1212
from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs,
1313
get_type, is_not_builtin, is_type,
1414
literal_to_kwargs, nullable_kvs,
@@ -222,17 +222,6 @@ def test_prefix_cache_default():
222222
assert not engine_args.enable_prefix_caching
223223

224224

225-
def test_valid_pooling_config():
226-
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
227-
args = parser.parse_args([
228-
'--override-pooler-config',
229-
'{"pooling_type": "MEAN"}',
230-
])
231-
engine_args = EngineArgs.from_cli_args(args=args)
232-
assert engine_args.override_pooler_config == PoolerConfig(
233-
pooling_type="MEAN", )
234-
235-
236225
@pytest.mark.parametrize(
237226
("arg"),
238227
[

tests/quantization/test_register_quantization_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from vllm.model_executor.layers.linear import LinearBase # noqa: E501
1515
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
1616
from vllm.model_executor.layers.quantization import (
17-
get_quantization_config, register_quantization_config)
17+
QuantizationMethods, get_quantization_config, register_quantization_config)
1818
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
1919
QuantizationConfig)
2020

@@ -54,7 +54,7 @@ def __init__(self, num_bits: int = 8) -> None:
5454
"""Initialize the quantization config."""
5555
self.num_bits = num_bits
5656

57-
def get_name(self) -> str:
57+
def get_name(self) -> QuantizationMethods:
5858
"""Name of the quantization method."""
5959
return "custom_quant"
6060

tests/test_config.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def test_get_pooling_config():
185185
revision=None,
186186
)
187187

188-
pooling_config = model_config._init_pooler_config(None)
188+
pooling_config = model_config._init_pooler_config()
189189
assert pooling_config is not None
190190

191191
assert pooling_config.normalize
@@ -205,11 +205,12 @@ def test_get_pooling_config_from_args():
205205
dtype="float16",
206206
revision=None)
207207

208-
override_config = PoolerConfig(pooling_type='CLS', normalize=True)
208+
override_pooler_config = PoolerConfig(pooling_type='CLS', normalize=True)
209+
model_config.override_pooler_config = override_pooler_config
209210

210-
pooling_config = model_config._init_pooler_config(override_config)
211+
pooling_config = model_config._init_pooler_config()
211212
assert pooling_config is not None
212-
assert asdict(pooling_config) == asdict(override_config)
213+
assert asdict(pooling_config) == asdict(override_pooler_config)
213214

214215

215216
@pytest.mark.skipif(current_platform.is_rocm(),

vllm/config.py

Lines changed: 258 additions & 255 deletions
Large diffs are not rendered by default.

vllm/engine/arg_utils.py

Lines changed: 137 additions & 314 deletions
Large diffs are not rendered by default.

vllm/entrypoints/llm.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
1515
BeamSearchSequence, get_beam_search_score)
16-
from vllm.config import CompilationConfig
16+
from vllm.config import CompilationConfig, ModelDType, TokenizerMode
1717
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
1818
TaskOption)
1919
from vllm.engine.llm_engine import LLMEngine
@@ -32,6 +32,7 @@
3232
from vllm.lora.request import LoRARequest
3333
from vllm.model_executor.guided_decoding.guided_fields import (
3434
GuidedDecodingRequest, LLMGuidedOptions)
35+
from vllm.model_executor.layers.quantization import QuantizationMethods
3536
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
3637
PoolingRequestOutput, RequestOutput,
3738
ScoringRequestOutput)
@@ -163,20 +164,20 @@ def __init__(
163164
self,
164165
model: str,
165166
tokenizer: Optional[str] = None,
166-
tokenizer_mode: str = "auto",
167+
tokenizer_mode: TokenizerMode = "auto",
167168
skip_tokenizer_init: bool = False,
168169
trust_remote_code: bool = False,
169170
allowed_local_media_path: str = "",
170171
tensor_parallel_size: int = 1,
171-
dtype: str = "auto",
172-
quantization: Optional[str] = None,
172+
dtype: ModelDType = "auto",
173+
quantization: Optional[QuantizationMethods] = None,
173174
revision: Optional[str] = None,
174175
tokenizer_revision: Optional[str] = None,
175176
seed: Optional[int] = None,
176177
gpu_memory_utilization: float = 0.9,
177178
swap_space: float = 4,
178179
cpu_offload_gb: float = 0,
179-
enforce_eager: Optional[bool] = None,
180+
enforce_eager: bool = False,
180181
max_seq_len_to_capture: int = 8192,
181182
disable_custom_all_reduce: bool = False,
182183
disable_async_output_proc: bool = False,
@@ -189,12 +190,7 @@ def __init__(
189190
compilation_config: Optional[Union[int, dict[str, Any]]] = None,
190191
**kwargs,
191192
) -> None:
192-
'''
193-
LLM constructor.
194-
195-
Note: if enforce_eager is unset (enforce_eager is None)
196-
it defaults to False.
197-
'''
193+
"""LLM constructor."""
198194

199195
if "disable_log_stats" not in kwargs:
200196
kwargs["disable_log_stats"] = True

vllm/model_executor/layers/quantization/aqlm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from vllm import _custom_ops as ops
1414
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
15+
from vllm.model_executor.layers.quantization import QuantizationMethods
1516
from vllm.model_executor.layers.quantization.base_config import (
1617
QuantizationConfig)
1718
from vllm.model_executor.utils import set_weight_attrs
@@ -186,7 +187,7 @@ def __repr__(self) -> str:
186187
f"out_group_size={self.out_group_size})")
187188

188189
@classmethod
189-
def get_name(cls) -> str:
190+
def get_name(cls) -> QuantizationMethods:
190191
return "aqlm"
191192

192193
@classmethod

vllm/model_executor/layers/quantization/awq.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from vllm import _custom_ops as ops
88
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
99
UnquantizedLinearMethod)
10+
from vllm.model_executor.layers.quantization import QuantizationMethods
1011
from vllm.model_executor.layers.quantization.base_config import (
1112
QuantizationConfig)
1213
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
@@ -44,7 +45,7 @@ def __repr__(self) -> str:
4445
f"zero_point={self.zero_point}, "
4546
f"modules_to_not_convert={self.modules_to_not_convert})")
4647

47-
def get_name(self) -> str:
48+
def get_name(self) -> QuantizationMethods:
4849
return "awq"
4950

5051
def get_supported_act_dtypes(self) -> List[torch.dtype]:

vllm/model_executor/layers/quantization/awq_marlin.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
1414
UnquantizedLinearMethod,
1515
set_weight_attrs)
16+
from vllm.model_executor.layers.quantization import QuantizationMethods
1617
from vllm.model_executor.layers.quantization.awq import (AWQConfig,
1718
is_layer_skipped_awq)
1819
from vllm.model_executor.layers.quantization.base_config import (
@@ -73,7 +74,7 @@ def __repr__(self) -> str:
7374
f"modules_to_not_convert={self.modules_to_not_convert})")
7475

7576
@classmethod
76-
def get_name(cls) -> str:
77+
def get_name(cls) -> QuantizationMethods:
7778
return "awq_marlin"
7879

7980
@classmethod
@@ -101,8 +102,8 @@ def from_config(cls, config: Dict[str, Any]) -> "AWQMarlinConfig":
101102
modules_to_not_convert, config)
102103

103104
@classmethod
104-
def override_quantization_method(cls, hf_quant_cfg,
105-
user_quant) -> Optional[str]:
105+
def override_quantization_method(
106+
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
106107
can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg)
107108
is_valid_user_quant = (user_quant is None or user_quant == "marlin"
108109
or user_quant == "awq_marlin")

0 commit comments

Comments
 (0)