diff --git a/Dockerfile b/Dockerfile index 0a562253c537b..2c72c6e0ec033 100644 --- a/Dockerfile +++ b/Dockerfile @@ -191,6 +191,9 @@ ADD . /vllm-workspace/ RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install -r requirements-dev.txt +# Copy in the v1 package for testing (it isn't distributed yet) +COPY vllm/v1 /usr/local/lib/python3.12/dist-packages/vllm/v1 + # doc requires source code # we hide them inside `test_docs/` , so that this source code # will not be imported by other tests diff --git a/pyproject.toml b/pyproject.toml index e78f5652f486b..533ab84331371 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -97,4 +97,5 @@ markers = [ "skip_global_cleanup", "core_model: run this model test in each PR instead of just daily", "distributed_2_gpus: run this test only in distributed tests for 2 GPUs", + "skip_v1: do not run this test with v1", ] diff --git a/tests/conftest.py b/tests/conftest.py index bdc6ffb148602..7aeb1905701f9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,7 @@ from enum import Enum from typing import (Any, Callable, Dict, List, Optional, Tuple, Type, TypedDict, TypeVar, Union) +from unittest.mock import patch import numpy as np import pytest @@ -121,6 +122,23 @@ def prompts(self, prompts: _VideoAssetPrompts) -> List[str]: """Singleton instance of :class:`_VideoAssets`.""" +@pytest.fixture(params=[True, False]) +def run_with_both_engines(request): + # Automatically runs tests twice, once with V1 and once without + use_v1 = request.param + # Tests decorated with `@skip_v1` are only run without v1 + skip_v1 = request.node.get_closest_marker("skip_v1") + + if use_v1: + if skip_v1: + pytest.skip("Skipping test on vllm V1") + with patch('vllm.envs.VLLM_USE_V1', True): + yield + else: + with patch('vllm.envs.VLLM_USE_V1', False): + yield + + @pytest.fixture(autouse=True) def init_test_http_connection(): # pytest_asyncio may use a different event loop per test diff --git a/tests/entrypoints/llm/test_prompt_validation.py b/tests/entrypoints/llm/test_prompt_validation.py index 675a980ab3f3f..ee7010a238114 100644 --- a/tests/entrypoints/llm/test_prompt_validation.py +++ b/tests/entrypoints/llm/test_prompt_validation.py @@ -3,12 +3,21 @@ from vllm import LLM +@pytest.fixture(autouse=True) +def v1(run_with_both_engines): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + def test_empty_prompt(): llm = LLM(model="gpt2", enforce_eager=True) with pytest.raises(ValueError, match='Prompt cannot be empty'): llm.generate([""]) +@pytest.mark.skip_v1 def test_out_of_vocab_token(): llm = LLM(model="gpt2", enforce_eager=True) with pytest.raises(ValueError, match='out of vocabulary'): diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 376b3136f0fb8..0323d3bcde226 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -88,7 +88,6 @@ def get_global_forced_attn_backend() -> Optional[_Backend]: return forced_attn_backend -@lru_cache(maxsize=None) def get_attn_backend( head_size: int, dtype: torch.dtype, @@ -98,7 +97,31 @@ def get_attn_backend( is_blocksparse: bool = False, ) -> Type[AttentionBackend]: """Selects which attention backend to use and lazily imports it.""" + # Accessing envs.* behind an @lru_cache decorator can cause the wrong + # value to be returned from the cache if the value changes between calls. + # To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the + # private function. + return _cached_get_attn_backend( + head_size=head_size, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + block_size=block_size, + is_attention_free=is_attention_free, + is_blocksparse=is_blocksparse, + use_v1=envs.VLLM_USE_V1, + ) + +@lru_cache(maxsize=None) +def _cached_get_attn_backend( + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, + is_attention_free: bool, + is_blocksparse: bool = False, + use_v1: bool = False, +) -> Type[AttentionBackend]: if is_blocksparse: logger.info("Using BlocksparseFlashAttention backend.") from vllm.attention.backends.blocksparse_attn import ( @@ -106,7 +129,7 @@ def get_attn_backend( return BlocksparseFlashAttentionBackend backend = which_attn_to_use(head_size, dtype, kv_cache_dtype, block_size, - is_attention_free) + is_attention_free, use_v1) if backend == _Backend.FLASH_ATTN: from vllm.attention.backends.flash_attn import ( # noqa: F401 FlashAttentionBackend) @@ -157,13 +180,9 @@ def get_attn_backend( raise ValueError("Invalid attention backend.") -def which_attn_to_use( - head_size: int, - dtype: torch.dtype, - kv_cache_dtype: Optional[str], - block_size: int, - is_attention_free: bool, -) -> _Backend: +def which_attn_to_use(head_size: int, dtype: torch.dtype, + kv_cache_dtype: Optional[str], block_size: int, + is_attention_free: bool, use_v1: bool) -> _Backend: """Returns which flash attention backend to use.""" # Default case. selected_backend = _Backend.FLASH_ATTN @@ -220,7 +239,7 @@ def which_attn_to_use( logger.info("%s is not supported in AMD GPUs.", selected_backend) return _Backend.ROCM_FLASH - if envs.VLLM_USE_V1: + if use_v1: return _Backend.FLASH_ATTN_VLLM_V1 # FlashAttn in NVIDIA GPUs. diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index a73b4c825b11c..d1d4199a85ec2 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -7,7 +7,9 @@ import zmq from ray.exceptions import RayTaskError +import vllm.envs from vllm import AsyncEngineArgs, SamplingParams +from vllm.engine.llm_engine import LLMEngine # yapf conflicts with isort for this block # yapf: disable from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, @@ -18,17 +20,11 @@ RPCStartupRequest, RPCStartupResponse, RPCUProfileRequest) # yapf: enable -from vllm.envs import VLLM_USE_V1 from vllm.executor.gpu_executor import GPUExecutor from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.usage.usage_lib import UsageContext -if VLLM_USE_V1: - from vllm.v1.engine.llm_engine import LLMEngine -else: - from vllm.engine.llm_engine import LLMEngine - logger = init_logger(__name__) POLLING_TIMEOUT_MS = 10000 @@ -118,11 +114,17 @@ def from_engine_args(cls, engine_args: AsyncEngineArgs, load_general_plugins() engine_config = engine_args.create_engine_config() + if vllm.envs.VLLM_USE_V1: + # Lazy import: the v1 package isn't distributed + from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine + engine_class = V1LLMEngine + else: + engine_class = LLMEngine - executor_class = LLMEngine._get_executor_cls(engine_config) + executor_class = engine_class._get_executor_cls(engine_config) use_async_sockets = (engine_config.model_config.use_async_output_proc - and not VLLM_USE_V1) + and not vllm.envs.VLLM_USE_V1) return cls(ipc_path=ipc_path, use_async_sockets=use_async_sockets, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 3d62cb3598477..22362ae7d7197 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,7 +1,7 @@ import itertools import warnings from contextlib import contextmanager -from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, +from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type, Union, cast, overload) from tqdm import tqdm @@ -10,6 +10,7 @@ from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, BeamSearchSequence, get_beam_search_score) from vllm.engine.arg_utils import EngineArgs, TaskOption +from vllm.engine.llm_engine import LLMEngine from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, apply_hf_chat_template, apply_mistral_chat_template, @@ -31,11 +32,6 @@ from vllm.usage.usage_lib import UsageContext from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of -if envs.VLLM_USE_V1: - from vllm.v1.engine.llm_engine import LLMEngine # type: ignore -else: - from vllm.engine.llm_engine import LLMEngine # type: ignore - logger = init_logger(__name__) @@ -200,10 +196,21 @@ def __init__( pooling_returned_token_ids=pooling_returned_token_ids, **kwargs, ) - self.llm_engine = LLMEngine.from_engine_args( + # Logic to switch between engines is done at runtime instead of import + # to avoid import order issues + self.engine_class = self.get_engine_class() + self.llm_engine = self.engine_class.from_engine_args( engine_args, usage_context=UsageContext.LLM_CLASS) self.request_counter = Counter() + @staticmethod + def get_engine_class() -> Type[LLMEngine]: + if envs.VLLM_USE_V1: + # Lazy import: the v1 package isn't distributed + from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine + return V1LLMEngine # type: ignore + return LLMEngine + def get_tokenizer(self) -> AnyTokenizer: return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer @@ -388,7 +395,7 @@ def generate( priority=priority) outputs = self._run_engine(use_tqdm=use_tqdm) - return LLMEngine.validate_outputs(outputs, RequestOutput) + return self.engine_class.validate_outputs(outputs, RequestOutput) def beam_search( self, @@ -763,7 +770,8 @@ def encode( ) outputs = self._run_engine(use_tqdm=use_tqdm) - return LLMEngine.validate_outputs(outputs, EmbeddingRequestOutput) + return self.engine_class.validate_outputs(outputs, + EmbeddingRequestOutput) def start_profile(self) -> None: self.llm_engine.start_profile() diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index f86c6ec362ebe..c10efefea5471 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -30,6 +30,15 @@ else: flashinfer_top_k_top_p_sampling = None + +def get_sampler() -> torch.nn.Module: + if envs.VLLM_USE_V1: + # Lazy import: the v1 package isn't distributed + from vllm.v1.sample.sampler import Sampler as V1Sampler + return V1Sampler() + return Sampler() + + # (num_token_ids, num_parent_ids) per sequence group. SampleResultType = List[Tuple[List[int], List[int]]] diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index fd29d4ccc59d8..11e902d22874f 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -23,7 +23,7 @@ from vllm.model_executor.layers.quantization.deepspeedfp import ( DeepSpeedFPConfig, DeepSpeedFPParameter) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -436,7 +436,7 @@ def __init__(self, self.unpadded_vocab_size = config.vocab_size self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index f2cfdf8ffd30a..8720931d75c3c 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -38,7 +38,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -353,7 +353,7 @@ def __init__( if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index cbdacf779b089..f6357667eecb7 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -34,7 +34,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -842,7 +842,7 @@ def __init__(self, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() def forward( self, diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index c3b3cc8a4ddb6..e4fa4c6e4b588 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -13,7 +13,7 @@ token_inputs) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import IntermediateTensors, SequenceData @@ -520,7 +520,7 @@ def sampler(self): if hasattr(self.language_model, "sampler"): return self.language_model.sampler - return Sampler() + return get_sampler() def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: h = w = self.config.vision_config.image_size diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 77ab7de6165fb..d3ffb942a38aa 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -34,7 +34,7 @@ RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -299,7 +299,7 @@ def __init__( self.config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index aaf559ca386cc..78744d04779b7 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -21,7 +21,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -936,7 +936,7 @@ def __init__( logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logit_scale) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index ca90d10e9f9fb..2acb962875510 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -25,7 +25,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -618,7 +618,7 @@ def __init__( self.transformer.embedding.weight) self.lm_head = self.transformer.output_layer self.logits_processor = LogitsProcessor(config.padded_vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() def forward(self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 348e6d20f3297..fd6c3537c8f65 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -38,7 +38,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -356,7 +356,7 @@ def __init__( cache_config, quant_config, lora_config=lora_config) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index aae7ab7370b74..2024d18d26463 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -15,7 +15,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -374,7 +374,7 @@ def __init__( ) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 5b4db8f258711..cdcf15a0c1d46 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -42,7 +42,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -400,7 +400,7 @@ def __init__( if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index d4ad0c6b5c99e..1843a89850b58 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -43,7 +43,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -497,7 +497,7 @@ def __init__( config.hidden_size, quant_config=quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index 22f194c776b69..386f10f96352c 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -43,7 +43,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( get_compressed_tensors_cache_scale) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -479,7 +479,7 @@ def __init__( self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logit_scale) - self.sampler = Sampler() + self.sampler = get_sampler() else: self.lm_head = PPMissingLayer() diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 36c85e37783ab..eced1760c28dc 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -39,7 +39,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -429,7 +429,7 @@ def __init__( quant_config=quant_config, ) self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/florence2.py b/vllm/model_executor/models/florence2.py index 6840ac8b9e303..184bee5f65671 100644 --- a/vllm/model_executor/models/florence2.py +++ b/vllm/model_executor/models/florence2.py @@ -10,7 +10,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.bart import (BartDecoder, BartEncoder, BartParallelLMHead, @@ -112,7 +112,7 @@ def __init__(self, self.logits_processor = LogitsProcessor(self.vocab_size, config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() def forward( self, diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 57b2b43c82f89..9d6ddd741f666 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -34,7 +34,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -397,7 +397,7 @@ def __init__( quant_config, prefix=maybe_prefix(prefix, "model")) self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 693f32160a289..acfdf97cb2e08 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -34,7 +34,7 @@ from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -417,7 +417,7 @@ def __init__( self.model = Gemma2Model(config, cache_config, quant_config) self.logits_processor = LogitsProcessor( config.vocab_size, soft_cap=config.final_logit_softcapping) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 3330d84021368..cbbbba68b85e9 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -34,7 +34,7 @@ RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -261,7 +261,7 @@ def __init__( self.lm_head = ParallelLMHead(self.config.vocab_size, self.config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 24c79a8855475..34f7cf8ee62f0 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -34,7 +34,7 @@ RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -287,7 +287,7 @@ def __init__( self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 0451d16b6c738..c0d0b7b39e693 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -33,7 +33,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -248,7 +248,7 @@ def __init__( quant_config=quant_config, ) self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 1bccef7a5f173..3690addaa544a 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -33,7 +33,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -262,7 +262,7 @@ def __init__( if self.config.tie_word_embeddings: self.embed_out.weight = self.gpt_neox.embed_in.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.gpt_neox.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index c968817747754..92037a133a097 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -43,7 +43,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( get_compressed_tensors_cache_scale) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -412,7 +412,7 @@ def __init__( self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, scale=logit_scale) - self.sampler = Sampler() + self.sampler = get_sampler() else: self.lm_head = PPMissingLayer() diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index 5307bb21adb96..bf9e9d99ccc5e 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -40,7 +40,7 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -372,7 +372,7 @@ def __init__( scale=1 / self.config.logits_scaling) - self.sampler = Sampler() + self.sampler = get_sampler() def forward( self, diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 313d98b649b48..c8423fd60b06d 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -21,7 +21,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -339,7 +339,7 @@ def __init__( if self.config.tie_word_embeddings: self.output.weight = self.model.tok_embeddings.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 1c1fde5b30983..4ab7d702cfc47 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -21,7 +21,7 @@ token_inputs) from vllm.model_executor.layers.quantization import (AWQConfig, QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.models.intern_vit import (InternVisionModel, InternVisionPatchModel) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -467,7 +467,7 @@ def sampler(self): if hasattr(self.language_model, "sampler"): return self.language_model.sampler - return Sampler() + return get_sampler() def _init_vision_model( self, diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index b947f24a693b5..f33de900cb364 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -35,7 +35,7 @@ RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -309,7 +309,7 @@ def __init__( config.mup_width_scale) self.logits_processor = LogitsProcessor(vocab_size=config.vocab_size, scale=self.output_logits_scale) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index fddd39fb8c85b..dc3a7d9f4f512 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -23,7 +23,7 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( selective_scan_fn, selective_state_update) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -555,7 +555,7 @@ def __init__( self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() def forward(self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 8a9e5203972be..13c245229ae52 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -43,7 +43,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( get_compressed_tensors_cache_scale) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -539,7 +539,7 @@ def __init__( self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logit_scale) - self.sampler = Sampler() + self.sampler = get_sampler() else: self.lm_head = PPMissingLayer() self.make_empty_intermediate_tensors = ( diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 27055e7ced865..f11382ba3df52 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -13,7 +13,7 @@ from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import IntermediateTensors @@ -301,7 +301,7 @@ def sampler(self): if hasattr(self.language_model, "sampler"): return self.language_model.sampler - return Sampler() + return get_sampler() def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: h = w = self.config.vision_config.image_size diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index e8540d85ff565..ee57c4ec7deae 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -15,7 +15,7 @@ from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @@ -326,7 +326,7 @@ def sampler(self): if hasattr(self.language_model, "sampler"): return self.language_model.sampler - return Sampler() + return get_sampler() def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: expected_dims = (2, ) diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index b8051d5fc6ae2..23777fad1ca13 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -15,7 +15,7 @@ token_inputs) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @@ -280,7 +280,7 @@ def sampler(self): if hasattr(self.language_model, "sampler"): return self.language_model.sampler - return Sampler() + return get_sampler() def _validate_video_pixel_values( self, data: Union[torch.Tensor, List[torch.Tensor]] diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index a0cf208a65f36..f3fcd0f46d5f8 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -19,7 +19,7 @@ token_inputs) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.utils import (cached_get_tokenizer, @@ -436,7 +436,7 @@ def sampler(self): if hasattr(self.language_model, "sampler"): return self.language_model.sampler - return Sampler() + return get_sampler() def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: expected_dims = (2, ) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 9f4f391a6682e..f198a4679f295 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -20,7 +20,7 @@ selective_scan_fn, selective_state_update) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -346,7 +346,7 @@ def __init__( self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() def forward(self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 03fb036020f2f..61b2f6e4dd580 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -44,7 +44,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -497,7 +497,7 @@ def __init__( self.logits_processor = LogitsProcessor(unpadded_vocab_size, config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 4917c33136069..c6694c4fa76d5 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -42,7 +42,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2, get_2d_sincos_pos_embed) -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -412,7 +412,7 @@ def __init__( quant_config=quant_config, prefix="llm.lm_head") self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.llm.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 1514243ad59c9..73efa8952255e 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -39,7 +39,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -367,7 +367,7 @@ def __init__( self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 63e2c60a84271..2fac7387360cc 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -41,7 +41,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -367,7 +367,7 @@ def __init__( if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 5cf5272cae878..4d51883e455ae 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -45,7 +45,7 @@ RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -1100,7 +1100,7 @@ def __init__(self, ) self.logits_processor = LogitsProcessor(config.output_hidden_states, config.text_config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() def compute_logits( self, diff --git a/vllm/model_executor/models/mlp_speculator.py b/vllm/model_executor/models/mlp_speculator.py index 42ccd01298169..8044f2aa5344d 100644 --- a/vllm/model_executor/models/mlp_speculator.py +++ b/vllm/model_executor/models/mlp_speculator.py @@ -6,7 +6,7 @@ from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -137,7 +137,7 @@ def __init__(self, config: MLPSpeculatorConfig, **kwargs) -> None: self.config = config self.logits_processor = LogitsProcessor(config.vocab_size, config.vocab_size, 1.0) - self.sampler = Sampler() + self.sampler = get_sampler() def generate_proposals( self, diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 3c34227767e05..ca0a763894af0 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -32,7 +32,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -1049,7 +1049,7 @@ def __init__( self.logits_processor = LogitsProcessor(config.embedding_size or config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index ee802030a5ef3..8c9e3d7d0fdcc 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -17,7 +17,7 @@ RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -282,7 +282,7 @@ def __init__( self.transformer = MPTModel(config, cache_config, quant_config) self.lm_head = self.transformer.wte self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index 72a09129fed63..dc393cb416714 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -37,7 +37,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -442,7 +442,7 @@ def __init__( self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logit_scale) - self.sampler = Sampler() + self.sampler = get_sampler() else: self.lm_head = PPMissingLayer() self.make_empty_intermediate_tensors = ( diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 90ab8abcb84b4..ca5c5ad43cee9 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -38,7 +38,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -310,7 +310,7 @@ def __init__(self, quant_config=quant_config, ) self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index 374cbb8df1fcd..de30b5270e7e8 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -28,7 +28,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -323,7 +323,7 @@ def __init__( config.hidden_size, quant_config=quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 10cca8b56268a..6c849f0bd8f32 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -34,7 +34,7 @@ RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -366,7 +366,7 @@ def __init__( self.lm_head = ParallelLMHead(config.vocab_size, config.word_embed_proj_dim) self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index 055407587c598..465f693f123a5 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -21,7 +21,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -285,7 +285,7 @@ def __init__( if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py index fc9ef15db26c0..e43984101491d 100644 --- a/vllm/model_executor/models/persimmon.py +++ b/vllm/model_executor/models/persimmon.py @@ -37,7 +37,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -280,7 +280,7 @@ def __init__(self, config.hidden_size, bias=False) self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 497eae4e8905b..b676b8b11befc 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -52,7 +52,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -303,7 +303,7 @@ def __init__( bias=True, quant_config=quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py index 3a7afc606bb9a..92bf0e61448e5 100644 --- a/vllm/model_executor/models/phi3_small.py +++ b/vllm/model_executor/models/phi3_small.py @@ -15,7 +15,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -386,7 +386,7 @@ def __init__( if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 4928e447d5b9e..d10ce7b8f0886 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -33,7 +33,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.models.clip import CLIPVisionModel @@ -571,7 +571,7 @@ def sampler(self): if hasattr(self.language_model, "sampler"): return self.language_model.sampler - return Sampler() + return get_sampler() def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: expected_dims = (2, ) diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index bb8a9327b4ac8..bbee0567a62db 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -39,7 +39,7 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -563,7 +563,7 @@ def __init__( ) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 6b53bf5660096..b2cc5144656fe 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -22,7 +22,7 @@ from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.utils import merge_multimodal_embeddings from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -181,7 +181,7 @@ def sampler(self): if hasattr(self.language_model, "sampler"): return self.language_model.sampler - return Sampler() + return get_sampler() def forward( self, diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 61665768eacf5..eac2818514820 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -37,7 +37,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.resampler import Resampler2, get_abs_pos from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -885,7 +885,7 @@ def __init__( if self.config.tie_word_embeddings: self.lm_head.weight = self.transformer.wte.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index db7556b3b5f4b..709af897ac143 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -40,7 +40,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -448,7 +448,7 @@ def __init__( prefix, "lm_head")) self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 3d049eeb920b7..ed617abb30861 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -37,7 +37,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) @@ -289,7 +289,7 @@ def __init__(self, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.text_config.vocab_size, logit_scale) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index dac85e35d369d..90f9b3fdad7d3 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -45,7 +45,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -394,7 +394,7 @@ def __init__( if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 1e12c2332b65e..4bccf065b06e7 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -53,7 +53,7 @@ RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.qwen2 import Qwen2Model @@ -967,7 +967,7 @@ def __init__(self, self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index e3e7ccb5cf179..aedc8d24a7b03 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -43,7 +43,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( get_compressed_tensors_cache_scale) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -450,7 +450,7 @@ def __init__( self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logit_scale) - self.sampler = Sampler() + self.sampler = get_sampler() else: self.lm_head = PPMissingLayer() diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 083a48588d01a..2ced87ddacc79 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -35,7 +35,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -262,7 +262,7 @@ def __init__( if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 8f0644bca3e2e..027493d472e35 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -35,7 +35,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -271,7 +271,7 @@ def __init__(self, ) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index f08e4aa355086..211bb3123029f 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -23,7 +23,7 @@ from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @@ -375,7 +375,7 @@ def sampler(self): if hasattr(self.language_model, "sampler"): return self.language_model.sampler - return Sampler() + return get_sampler() def _audio_features_to_embeddings( self, input_features: torch.Tensor) -> torch.Tensor: diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index 036789642d3c4..0ed277a912bf4 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -38,7 +38,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -335,7 +335,7 @@ def __init__( if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index b2af89ebf854a..562afe276db68 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -134,7 +134,7 @@ def forward( assert k_scale == 1.0 and v_scale == 1.0, ( "key/v_scale is not supported in FlashAttention.") - output = torch.ops.vllm.unified_flash_attention( + output = torch.ops.vllm.unified_v1_flash_attention( query, key, value, @@ -153,7 +153,7 @@ def forward( return output -def unified_flash_attention( +def unified_v1_flash_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, @@ -216,7 +216,7 @@ def unified_flash_attention( return output.view(num_tokens, hidden_size) -def unified_flash_attention_fake( +def unified_v1_flash_attention_fake( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, @@ -236,8 +236,8 @@ def unified_flash_attention_fake( direct_register_custom_op( - op_name="unified_flash_attention", - op_func=unified_flash_attention, + op_name="unified_v1_flash_attention", + op_func=unified_v1_flash_attention, mutates_args=["kv_cache"], - fake_impl=unified_flash_attention_fake, + fake_impl=unified_v1_flash_attention_fake, ) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index febabd2f31036..5e0ebad8e040e 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -154,6 +154,12 @@ def __init__( # GPU and CPU blocks, which are profiled in the distributed executor. self.scheduler = Scheduler(scheduler_config, cache_config, lora_config) + def __del__(self): + # Small hack- implicit clean up of resources on garbage collect + # TODO: this should probably be explicitly invoked when we're done with + # the engine + self.terminate_detokenizer() + def _initialize_kv_caches(self) -> None: num_gpu_blocks, _ = self.model_executor.determine_num_available_blocks( ) diff --git a/vllm/v1/tokenizer/detokenizer.py b/vllm/v1/tokenizer/detokenizer.py index 4bbcf4717981e..e485fcc3522d9 100644 --- a/vllm/v1/tokenizer/detokenizer.py +++ b/vllm/v1/tokenizer/detokenizer.py @@ -73,7 +73,7 @@ def recv(self) -> Optional[DetokenizerOutputs]: return None def terminate(self) -> None: - self.push_socket.send(b"", flags=zmq.NOBLOCK) + self.detokenizer.kill() self.detokenizer.join() @@ -108,10 +108,10 @@ def run(self): self.push_socket.bind(f"tcp://*:{self.push_port}") while True: + if self.pull_socket.poll(timeout=1000) == 0: + # Nothing to read + continue message = self.pull_socket.recv() - if message == b"": - # Terminate signal. - break inputs = self.msgpack_decoder.decode(message) for req_id in inputs.free_req_ids: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e84645ac7a4ae..0cf72e8858817 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,6 +1,5 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Set -from unittest.mock import patch import numpy as np import torch @@ -21,7 +20,6 @@ FlashAttentionMetadata) from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.sample.sampler import Sampler if TYPE_CHECKING: from vllm.v1.core.scheduler import SchedulerOutput @@ -371,14 +369,13 @@ def execute_model( def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) with DeviceMemoryProfiler() as m: # noqa: SIM117 - with patch("vllm.model_executor.layers.sampler.Sampler", Sampler): - self.model = get_model(model_config=self.model_config, - device_config=self.device_config, - load_config=self.load_config, - lora_config=self.lora_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - cache_config=self.cache_config) + self.model = get_model(model_config=self.model_config, + device_config=self.device_config, + load_config=self.load_config, + lora_config=self.lora_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + cache_config=self.cache_config) self.model_memory_usage = m.consumed_memory logger.info("Loading model weights took %.4f GB",