diff --git a/docs/usage/v1_guide.md b/docs/usage/v1_guide.md index c47547cb0ea7..8d8a9e0f5080 100644 --- a/docs/usage/v1_guide.md +++ b/docs/usage/v1_guide.md @@ -6,8 +6,6 @@ V1 is now enabled by default for all supported use cases, and we will gradually enable it for every use case we plan to support. Please share any feedback on [GitHub](https://github.com/vllm-project/vllm) or in the [vLLM Slack](https://inviter.co/vllm-slack). -To disable V1, please set the environment variable as: `VLLM_USE_V1=0`, and send us a GitHub issue sharing the reason! - ## Why vLLM V1? vLLM V0 successfully supported a wide range of models and hardware, but as new features were developed independently, the system grew increasingly complex. This complexity made it harder to integrate new capabilities and introduced technical debt, revealing the need for a more streamlined and unified design. diff --git a/tests/conftest.py b/tests/conftest.py index 41fda04a6c92..5e127e4e939e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -154,26 +154,6 @@ def prompts(self, prompts: AudioAssetPrompts) -> list[str]: """Singleton instance of {class}`AudioTestAssets`.""" -@pytest.fixture(scope="function", autouse=True) -def cleanup_VLLM_USE_V1(monkeypatch): - """ - The V1 oracle sets "VLLM_USE_V1" during loading. This means - that each invocation of a test change the env variable. - - If we touch "VLLM_USE_V1" with monkeypatch, then any changes - made during the test run by vLLM will be cleaned up. - - This fixture is used by every test. - """ - - # If VLLM_USE_V1 is not set, set then delete. This will - # cause monkeypatch to clean up VLLM_USE_V1 upon exit - # if VLLM modifies the value of envs.VLLM_USE_V1. - if "VLLM_USE_V1" not in os.environ: - monkeypatch.setenv("VLLM_USE_V1", "") - monkeypatch.delenv("VLLM_USE_V1") - - @pytest.fixture(autouse=True) def init_test_http_connection(): # pytest_asyncio may use a different event loop per test diff --git a/tests/v1/engine/test_async_llm.py b/tests/v1/engine/test_async_llm.py index c9605ea1b07c..25af55baa91f 100644 --- a/tests/v1/engine/test_async_llm.py +++ b/tests/v1/engine/test_async_llm.py @@ -424,15 +424,12 @@ async def test_customize_loggers(monkeypatch): @pytest.mark.asyncio -async def test_customize_aggregated_loggers(monkeypatch): +async def test_customize_aggregated_loggers(): """Test that we can customize the aggregated loggers. If a customized logger is provided at the init, it should be added to the default loggers. """ - - with monkeypatch.context() as m, ExitStack() as after: - m.setenv("VLLM_USE_V1", "1") - + with ExitStack() as after: with set_default_torch_num_threads(1): engine = AsyncLLM.from_engine_args( TEXT_ENGINE_ARGS, diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index 014e6eca2e02..676423f2ca91 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -868,11 +868,8 @@ def test_structured_output_batched_with_non_structured_outputs_requests( @pytest.mark.parametrize("guided_decoding_backend", ["xgrammar"]) def test_structured_output_with_structural_tag( - monkeypatch: pytest.MonkeyPatch, guided_decoding_backend: str, ): - monkeypatch.setenv("VLLM_USE_V1", "1") - llm = LLM( model="Qwen/Qwen2.5-1.5B-Instruct", guided_decoding_backend=guided_decoding_backend, diff --git a/tests/v1/sample/test_logprobs.py b/tests/v1/sample/test_logprobs.py index 6d4a1ecf78c8..354fff22dc2a 100644 --- a/tests/v1/sample/test_logprobs.py +++ b/tests/v1/sample/test_logprobs.py @@ -530,7 +530,6 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode): def test_spec_decode_logprobs( logprobs_mode: LogprobsMode, model_setup: tuple[str, str, str], - monkeypatch: pytest.MonkeyPatch, ): """Spec decode logprobs should match those of the base model. @@ -541,64 +540,62 @@ def test_spec_decode_logprobs( """ from vllm import LLM - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - prompt = "Hello world" - sampling_params = SamplingParams( - temperature=0, logprobs=3, max_tokens=10, ignore_eos=False - ) - method, model_name, spec_model_name = model_setup - max_model_len = 256 - - # Run base LLM. - ref_llm = LLM( - model=model_name, - max_logprobs=5, - max_model_len=max_model_len, - seed=42, - logprobs_mode=logprobs_mode, - gpu_memory_utilization=0.4, - ) - ref_results = ref_llm.generate([prompt], sampling_params) - # Collect logprobs outputs from reference LLM. - ref_logprobs = [] - for output in ref_results[0].outputs: - for logprobs in output.logprobs: - for token_id in logprobs: - ref_logprobs.append(logprobs[token_id]) - del ref_llm - torch.cuda.empty_cache() - cleanup_dist_env_and_memory() - - # Run spec decode LLM. - spec_llm = LLM( - model_name, - speculative_config={ - "method": method, - "model": spec_model_name, - "num_speculative_tokens": 3, - "max_model_len": max_model_len, - }, - max_logprobs=5, - max_model_len=max_model_len, - seed=42, - logprobs_mode=logprobs_mode, - gpu_memory_utilization=0.4, - ) - spec_results = spec_llm.generate([prompt], sampling_params) - # Collect logprobs outputs from spec decode LLM. - spec_logprobs = [] - for output in spec_results[0].outputs: - for logprobs in output.logprobs: - for token_id in logprobs: - spec_logprobs.append(logprobs[token_id]) - del spec_llm - torch.cuda.empty_cache() - cleanup_dist_env_and_memory() - - # Per-token logprobs are expected to be the same. - assert len(ref_logprobs) == len(spec_logprobs) - for ref_logprob, spec_logprob in zip(ref_logprobs, spec_logprobs): - assert math.isclose(ref_logprob.logprob, spec_logprob.logprob, abs_tol=1e-3) - assert ref_logprob.rank == spec_logprob.rank - assert ref_logprob.decoded_token == spec_logprob.decoded_token + prompt = "Hello world" + sampling_params = SamplingParams( + temperature=0, logprobs=3, max_tokens=10, ignore_eos=False + ) + method, model_name, spec_model_name = model_setup + max_model_len = 256 + + # Run base LLM. + ref_llm = LLM( + model=model_name, + max_logprobs=5, + max_model_len=max_model_len, + seed=42, + logprobs_mode=logprobs_mode, + gpu_memory_utilization=0.4, + ) + ref_results = ref_llm.generate([prompt], sampling_params) + # Collect logprobs outputs from reference LLM. + ref_logprobs = [] + for output in ref_results[0].outputs: + for logprobs in output.logprobs: + for token_id in logprobs: + ref_logprobs.append(logprobs[token_id]) + del ref_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() + + # Run spec decode LLM. + spec_llm = LLM( + model_name, + speculative_config={ + "method": method, + "model": spec_model_name, + "num_speculative_tokens": 3, + "max_model_len": max_model_len, + }, + max_logprobs=5, + max_model_len=max_model_len, + seed=42, + logprobs_mode=logprobs_mode, + gpu_memory_utilization=0.4, + ) + spec_results = spec_llm.generate([prompt], sampling_params) + # Collect logprobs outputs from spec decode LLM. + spec_logprobs = [] + for output in spec_results[0].outputs: + for logprobs in output.logprobs: + for token_id in logprobs: + spec_logprobs.append(logprobs[token_id]) + del spec_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() + + # Per-token logprobs are expected to be the same. + assert len(ref_logprobs) == len(spec_logprobs) + for ref_logprob, spec_logprob in zip(ref_logprobs, spec_logprobs): + assert math.isclose(ref_logprob.logprob, spec_logprob.logprob, abs_tol=1e-3) + assert ref_logprob.rank == spec_logprob.rank + assert ref_logprob.decoded_token == spec_logprob.decoded_token diff --git a/vllm/attention/layers/chunked_local_attention.py b/vllm/attention/layers/chunked_local_attention.py index 18422404d08f..5532ce80d7f1 100644 --- a/vllm/attention/layers/chunked_local_attention.py +++ b/vllm/attention/layers/chunked_local_attention.py @@ -5,7 +5,6 @@ import torch -from vllm import envs from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata from vllm.attention.selector import get_attn_backend from vllm.config import CacheConfig @@ -78,17 +77,12 @@ def __init__( kv_cache_dtype = "auto" block_size = 16 - if envs.VLLM_USE_V1: - underlying_attn_backend = get_attn_backend( - head_size, dtype, kv_cache_dtype, block_size - ) - - attn_backend = create_chunked_local_attention_backend( - underlying_attn_backend, attention_chunk_size, block_size - ) - else: - # in v0 the local attention is handled inside the backends - attn_backend = None + underlying_attn_backend = get_attn_backend( + head_size, dtype, kv_cache_dtype, block_size + ) + attn_backend = create_chunked_local_attention_backend( + underlying_attn_backend, attention_chunk_size, block_size + ) super().__init__( num_heads=num_heads, diff --git a/vllm/attention/layers/cross_attention.py b/vllm/attention/layers/cross_attention.py index 4b89c28f0ca6..5b44c7e3e7ec 100644 --- a/vllm/attention/layers/cross_attention.py +++ b/vllm/attention/layers/cross_attention.py @@ -6,7 +6,6 @@ import numpy as np import torch -from vllm import envs from vllm.attention.backends.abstract import ( AttentionBackend, AttentionMetadata, @@ -150,15 +149,10 @@ def __init__( kv_cache_dtype = "auto" block_size = 16 - if envs.VLLM_USE_V1: - underlying_attn_backend = get_attn_backend( - head_size, dtype, kv_cache_dtype, block_size - ) - - attn_backend = create_cross_attention_backend(underlying_attn_backend) - else: - # in v0 cross attention is handled inside the backends - attn_backend = None + underlying_attn_backend = get_attn_backend( + head_size, dtype, kv_cache_dtype, block_size + ) + attn_backend = create_cross_attention_backend(underlying_attn_backend) if attn_type is not None: assert attn_type == AttentionType.ENCODER_DECODER, ( diff --git a/vllm/attention/layers/encoder_only_attention.py b/vllm/attention/layers/encoder_only_attention.py index 8d2a046757fe..4929bbf5efc7 100644 --- a/vllm/attention/layers/encoder_only_attention.py +++ b/vllm/attention/layers/encoder_only_attention.py @@ -5,7 +5,6 @@ import torch -from vllm import envs from vllm.attention.backends.abstract import ( AttentionBackend, AttentionMetadata, @@ -74,17 +73,11 @@ def __init__( kv_cache_dtype = "auto" block_size = 16 - if envs.VLLM_USE_V1: - underlying_attn_backend = get_attn_backend( - head_size, dtype, kv_cache_dtype, block_size - ) + underlying_attn_backend = get_attn_backend( + head_size, dtype, kv_cache_dtype, block_size + ) - attn_backend = create_encoder_only_attention_backend( - underlying_attn_backend - ) - else: - # in v0 encoder only attention is handled inside the backends - attn_backend = None + attn_backend = create_encoder_only_attention_backend(underlying_attn_backend) if attn_type is not None: assert attn_type == AttentionType.ENCODER_ONLY, ( diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 9890d8d80cba..9c26a8d40eda 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -134,16 +134,11 @@ def get_attn_backend( use_sparse: bool = False, ) -> type[AttentionBackend]: """Selects which attention backend to use and lazily imports it.""" - # Accessing envs.* behind an @lru_cache decorator can cause the wrong - # value to be returned from the cache if the value changes between calls. - # To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the - # private function. return _cached_get_attn_backend( head_size=head_size, dtype=dtype, kv_cache_dtype=kv_cache_dtype, block_size=block_size, - use_v1=envs.VLLM_USE_V1, use_mla=use_mla, has_sink=has_sink, use_sparse=use_sparse, @@ -156,7 +151,6 @@ def _cached_get_attn_backend( dtype: torch.dtype, kv_cache_dtype: str | None, block_size: int, - use_v1: bool = False, use_mla: bool = False, has_sink: bool = False, use_sparse: bool = False, @@ -199,7 +193,7 @@ def _cached_get_attn_backend( dtype, kv_cache_dtype, block_size, - use_v1, + True, use_mla, has_sink, use_sparse, diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 8d14200c5240..494a4d3c33aa 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -5,7 +5,6 @@ from collections.abc import Callable from typing import TYPE_CHECKING, Optional, cast -import vllm.envs as envs from vllm.distributed.kv_transfer.kv_connector.base import ( KVConnectorBase, KVConnectorBaseType, @@ -47,12 +46,6 @@ def create_connector( role: KVConnectorRole, kv_cache_config: Optional["KVCacheConfig"] = None, ) -> KVConnectorBase: - if not envs.VLLM_USE_V1: - raise ValueError( - "Attempting to initialize a V1 Connector, " - f"but found {envs.VLLM_USE_V1=}" - ) - kv_transfer_config = config.kv_transfer_config if kv_transfer_config is None: raise ValueError("kv_transfer_config must be set to create a connector") diff --git a/vllm/distributed/kv_transfer/kv_transfer_state.py b/vllm/distributed/kv_transfer/kv_transfer_state.py index 7501f0b373d4..54b46d98870a 100644 --- a/vllm/distributed/kv_transfer/kv_transfer_state.py +++ b/vllm/distributed/kv_transfer/kv_transfer_state.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import TYPE_CHECKING, Optional -from vllm import envs from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory from vllm.distributed.kv_transfer.kv_connector.v1 import ( @@ -65,14 +64,11 @@ def ensure_kv_transfer_initialized( vllm_config.kv_transfer_config.is_kv_transfer_instance and _KV_CONNECTOR_AGENT is None ): - if envs.VLLM_USE_V1: - _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector( - config=vllm_config, - role=KVConnectorRole.WORKER, - kv_cache_config=kv_cache_config, - ) - else: - raise ValueError("V0 is no longer supported") + _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector( + config=vllm_config, + role=KVConnectorRole.WORKER, + kv_cache_config=kv_cache_config, + ) def ensure_kv_transfer_shutdown() -> None: diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index dc6f3df5a68e..2678658dd126 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -88,9 +88,6 @@ def run_headless(args: argparse.Namespace): usage_context=usage_context, headless=True ) - if not envs.VLLM_USE_V1: - raise ValueError("Headless mode is only supported for V1") - if engine_args.data_parallel_hybrid_lb: raise ValueError("data_parallel_hybrid_lb is not applicable in headless mode") @@ -156,15 +153,10 @@ def run_multi_api_server(args: argparse.Namespace): usage_context = UsageContext.OPENAI_API_SERVER vllm_config = engine_args.create_engine_config(usage_context=usage_context) - if num_api_servers > 1: - if not envs.VLLM_USE_V1: - raise ValueError("api_server_count > 1 is only supported for V1") - - if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: - raise ValueError( - "VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used " - "with api_server_count > 1" - ) + if num_api_servers > 1 and envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: + raise ValueError( + "VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used with api_server_count > 1" + ) executor_class = Executor.get_class(vllm_config) log_stats = not engine_args.disable_log_stats diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index e184f22f3630..e77a6ad86277 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -220,14 +220,8 @@ async def build_async_engine_client_from_engine_args( # Create the EngineConfig (determines if we can use V1). vllm_config = engine_args.create_engine_config(usage_context=usage_context) - # V1 AsyncLLM. - assert envs.VLLM_USE_V1 - if disable_frontend_multiprocessing: - logger.warning( - "V1 is enabled, but got --disable-frontend-multiprocessing. " - "To disable frontend multiprocessing, set VLLM_USE_V1=0." - ) + logger.warning("V1 is enabled, but got --disable-frontend-multiprocessing.") from vllm.v1.engine.async_llm import AsyncLLM diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index d0061f9d5b40..33256de6dd47 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -79,7 +79,6 @@ model_validator, ) -from vllm import envs from vllm.entrypoints.chat_utils import ChatCompletionMessageParam, make_tool_call_id from vllm.entrypoints.score_utils import ScoreContentPartParam, ScoreMultiModalParam from vllm.logger import init_logger @@ -475,16 +474,12 @@ def validate_prompt(cls, data): @model_validator(mode="before") def check_cache_salt_support(cls, data): - if data.get("cache_salt") is not None: - if not envs.VLLM_USE_V1: - raise ValueError( - "Parameter 'cache_salt' is not supported with " - "this instance of vLLM, which uses engine V0." - ) - if not isinstance(data["cache_salt"], str) or not data["cache_salt"]: - raise ValueError( - "Parameter 'cache_salt' must be a non-empty string if provided." - ) + if data.get("cache_salt") is not None and ( + not isinstance(data["cache_salt"], str) or not data["cache_salt"] + ): + raise ValueError( + "Parameter 'cache_salt' must be a non-empty string if provided." + ) return data @model_validator(mode="before") @@ -946,10 +941,6 @@ def check_logprobs(cls, data): if prompt_logprobs < 0 and prompt_logprobs != -1: raise ValueError("`prompt_logprobs` must be a positive value or -1.") - if prompt_logprobs == -1 and not envs.VLLM_USE_V1: - raise ValueError( - "`prompt_logprobs=-1` is only supported with vLLM engine V1." - ) if (top_logprobs := data.get("top_logprobs")) is not None: if top_logprobs < 0 and top_logprobs != -1: raise ValueError("`top_logprobs` must be a positive value or -1.") @@ -1083,16 +1074,12 @@ def check_generation_prompt(cls, data): @model_validator(mode="before") @classmethod def check_cache_salt_support(cls, data): - if data.get("cache_salt") is not None: - if not envs.VLLM_USE_V1: - raise ValueError( - "Parameter 'cache_salt' is not supported with " - "this instance of vLLM, which uses engine V0." - ) - if not isinstance(data["cache_salt"], str) or not data["cache_salt"]: - raise ValueError( - "Parameter 'cache_salt' must be a non-empty string if provided." - ) + if data.get("cache_salt") is not None and ( + not isinstance(data["cache_salt"], str) or not data["cache_salt"] + ): + raise ValueError( + "Parameter 'cache_salt' must be a non-empty string if provided." + ) return data @@ -1449,10 +1436,6 @@ def check_logprobs(cls, data): if prompt_logprobs < 0 and prompt_logprobs != -1: raise ValueError("`prompt_logprobs` must be a positive value or -1.") - if prompt_logprobs == -1 and not envs.VLLM_USE_V1: - raise ValueError( - "`prompt_logprobs=-1` is only supported with vLLM engine V1." - ) if (logprobs := data.get("logprobs")) is not None and logprobs < 0: raise ValueError("`logprobs` must be a positive value.") @@ -1487,16 +1470,12 @@ def validate_prompt_and_prompt_embeds(cls, data): @model_validator(mode="before") @classmethod def check_cache_salt_support(cls, data): - if data.get("cache_salt") is not None: - if not envs.VLLM_USE_V1: - raise ValueError( - "Parameter 'cache_salt' is not supported with " - "this instance of vLLM, which uses engine V0." - ) - if not isinstance(data["cache_salt"], str) or not data["cache_salt"]: - raise ValueError( - "Parameter 'cache_salt' must be a non-empty string if provided." - ) + if data.get("cache_salt") is not None and ( + not isinstance(data["cache_salt"], str) or not data["cache_salt"] + ): + raise ValueError( + "Parameter 'cache_salt' must be a non-empty string if provided." + ) return data diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index 06b4f9271b41..e4e530f0cea8 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -726,8 +726,6 @@ def tensorize_vllm_model( ) as stream: stream.write(encryption_params.key) - assert envs.VLLM_USE_V1 - from vllm.v1.engine.llm_engine import LLMEngine engine = LLMEngine.from_vllm_config(engine_config) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 5dda2ec97875..936e59117232 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -285,10 +285,6 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: Args: vllm_config: vLLM Config """ - - if not envs.VLLM_USE_V1: - return - model_config = vllm_config.model_config cache_config = vllm_config.cache_config @@ -329,10 +325,6 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: Args: vllm_config: vLLM Config """ - - if not envs.VLLM_USE_V1: - return - # Save the user input before it gets modified by MambaModelConfig mamba_block_size = vllm_config.cache_config.mamba_block_size # Enable FULL_AND_PIECEWISE by default diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 748605b4ed5a..630de816dc22 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -9,7 +9,6 @@ from transformers import BatchFeature, Gemma3Config, Gemma3Processor from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs -import vllm.envs as envs from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.logger import init_logger @@ -137,11 +136,10 @@ def get_num_crops( if not do_pan_and_scan: return 0 - if envs.VLLM_USE_V1: - logger.warning_once( - "`do_pan_and_scan=True` has suboptimal results on V1 " - "because of the simplified attention pattern being used." - ) + logger.warning_once( + "`do_pan_and_scan=True` has suboptimal results on V1 " + "because of the simplified attention pattern being used." + ) # Based on Gemma3ImageProcessor.pan_and_scan if image_width >= image_height: diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 069078850217..e5ebd8138b0a 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -12,7 +12,6 @@ from transformers import PretrainedConfig from typing_extensions import deprecated -import vllm.envs as envs from vllm.config import VllmConfig from vllm.distributed import ( get_tensor_model_parallel_rank, @@ -576,11 +575,8 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module: pin_memory = is_pin_memory_available() uva_available = is_uva_available() - if envs.VLLM_USE_V1: - assert uva_available, "V1 CPU offloading requires uva (pin memory) support" - uva_offloading = True - else: - uva_offloading = False + assert uva_available, "V1 CPU offloading requires uva (pin memory) support" + uva_offloading = True # offload parameters to CPU # use pin_memory if possible, which helps cudagraph capture speed diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index b864c52dfbc8..cb70041e9744 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -9,7 +9,6 @@ import numpy.typing as npt from PIL import Image -import vllm.envs as envs from vllm.config.multimodal import ( AudioDummyOptions, BaseDummyOptions, @@ -306,18 +305,6 @@ def get_encoder_dummy_data( if processor.pad_dummy_encoder_prompt: num_tokens_to_pad = max(total_len, seq_len) - total_len encoder_prompt_token_ids.extend([0] * num_tokens_to_pad) - # NOTE: Whisper allows total_len > seq_len. - elif total_len > seq_len and not envs.VLLM_USE_V1: - # `max_num_batched_tokens` is defined by `SchedulerConfig` - logger.warning_once( - "The encoder sequence length used for profiling (max_num_batched_tokens / max_num_seqs = %d) " # noqa: E501 - "is too short to hold the multi-modal embeddings in the worst case (%d tokens in total, out of which %s are reserved for multi-modal embeddings). " # noqa: E501 - "This may cause certain multi-modal inputs to fail during inference, even when the input text is short. " # noqa: E501 - "To avoid this, you should increase `max_model_len`, reduce `max_num_seqs`, and/or reduce `mm_counts`.", # noqa: E501 - seq_len, - total_len, - str(self._get_mm_num_tokens(mm_inputs)), - ) return DummyEncoderData(encoder_prompt_token_ids)