From 3765f9fbc0b477b645fa8257f18ec827557a250c Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Tue, 7 Oct 2025 20:59:45 +0800 Subject: [PATCH 01/15] draft Signed-off-by: Isotr0py --- .../processing/test_tensor_schema.py | 2 +- .../model_loader/base_loader.py | 2 +- .../model_loader/bitsandbytes_loader.py | 3 +- .../model_loader/gguf_loader.py | 2 +- .../model_loader/tensorizer_loader.py | 2 +- vllm/model_executor/model_loader/tpu.py | 2 +- vllm/model_executor/model_loader/utils.py | 10 --- vllm/model_executor/models/deepseek_vl2.py | 2 +- vllm/model_executor/models/minicpmv.py | 2 +- vllm/model_executor/models/whisper.py | 2 +- vllm/utils/__init__.py | 63 --------------- vllm/utils/torch_utils.py | 78 +++++++++++++++++++ 12 files changed, 88 insertions(+), 82 deletions(-) create mode 100644 vllm/utils/torch_utils.py diff --git a/tests/models/multimodal/processing/test_tensor_schema.py b/tests/models/multimodal/processing/test_tensor_schema.py index 2c4d109c3687..bb74c6cabe43 100644 --- a/tests/models/multimodal/processing/test_tensor_schema.py +++ b/tests/models/multimodal/processing/test_tensor_schema.py @@ -25,7 +25,6 @@ init_distributed_environment, initialize_model_parallel, ) -from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.models.interfaces import ( SupportsMultiModal, supports_multimodal, @@ -35,6 +34,7 @@ from vllm.multimodal.utils import group_mm_kwargs_by_modality from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from vllm.utils import is_list_of +from vllm.utils.torch_utils import set_default_torch_dtype from ...registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS from ...utils import dummy_hf_overrides diff --git a/vllm/model_executor/model_loader/base_loader.py b/vllm/model_executor/model_loader/base_loader.py index 6106a1ab8a85..94dfa478245d 100644 --- a/vllm/model_executor/model_loader/base_loader.py +++ b/vllm/model_executor/model_loader/base_loader.py @@ -11,8 +11,8 @@ from vllm.model_executor.model_loader.utils import ( initialize_model, process_weights_after_loading, - set_default_torch_dtype, ) +from vllm.utils.torch_utils import set_default_torch_dtype logger = init_logger(__name__) diff --git a/vllm/model_executor/model_loader/bitsandbytes_loader.py b/vllm/model_executor/model_loader/bitsandbytes_loader.py index 8c1ff0300b24..f63d087bacf4 100644 --- a/vllm/model_executor/model_loader/bitsandbytes_loader.py +++ b/vllm/model_executor/model_loader/bitsandbytes_loader.py @@ -32,7 +32,7 @@ RowParallelLinear, ) from vllm.model_executor.model_loader.base_loader import BaseModelLoader -from vllm.model_executor.model_loader.utils import ParamMapping, set_default_torch_dtype +from vllm.model_executor.model_loader.utils import ParamMapping from vllm.model_executor.model_loader.weight_utils import ( download_safetensors_index_file_from_hf, download_weights_from_hf, @@ -48,6 +48,7 @@ set_weight_attrs, ) from vllm.platforms import current_platform +from vllm.utils.torch_utils import set_default_torch_dtype logger = init_logger(__name__) diff --git a/vllm/model_executor/model_loader/gguf_loader.py b/vllm/model_executor/model_loader/gguf_loader.py index 93dc754a571c..5083119d46f7 100644 --- a/vllm/model_executor/model_loader/gguf_loader.py +++ b/vllm/model_executor/model_loader/gguf_loader.py @@ -15,13 +15,13 @@ from vllm.model_executor.model_loader.utils import ( initialize_model, process_weights_after_loading, - set_default_torch_dtype, ) from vllm.model_executor.model_loader.weight_utils import ( get_gguf_extra_tensor_names, get_gguf_weight_type_map, gguf_quant_weights_iterator, ) +from vllm.utils.torch_utils import set_default_torch_dtype class GGUFModelLoader(BaseModelLoader): diff --git a/vllm/model_executor/model_loader/tensorizer_loader.py b/vllm/model_executor/model_loader/tensorizer_loader.py index 5585a74f8926..250562dd17d6 100644 --- a/vllm/model_executor/model_loader/tensorizer_loader.py +++ b/vllm/model_executor/model_loader/tensorizer_loader.py @@ -23,8 +23,8 @@ from vllm.model_executor.model_loader.utils import ( get_model_architecture, initialize_model, - set_default_torch_dtype, ) +from vllm.utils.torch_utils import set_default_torch_dtype logger = init_logger(__name__) diff --git a/vllm/model_executor/model_loader/tpu.py b/vllm/model_executor/model_loader/tpu.py index fc97003de8e3..5ad024b8bd3f 100644 --- a/vllm/model_executor/model_loader/tpu.py +++ b/vllm/model_executor/model_loader/tpu.py @@ -15,8 +15,8 @@ from vllm.model_executor.model_loader.utils import ( initialize_model, process_weights_after_loading, - set_default_torch_dtype, ) +from vllm.utils.torch_utils import set_default_torch_dtype logger = init_logger(__name__) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index ba8d53c0ba14..ac650adca285 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Utilities for selecting and loading models.""" -import contextlib import inspect import warnings from contextlib import contextmanager @@ -33,15 +32,6 @@ logger = init_logger(__name__) -@contextlib.contextmanager -def set_default_torch_dtype(dtype: torch.dtype): - """Sets the default torch dtype to the given dtype.""" - old_dtype = torch.get_default_dtype() - torch.set_default_dtype(dtype) - yield - torch.set_default_dtype(old_dtype) - - def initialize_model( vllm_config: VllmConfig, *, diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index 8226e88c47a2..0cebdf3f3465 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -18,7 +18,6 @@ from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.models.transformers import replace_linear_class from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( @@ -51,6 +50,7 @@ from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from vllm.utils import is_list_of from vllm.utils.tensor_schema import TensorSchema, TensorShape +from vllm.utils.torch_utils import set_default_torch_dtype from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import ( diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 09f973e98db9..8665a778a284 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -49,7 +49,6 @@ Resampler2, get_2d_sincos_pos_embed, ) -from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.minicpm import MiniCPMForCausalLM from vllm.model_executor.models.module_mapping import MultiModelKeys @@ -88,6 +87,7 @@ from vllm.sequence import IntermediateTensors from vllm.utils import flatten_2d_lists from vllm.utils.tensor_schema import TensorSchema, TensorShape +from vllm.utils.torch_utils import set_default_torch_dtype from .idefics2_vision_model import Idefics2VisionTransformer from .interfaces import ( diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 397556cbbcc4..1e17b5ed92d2 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -34,7 +34,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( @@ -53,6 +52,7 @@ from vllm.transformers_utils.processor import cached_get_processor from vllm.utils.jsontree import json_map_leaves from vllm.utils.tensor_schema import TensorSchema, TensorShape +from vllm.utils.torch_utils import set_default_torch_dtype from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsTranscription from .utils import ( diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index c06bbbbb23ab..3da480217a4d 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -168,15 +168,6 @@ } -@contextlib.contextmanager -def set_default_torch_num_threads(num_threads: int): - """Sets the default number of threads for PyTorch to the given value.""" - old_num_threads = torch.get_num_threads() - torch.set_num_threads(num_threads) - yield - torch.set_num_threads(old_num_threads) - - P = ParamSpec("P") T = TypeVar("T") U = TypeVar("U") @@ -1237,60 +1228,6 @@ def async_tensor_h2d( return t.to(device=target_device, non_blocking=True) -def get_dtype_size(dtype: torch.dtype) -> int: - """Get the size of the data type in bytes.""" - return torch.tensor([], dtype=dtype).element_size() - - -# bool = 0, int = 1, float = 2, complex = 3 -def _get_precision_level(dtype: torch.dtype) -> int: - # NOTE: Complex dtypes return `is_floating_point=False` - return (dtype != torch.bool) + dtype.is_floating_point + dtype.is_complex * 2 - - -def is_lossless_cast(src_dtype: torch.dtype, tgt_dtype: torch.dtype): - """ - Test whether it is lossless to cast a tensor from - `src_dtype` to `tgt_dtype`. - """ - if src_dtype == tgt_dtype: - return True - - src_level = _get_precision_level(src_dtype) - tgt_level = _get_precision_level(tgt_dtype) - - if src_level < tgt_level: - return True - if src_level > tgt_level: - return False - - # Compare integral types - if not src_dtype.is_floating_point and not src_dtype.is_complex: - src_info = torch.iinfo(src_dtype) - tgt_info = torch.iinfo(tgt_dtype) - return src_info.min >= tgt_info.min and src_info.max <= tgt_info.max - - # Compare floating-point types - src_info = torch.finfo(src_dtype) - tgt_info = torch.finfo(tgt_dtype) - return ( - src_info.min >= tgt_info.min - and src_info.max <= tgt_info.max - and src_info.resolution >= tgt_info.resolution - ) - - -def common_broadcastable_dtype(dtypes: Collection[torch.dtype]): - """ - Get the common `dtype` where all of the other `dtypes` can be - cast to it without losing any information. - """ - return max( - dtypes, - key=lambda dtype: sum(is_lossless_cast(dt, dtype) for dt in dtypes), - ) - - def as_list(maybe_list: Iterable[T]) -> list[T]: """Convert iterable to list, unless it's already a list.""" return maybe_list if isinstance(maybe_list, list) else list(maybe_list) diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py new file mode 100644 index 000000000000..3e78deda8dac --- /dev/null +++ b/vllm/utils/torch_utils.py @@ -0,0 +1,78 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib +from collections.abc import Collection + +import torch + + +@contextlib.contextmanager +def set_default_torch_dtype(dtype: torch.dtype): + """Sets the default torch dtype to the given dtype.""" + old_dtype = torch.get_default_dtype() + torch.set_default_dtype(dtype) + yield + torch.set_default_dtype(old_dtype) + + +@contextlib.contextmanager +def set_default_torch_num_threads(num_threads: int): + """Sets the default number of threads for PyTorch to the given value.""" + old_num_threads = torch.get_num_threads() + torch.set_num_threads(num_threads) + yield + torch.set_num_threads(old_num_threads) + + +def get_dtype_size(dtype: torch.dtype) -> int: + """Get the size of the data type in bytes.""" + return torch.tensor([], dtype=dtype).element_size() + + +# bool = 0, int = 1, float = 2, complex = 3 +def _get_precision_level(dtype: torch.dtype) -> int: + # NOTE: Complex dtypes return `is_floating_point=False` + return (dtype != torch.bool) + dtype.is_floating_point + dtype.is_complex * 2 + + +def is_lossless_cast(src_dtype: torch.dtype, tgt_dtype: torch.dtype): + """ + Test whether it is lossless to cast a tensor from + `src_dtype` to `tgt_dtype`. + """ + if src_dtype == tgt_dtype: + return True + + src_level = _get_precision_level(src_dtype) + tgt_level = _get_precision_level(tgt_dtype) + + if src_level < tgt_level: + return True + if src_level > tgt_level: + return False + + # Compare integral types + if not src_dtype.is_floating_point and not src_dtype.is_complex: + src_info = torch.iinfo(src_dtype) + tgt_info = torch.iinfo(tgt_dtype) + return src_info.min >= tgt_info.min and src_info.max <= tgt_info.max + + # Compare floating-point types + src_info = torch.finfo(src_dtype) + tgt_info = torch.finfo(tgt_dtype) + return ( + src_info.min >= tgt_info.min + and src_info.max <= tgt_info.max + and src_info.resolution >= tgt_info.resolution + ) + + +def common_broadcastable_dtype(dtypes: Collection[torch.dtype]): + """ + Get the common `dtype` where all of the other `dtypes` can be + cast to it without losing any information. + """ + return max( + dtypes, + key=lambda dtype: sum(is_lossless_cast(dt, dtype) for dt in dtypes), + ) From 7dffb1eb6415179ab32580fa15ffcadc41ab8f75 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Wed, 15 Oct 2025 22:58:00 +0800 Subject: [PATCH 02/15] update mop context Signed-off-by: Isotr0py --- tests/conftest.py | 3 ++- tests/v1/engine/test_async_llm.py | 2 +- tests/v1/engine/test_engine_core.py | 2 +- tests/v1/engine/test_engine_core_client.py | 2 +- vllm/model_executor/models/internvl.py | 2 +- 5 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 9126b3d668b9..4ae423298a07 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -57,7 +57,8 @@ from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams from vllm.transformers_utils.utils import maybe_model_redirect -from vllm.utils import is_list_of, set_default_torch_num_threads +from vllm.utils import is_list_of +from vllm.utils.torch_utils import set_default_torch_num_threads logger = init_logger(__name__) diff --git a/tests/v1/engine/test_async_llm.py b/tests/v1/engine/test_async_llm.py index b9fa55314278..c9605ea1b07c 100644 --- a/tests/v1/engine/test_async_llm.py +++ b/tests/v1/engine/test_async_llm.py @@ -15,7 +15,7 @@ from vllm.outputs import RequestOutput from vllm.platforms import current_platform from vllm.sampling_params import RequestOutputKind -from vllm.utils import set_default_torch_num_threads +from vllm.utils.torch_utils import set_default_torch_num_threads from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.metrics.loggers import ( AggregatedLoggingStatLogger, diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index 997b2b74bb6b..7e39cd781bae 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -12,7 +12,7 @@ from vllm import SamplingParams from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform -from vllm.utils import set_default_torch_num_threads +from vllm.utils.torch_utils import set_default_torch_num_threads from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.core import EngineCore from vllm.v1.executor.abstract import Executor, UniProcExecutor diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 32eeaebbca91..770560a5e549 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -21,7 +21,7 @@ from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform from vllm.usage.usage_lib import UsageContext -from vllm.utils import set_default_torch_num_threads +from vllm.utils.torch_utils import set_default_torch_num_threads from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.core import EngineCore from vllm.v1.engine.core_client import AsyncMPClient, EngineCoreClient, SyncMPClient diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 05b822d6fdbf..e2d2647f0177 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -51,8 +51,8 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import set_default_torch_num_threads from vllm.utils.tensor_schema import TensorSchema, TensorShape +from vllm.utils.torch_utils import set_default_torch_num_threads from .interfaces import ( MultiModalEmbeddings, From 20876bef1cf32977708501253156307cdf6896de Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Wed, 15 Oct 2025 23:25:48 +0800 Subject: [PATCH 03/15] update Signed-off-by: Isotr0py --- tests/utils_/test_utils.py | 3 +-- vllm/config/model.py | 3 ++- vllm/v1/kv_cache_interface.py | 3 ++- vllm/v1/worker/gpu_model_runner.py | 2 +- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/utils_/test_utils.py b/tests/utils_/test_utils.py index af5fc758f2c2..2335e4a2b134 100644 --- a/tests/utils_/test_utils.py +++ b/tests/utils_/test_utils.py @@ -28,12 +28,10 @@ MemorySnapshot, PlaceholderModule, bind_kv_cache, - common_broadcastable_dtype, current_stream, deprecate_kwargs, get_open_port, get_tcp_uri, - is_lossless_cast, join_host_port, make_zmq_path, make_zmq_socket, @@ -46,6 +44,7 @@ swap_dict_values, unique_filepath, ) +from vllm.utils.torch_utils import common_broadcastable_dtype, is_lossless_cast from ..utils import create_new_process_for_each_test, error_on_warning diff --git a/vllm/config/model.py b/vllm/config/model.py index 2be939eb654d..04a35f92a1c1 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -38,7 +38,8 @@ ) from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri from vllm.transformers_utils.utils import maybe_model_redirect -from vllm.utils import LayerBlockType, LazyLoader, common_broadcastable_dtype +from vllm.utils import LayerBlockType, LazyLoader +from vllm.utils.torch_utils import common_broadcastable_dtype if TYPE_CHECKING: from transformers import PretrainedConfig diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index a9ef1b92c243..392519f8fa9a 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -10,7 +10,8 @@ from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.utils import cdiv, get_dtype_size +from vllm.utils import cdiv +from vllm.utils.torch_utils import get_dtype_size logger = init_logger(__name__) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9e394dbb592e..dd2a86a257c5 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -78,13 +78,13 @@ GiB_bytes, cdiv, check_use_alibi, - get_dtype_size, is_pin_memory_available, length_from_prompt_token_ids_or_embeds, round_up, supports_dynamo, ) from vllm.utils.jsontree import json_map_leaves +from vllm.utils.torch_utils import get_dtype_size from vllm.v1.attention.backends.flash_attn import AttentionMetadata from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.utils import ( From a4ac619e5108acc4ec2380476e4b82adb12f4bdb Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Thu, 16 Oct 2025 01:02:05 +0800 Subject: [PATCH 04/15] move STR_DTYPE_TO_TORCH_DTYPE and kv_caches utils Signed-off-by: Isotr0py --- .../kernels/bench_per_token_quant_fp8.py | 3 +- benchmarks/kernels/benchmark_activation.py | 3 +- benchmarks/kernels/benchmark_layernorm.py | 3 +- .../kernels/benchmark_paged_attention.py | 4 +- benchmarks/kernels/benchmark_quant.py | 3 +- .../kernels/benchmark_reshape_and_cache.py | 4 +- .../benchmark_reshape_and_cache_flash.py | 4 +- tests/kernels/attention/conftest.py | 5 +- .../kernels/attention/test_prefix_prefill.py | 2 +- .../multimodal/pooling/test_intern_vit.py | 2 +- tests/models/multimodal/pooling/test_radio.py | 2 +- tests/v1/attention/test_attention_backends.py | 3 +- tests/v1/attention/test_mla_backends.py | 3 +- .../layers/mamba/mamba_utils.py | 5 +- vllm/model_executor/models/config.py | 3 +- vllm/utils/__init__.py | 148 ------------------ vllm/utils/torch_utils.py | 148 ++++++++++++++++++ vllm/v1/worker/gpu_model_runner.py | 6 +- vllm/v1/worker/tpu_worker.py | 3 +- 19 files changed, 185 insertions(+), 169 deletions(-) diff --git a/benchmarks/kernels/bench_per_token_quant_fp8.py b/benchmarks/kernels/bench_per_token_quant_fp8.py index 9a52ea7f47e3..d33b84fc3601 100644 --- a/benchmarks/kernels/bench_per_token_quant_fp8.py +++ b/benchmarks/kernels/bench_per_token_quant_fp8.py @@ -10,7 +10,8 @@ from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.triton_utils import triton -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser +from vllm.utils import FlexibleArgumentParser +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE def with_triton_mode(fn): diff --git a/benchmarks/kernels/benchmark_activation.py b/benchmarks/kernels/benchmark_activation.py index 93edbcc9391f..7662655b5efa 100644 --- a/benchmarks/kernels/benchmark_activation.py +++ b/benchmarks/kernels/benchmark_activation.py @@ -10,7 +10,8 @@ from vllm.model_executor.custom_op import CustomOp from vllm.platforms import current_platform from vllm.triton_utils import triton -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser +from vllm.utils import FlexibleArgumentParser +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE batch_size_range = [1, 16, 32, 64, 128] seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096] diff --git a/benchmarks/kernels/benchmark_layernorm.py b/benchmarks/kernels/benchmark_layernorm.py index 69978ec6b23e..bcfa64c3f425 100644 --- a/benchmarks/kernels/benchmark_layernorm.py +++ b/benchmarks/kernels/benchmark_layernorm.py @@ -7,7 +7,8 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.platforms import current_platform -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser +from vllm.utils import FlexibleArgumentParser +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE @torch.inference_mode() diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 8f9907952d24..1b1e71adeec4 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -9,9 +9,9 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import ( +from vllm.utils import FlexibleArgumentParser +from vllm.utils.torch_utils import ( STR_DTYPE_TO_TORCH_DTYPE, - FlexibleArgumentParser, create_kv_caches_with_random, ) diff --git a/benchmarks/kernels/benchmark_quant.py b/benchmarks/kernels/benchmark_quant.py index 6ab26f5f1adf..61427a77b4e3 100644 --- a/benchmarks/kernels/benchmark_quant.py +++ b/benchmarks/kernels/benchmark_quant.py @@ -7,7 +7,8 @@ from vllm import _custom_ops as ops from vllm.platforms import current_platform -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser +from vllm.utils import FlexibleArgumentParser +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE @torch.inference_mode() diff --git a/benchmarks/kernels/benchmark_reshape_and_cache.py b/benchmarks/kernels/benchmark_reshape_and_cache.py index d4b564d2ec6c..e0ff09d4b397 100644 --- a/benchmarks/kernels/benchmark_reshape_and_cache.py +++ b/benchmarks/kernels/benchmark_reshape_and_cache.py @@ -9,9 +9,9 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import ( +from vllm.utils import FlexibleArgumentParser +from vllm.utils.torch_utils import ( STR_DTYPE_TO_TORCH_DTYPE, - FlexibleArgumentParser, create_kv_caches_with_random, ) diff --git a/benchmarks/kernels/benchmark_reshape_and_cache_flash.py b/benchmarks/kernels/benchmark_reshape_and_cache_flash.py index 93df14f0d95c..29f1b2ccdcf6 100644 --- a/benchmarks/kernels/benchmark_reshape_and_cache_flash.py +++ b/benchmarks/kernels/benchmark_reshape_and_cache_flash.py @@ -12,9 +12,9 @@ ) from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import ( +from vllm.utils import FlexibleArgumentParser +from vllm.utils.torch_utils import ( STR_DTYPE_TO_TORCH_DTYPE, - FlexibleArgumentParser, create_kv_caches_with_random_flash, ) diff --git a/tests/kernels/attention/conftest.py b/tests/kernels/attention/conftest.py index b080a71bd54e..e520267320c0 100644 --- a/tests/kernels/attention/conftest.py +++ b/tests/kernels/attention/conftest.py @@ -3,7 +3,10 @@ import pytest -from vllm.utils import create_kv_caches_with_random, create_kv_caches_with_random_flash +from vllm.utils.torch_utils import ( + create_kv_caches_with_random, + create_kv_caches_with_random_flash, +) @pytest.fixture() diff --git a/tests/kernels/attention/test_prefix_prefill.py b/tests/kernels/attention/test_prefix_prefill.py index 5ff2624cd7a4..65972d02f2f6 100644 --- a/tests/kernels/attention/test_prefix_prefill.py +++ b/tests/kernels/attention/test_prefix_prefill.py @@ -15,7 +15,7 @@ from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode from vllm.attention.ops.prefix_prefill import context_attention_fwd from vllm.platforms import current_platform -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE NUM_HEADS = [64] NUM_QUERIES_PER_KV = [1, 64] diff --git a/tests/models/multimodal/pooling/test_intern_vit.py b/tests/models/multimodal/pooling/test_intern_vit.py index 74e30c4307fa..5a97848216b8 100644 --- a/tests/models/multimodal/pooling/test_intern_vit.py +++ b/tests/models/multimodal/pooling/test_intern_vit.py @@ -7,7 +7,7 @@ from transformers import AutoConfig, AutoModel, CLIPImageProcessor from vllm.distributed import cleanup_dist_env_and_memory -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from ....conftest import ImageTestAssets diff --git a/tests/models/multimodal/pooling/test_radio.py b/tests/models/multimodal/pooling/test_radio.py index 414e99a71e7b..8929563d8b05 100644 --- a/tests/models/multimodal/pooling/test_radio.py +++ b/tests/models/multimodal/pooling/test_radio.py @@ -9,7 +9,7 @@ from vllm.distributed import cleanup_dist_env_and_memory from vllm.model_executor.models.radio import RadioModel from vllm.transformers_utils.configs.radio import RadioConfig -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from ....conftest import ImageTestAssets diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index 07706d4b956c..4d88367ffd02 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -18,7 +18,8 @@ from vllm.attention.backends.registry import _Backend from vllm.config import ModelConfig from vllm.platforms import current_platform -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_torch_equal_or_newer +from vllm.utils import cdiv, is_torch_equal_or_newer +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata, set_kv_cache_layout, diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index f41f63ed2af4..81fd6433b0c8 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -22,7 +22,8 @@ from vllm.attention.backends.registry import _Backend from vllm.attention.ops.flashmla import is_flashmla_dense_supported from vllm.config.vllm import set_current_vllm_config -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv +from vllm.utils import cdiv +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import FullAttentionSpec diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index 41ab7f3fecdb..91a45623582d 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -6,7 +6,10 @@ from vllm.config.cache import MambaDType from vllm.config.model import ModelDType from vllm.distributed import divide -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_kv_cache_torch_dtype +from vllm.utils.torch_utils import ( + STR_DTYPE_TO_TORCH_DTYPE, + get_kv_cache_torch_dtype, +) class MambaStateDtypeCalculator: diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 662f2c9209f4..da5d80f9828e 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -6,7 +6,8 @@ import vllm.envs as envs from vllm.logger import init_logger from vllm.model_executor.models import ModelRegistry -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, round_up +from vllm.utils import cdiv, round_up +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec if TYPE_CHECKING: diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 4137f9b5c852..d90212864441 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -40,7 +40,6 @@ from collections import UserDict, defaultdict from collections.abc import ( Callable, - Collection, Generator, Hashable, Iterable, @@ -135,18 +134,6 @@ CYAN = "\033[1;36m" RESET = "\033[0;0m" -STR_DTYPE_TO_TORCH_DTYPE = { - "float32": torch.float32, - "half": torch.half, - "bfloat16": torch.bfloat16, - "float": torch.float, - "fp8": torch.uint8, - "fp8_e4m3": torch.uint8, - "fp8_e5m2": torch.uint8, - "int8": torch.int8, - "fp8_inc": torch.float8_e4m3fn, - "fp8_ds_mla": torch.uint8, -} TORCH_DTYPE_TO_NUMPY_DTYPE = { torch.float16: np.float16, @@ -445,141 +432,6 @@ def round_down(x: int, y: int) -> int: return (x // y) * y -def _generate_random_fp8( - tensor: torch.Tensor, - low: float, - high: float, -) -> None: - # NOTE(zhaoyang): Due to NaN and Inf representation for fp8 data type, - # it may occur Inf or NaN if we directly use torch.randint - # to generate random data for fp8 data. - # For example, s.11111.00 in fp8e5m2 format represents Inf. - # | E4M3 | E5M2 - # -----|-------------|------------------- - # Inf | N/A | s.11111.00 - # NaN | s.1111.111 | s.11111.{01,10,11} - from vllm import _custom_ops as ops - - tensor_tmp = torch.empty_like(tensor, dtype=torch.float16) - tensor_tmp.uniform_(low, high) - ops.convert_fp8(tensor, tensor_tmp) - del tensor_tmp - - -def get_kv_cache_torch_dtype( - cache_dtype: str | torch.dtype | None, - model_dtype: str | torch.dtype | None = None, -) -> torch.dtype: - if isinstance(cache_dtype, str): - if cache_dtype == "auto": - if isinstance(model_dtype, str) and model_dtype in STR_DTYPE_TO_TORCH_DTYPE: - torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype] - elif isinstance(model_dtype, torch.dtype): - torch_dtype = model_dtype - else: - raise ValueError(f"Invalid model dtype: {model_dtype}") - elif cache_dtype in STR_DTYPE_TO_TORCH_DTYPE: - torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] - else: - raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") - elif isinstance(cache_dtype, torch.dtype): - torch_dtype = cache_dtype - else: - raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") - return torch_dtype - - -def create_kv_caches_with_random_flash( - num_blocks: int, - block_size: int, - num_layers: int, - num_heads: int, - head_size: int, - cache_dtype: str | torch.dtype | None, - model_dtype: str | torch.dtype | None = None, - seed: int | None = None, - device: str | None = "cuda", - cache_layout: str | None = "NHD", -) -> tuple[list[torch.Tensor], list[torch.Tensor]]: - from vllm.platforms import current_platform - - current_platform.seed_everything(seed) - - dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) - generic_kv_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) - assert cache_layout in ("NHD", "HND") - stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3, 2, 4) - - kv_cache_allocation_shape = tuple(generic_kv_cache_shape[i] for i in stride_order) - scale = head_size**-0.5 - - key_caches: list[torch.Tensor] = [] - value_caches: list[torch.Tensor] = [] - - for _ in range(num_layers): - key_value_cache = torch.empty( - size=kv_cache_allocation_shape, dtype=dtype, device=device - ).permute(*stride_order) - if cache_dtype in ["auto", "half", "bfloat16", "float"]: - key_value_cache.uniform_(-scale, scale) - elif cache_dtype == "fp8": - _generate_random_fp8(key_value_cache, -scale, scale) - else: - raise ValueError(f"Does not support key cache of type {cache_dtype}") - key_caches.append(key_value_cache[:, 0]) - value_caches.append(key_value_cache[:, 1]) - return key_caches, value_caches - - -def create_kv_caches_with_random( - num_blocks: int, - block_size: int, - num_layers: int, - num_heads: int, - head_size: int, - cache_dtype: str | torch.dtype | None, - model_dtype: str | torch.dtype | None = None, - seed: int | None = None, - device: str | None = "cuda", -) -> tuple[list[torch.Tensor], list[torch.Tensor]]: - if cache_dtype == "fp8" and head_size % 16: - raise ValueError( - f"Does not support key cache of type fp8 with head_size {head_size}" - ) - from vllm.platforms import current_platform - - current_platform.seed_everything(seed) - - dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) - - scale = head_size**-0.5 - x = 16 // torch.tensor([], dtype=dtype).element_size() - key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) - key_caches: list[torch.Tensor] = [] - for _ in range(num_layers): - key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device=device) - if cache_dtype in ["auto", "half", "bfloat16", "float"]: - key_cache.uniform_(-scale, scale) - elif cache_dtype == "fp8": - _generate_random_fp8(key_cache, -scale, scale) - else: - raise ValueError(f"Does not support key cache of type {cache_dtype}") - key_caches.append(key_cache) - - value_cache_shape = (num_blocks, num_heads, head_size, block_size) - value_caches: list[torch.Tensor] = [] - for _ in range(num_layers): - value_cache = torch.empty(size=value_cache_shape, dtype=dtype, device=device) - if cache_dtype in ["auto", "half", "bfloat16", "float"]: - value_cache.uniform_(-scale, scale) - elif cache_dtype == "fp8": - _generate_random_fp8(value_cache, -scale, scale) - else: - raise ValueError(f"Does not support value cache of type {cache_dtype}") - value_caches.append(value_cache) - return key_caches, value_caches - - @cache def is_pin_memory_available() -> bool: from vllm.platforms import current_platform diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py index 3e78deda8dac..422db050fa73 100644 --- a/vllm/utils/torch_utils.py +++ b/vllm/utils/torch_utils.py @@ -5,6 +5,19 @@ import torch +STR_DTYPE_TO_TORCH_DTYPE = { + "float32": torch.float32, + "half": torch.half, + "bfloat16": torch.bfloat16, + "float": torch.float, + "fp8": torch.uint8, + "fp8_e4m3": torch.uint8, + "fp8_e5m2": torch.uint8, + "int8": torch.int8, + "fp8_inc": torch.float8_e4m3fn, + "fp8_ds_mla": torch.uint8, +} + @contextlib.contextmanager def set_default_torch_dtype(dtype: torch.dtype): @@ -76,3 +89,138 @@ def common_broadcastable_dtype(dtypes: Collection[torch.dtype]): dtypes, key=lambda dtype: sum(is_lossless_cast(dt, dtype) for dt in dtypes), ) + + +def _generate_random_fp8( + tensor: torch.Tensor, + low: float, + high: float, +) -> None: + # NOTE(zhaoyang): Due to NaN and Inf representation for fp8 data type, + # it may occur Inf or NaN if we directly use torch.randint + # to generate random data for fp8 data. + # For example, s.11111.00 in fp8e5m2 format represents Inf. + # | E4M3 | E5M2 + # -----|-------------|------------------- + # Inf | N/A | s.11111.00 + # NaN | s.1111.111 | s.11111.{01,10,11} + from vllm import _custom_ops as ops + + tensor_tmp = torch.empty_like(tensor, dtype=torch.float16) + tensor_tmp.uniform_(low, high) + ops.convert_fp8(tensor, tensor_tmp) + del tensor_tmp + + +def get_kv_cache_torch_dtype( + cache_dtype: str | torch.dtype | None, + model_dtype: str | torch.dtype | None = None, +) -> torch.dtype: + if isinstance(cache_dtype, str): + if cache_dtype == "auto": + if isinstance(model_dtype, str) and model_dtype in STR_DTYPE_TO_TORCH_DTYPE: + torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype] + elif isinstance(model_dtype, torch.dtype): + torch_dtype = model_dtype + else: + raise ValueError(f"Invalid model dtype: {model_dtype}") + elif cache_dtype in STR_DTYPE_TO_TORCH_DTYPE: + torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] + else: + raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") + elif isinstance(cache_dtype, torch.dtype): + torch_dtype = cache_dtype + else: + raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") + return torch_dtype + + +def create_kv_caches_with_random_flash( + num_blocks: int, + block_size: int, + num_layers: int, + num_heads: int, + head_size: int, + cache_dtype: str | torch.dtype | None, + model_dtype: str | torch.dtype | None = None, + seed: int | None = None, + device: str | None = "cuda", + cache_layout: str | None = "NHD", +) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + from vllm.platforms import current_platform + + current_platform.seed_everything(seed) + + dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) + generic_kv_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) + assert cache_layout in ("NHD", "HND") + stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3, 2, 4) + + kv_cache_allocation_shape = tuple(generic_kv_cache_shape[i] for i in stride_order) + scale = head_size**-0.5 + + key_caches: list[torch.Tensor] = [] + value_caches: list[torch.Tensor] = [] + + for _ in range(num_layers): + key_value_cache = torch.empty( + size=kv_cache_allocation_shape, dtype=dtype, device=device + ).permute(*stride_order) + if cache_dtype in ["auto", "half", "bfloat16", "float"]: + key_value_cache.uniform_(-scale, scale) + elif cache_dtype == "fp8": + _generate_random_fp8(key_value_cache, -scale, scale) + else: + raise ValueError(f"Does not support key cache of type {cache_dtype}") + key_caches.append(key_value_cache[:, 0]) + value_caches.append(key_value_cache[:, 1]) + return key_caches, value_caches + + +def create_kv_caches_with_random( + num_blocks: int, + block_size: int, + num_layers: int, + num_heads: int, + head_size: int, + cache_dtype: str | torch.dtype | None, + model_dtype: str | torch.dtype | None = None, + seed: int | None = None, + device: str | None = "cuda", +) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + if cache_dtype == "fp8" and head_size % 16: + raise ValueError( + f"Does not support key cache of type fp8 with head_size {head_size}" + ) + from vllm.platforms import current_platform + + current_platform.seed_everything(seed) + + dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) + + scale = head_size**-0.5 + x = 16 // torch.tensor([], dtype=dtype).element_size() + key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) + key_caches: list[torch.Tensor] = [] + for _ in range(num_layers): + key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device=device) + if cache_dtype in ["auto", "half", "bfloat16", "float"]: + key_cache.uniform_(-scale, scale) + elif cache_dtype == "fp8": + _generate_random_fp8(key_cache, -scale, scale) + else: + raise ValueError(f"Does not support key cache of type {cache_dtype}") + key_caches.append(key_cache) + + value_cache_shape = (num_blocks, num_heads, head_size, block_size) + value_caches: list[torch.Tensor] = [] + for _ in range(num_layers): + value_cache = torch.empty(size=value_cache_shape, dtype=dtype, device=device) + if cache_dtype in ["auto", "half", "bfloat16", "float"]: + value_cache.uniform_(-scale, scale) + elif cache_dtype == "fp8": + _generate_random_fp8(value_cache, -scale, scale) + else: + raise ValueError(f"Does not support value cache of type {cache_dtype}") + value_caches.append(value_cache) + return key_caches, value_caches diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index dd2a86a257c5..d06244c7f280 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -73,7 +73,6 @@ from vllm.sequence import IntermediateTensors from vllm.tasks import GenerationTask, PoolingTask, SupportedTask from vllm.utils import ( - STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, GiB_bytes, cdiv, @@ -84,7 +83,10 @@ supports_dynamo, ) from vllm.utils.jsontree import json_map_leaves -from vllm.utils.torch_utils import get_dtype_size +from vllm.utils.torch_utils import ( + STR_DTYPE_TO_TORCH_DTYPE, + get_dtype_size, +) from vllm.v1.attention.backends.flash_attn import AttentionMetadata from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.utils import ( diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 9bce362120ac..1a758386d1c9 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -26,7 +26,8 @@ from vllm.platforms import current_platform from vllm.platforms.tpu import USE_TPU_INFERENCE from vllm.tasks import SupportedTask -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv +from vllm.utils import cdiv +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ModelRunnerOutput From 47601ee7001363440b53d6140297090934169ce8 Mon Sep 17 00:00:00 2001 From: isotr0py <2037008807@qq.com> Date: Thu, 16 Oct 2025 14:50:28 +0800 Subject: [PATCH 05/15] move current_stream Signed-off-by: isotr0py <2037008807@qq.com> --- tests/utils_/test_utils.py | 7 +- .../device_communicators/pynccl.py | 2 +- .../device_communicators/ray_communicator.py | 2 +- .../kv_connector/v1/p2p/p2p_nccl_engine.py | 3 +- vllm/utils/__init__.py | 65 ------------------ vllm/utils/torch_utils.py | 66 +++++++++++++++++++ vllm/v1/worker/ubatching.py | 2 +- 7 files changed, 76 insertions(+), 71 deletions(-) diff --git a/tests/utils_/test_utils.py b/tests/utils_/test_utils.py index c505e6fd9c60..ea64514dbbb9 100644 --- a/tests/utils_/test_utils.py +++ b/tests/utils_/test_utils.py @@ -26,7 +26,6 @@ MemorySnapshot, PlaceholderModule, bind_kv_cache, - current_stream, get_open_port, get_tcp_uri, join_host_port, @@ -39,7 +38,11 @@ swap_dict_values, unique_filepath, ) -from vllm.utils.torch_utils import common_broadcastable_dtype, is_lossless_cast +from vllm.utils.torch_utils import ( + common_broadcastable_dtype, + current_stream, + is_lossless_cast, +) from ..utils import create_new_process_for_each_test diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index f08330879178..6aaa48cc14bf 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -19,7 +19,7 @@ ) from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger -from vllm.utils import current_stream +from vllm.utils.torch_utils import current_stream logger = init_logger(__name__) diff --git a/vllm/distributed/device_communicators/ray_communicator.py b/vllm/distributed/device_communicators/ray_communicator.py index 732a40770f25..3b02b885e786 100644 --- a/vllm/distributed/device_communicators/ray_communicator.py +++ b/vllm/distributed/device_communicators/ray_communicator.py @@ -14,7 +14,7 @@ ) from vllm.distributed.parallel_state import get_pp_group from vllm.logger import init_logger -from vllm.utils import current_stream +from vllm.utils.torch_utils import current_stream logger = init_logger(__name__) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py index 7714359a5091..5b32a9756637 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py @@ -25,7 +25,8 @@ from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ( # noqa: E501 TensorMemoryPool, ) -from vllm.utils import current_stream, get_ip +from vllm.utils import get_ip +from vllm.utils.torch_utils import current_stream logger = logging.getLogger(__name__) diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index d90212864441..14e60b23ca79 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -521,17 +521,6 @@ def make_tensor_with_pad( return tensor -def async_tensor_h2d( - data: list, - dtype: torch.dtype, - target_device: str | torch.device, - pin_memory: bool, -) -> torch.Tensor: - """Asynchronously create a tensor and copy it from host to device.""" - t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu") - return t.to(device=target_device, non_blocking=True) - - def as_list(maybe_list: Iterable[T]) -> list[T]: """Convert iterable to list, unless it's already a list.""" return maybe_list if isinstance(maybe_list, list) else list(maybe_list) @@ -676,60 +665,6 @@ def find_nccl_include_paths() -> list[str] | None: return out or None -prev_set_stream = torch.cuda.set_stream - -_current_stream_tls = threading.local() - - -def _patched_set_stream(stream: torch.cuda.Stream) -> None: - _current_stream_tls.value = stream - prev_set_stream(stream) - - -torch.cuda.set_stream = _patched_set_stream - - -class _StreamPlaceholder: - def __init__(self): - self.synchronize = lambda: None - - -def current_stream() -> torch.cuda.Stream: - """ - replace `torch.cuda.current_stream()` with `vllm.utils.current_stream()`. - it turns out that `torch.cuda.current_stream()` is quite expensive, - as it will construct a new stream object at each call. - here we patch `torch.cuda.set_stream` to keep track of the current stream - directly, so that we can avoid calling `torch.cuda.current_stream()`. - - the underlying hypothesis is that we do not call `torch._C._cuda_setStream` - from C/C++ code. - """ - from vllm.platforms import current_platform - - if not hasattr(_current_stream_tls, "value") or _current_stream_tls.value is None: - # when this function is called before any stream is set, - # we return the default stream. - # On ROCm using the default 0 stream in combination with RCCL - # is hurting performance. Therefore creating a dedicated stream - # per process - if current_platform.is_rocm(): - # torch.cuda.set_stream here is the alias of _pathed_set_stream - torch.cuda.set_stream(torch.cuda.Stream()) - elif current_platform.is_cpu(): - _current_stream_tls.value = _StreamPlaceholder() - else: - current_stream = current_platform.current_stream - if current_stream is not None: - _current_stream_tls.value = current_stream() - else: - raise ValueError( - "Fail to set current stream, current platform " - "may not support current_stream with torch API" - ) - return _current_stream_tls.value - - def enable_trace_function_call_for_thread(vllm_config: VllmConfig) -> None: """Set up function tracing for the current thread, if enabled via the VLLM_TRACE_FUNCTION environment variable diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py index 422db050fa73..34133c2ed43a 100644 --- a/vllm/utils/torch_utils.py +++ b/vllm/utils/torch_utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib +import threading from collections.abc import Collection import torch @@ -224,3 +225,68 @@ def create_kv_caches_with_random( raise ValueError(f"Does not support value cache of type {cache_dtype}") value_caches.append(value_cache) return key_caches, value_caches + + +def async_tensor_h2d( + data: list, + dtype: torch.dtype, + target_device: str | torch.device, + pin_memory: bool, +) -> torch.Tensor: + """Asynchronously create a tensor and copy it from host to device.""" + t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu") + return t.to(device=target_device, non_blocking=True) + + +prev_set_stream = torch.cuda.set_stream + +_current_stream_tls = threading.local() + + +def _patched_set_stream(stream: torch.cuda.Stream) -> None: + _current_stream_tls.value = stream + prev_set_stream(stream) + + +torch.cuda.set_stream = _patched_set_stream + + +class _StreamPlaceholder: + def __init__(self): + self.synchronize = lambda: None + + +def current_stream() -> torch.cuda.Stream: + """ + replace `torch.cuda.current_stream()` with `vllm.utils.current_stream()`. + it turns out that `torch.cuda.current_stream()` is quite expensive, + as it will construct a new stream object at each call. + here we patch `torch.cuda.set_stream` to keep track of the current stream + directly, so that we can avoid calling `torch.cuda.current_stream()`. + + the underlying hypothesis is that we do not call `torch._C._cuda_setStream` + from C/C++ code. + """ + from vllm.platforms import current_platform + + if not hasattr(_current_stream_tls, "value") or _current_stream_tls.value is None: + # when this function is called before any stream is set, + # we return the default stream. + # On ROCm using the default 0 stream in combination with RCCL + # is hurting performance. Therefore creating a dedicated stream + # per process + if current_platform.is_rocm(): + # torch.cuda.set_stream here is the alias of _pathed_set_stream + torch.cuda.set_stream(torch.cuda.Stream()) + elif current_platform.is_cpu(): + _current_stream_tls.value = _StreamPlaceholder() + else: + current_stream = current_platform.current_stream + if current_stream is not None: + _current_stream_tls.value = current_stream() + else: + raise ValueError( + "Fail to set current stream, current platform " + "may not support current_stream with torch API" + ) + return _current_stream_tls.value diff --git a/vllm/v1/worker/ubatching.py b/vllm/v1/worker/ubatching.py index 867ce2b93036..6edcb7848638 100644 --- a/vllm/v1/worker/ubatching.py +++ b/vllm/v1/worker/ubatching.py @@ -7,7 +7,7 @@ from vllm import forward_context from vllm.forward_context import ForwardContext -from vllm.utils import current_stream +from vllm.utils.torch_utils import current_stream _THREAD_ID_TO_CONTEXT: dict = {} _CURRENT_CONTEXTS: list[Optional["UBatchContext"]] = [None, None] From 7cacad09b6b37d3ac4baef0732dc6de8d2faadca Mon Sep 17 00:00:00 2001 From: isotr0py <2037008807@qq.com> Date: Thu, 16 Oct 2025 16:46:44 +0800 Subject: [PATCH 06/15] move cuda_device_count_stateless Signed-off-by: isotr0py <2037008807@qq.com> --- tests/compile/test_basic_correctness.py | 2 +- tests/distributed/test_utils.py | 2 +- .../moe/test_modular_kernel_combinations.py | 3 +- tests/utils.py | 2 +- tests/utils_/test_utils.py | 2 +- tests/v1/shutdown/test_delete.py | 2 +- tests/v1/shutdown/test_forward_error.py | 2 +- tests/v1/shutdown/test_startup_error.py | 2 +- vllm/config/parallel.py | 3 +- .../device_communicators/all_reduce_utils.py | 3 +- .../device_communicators/custom_all_reduce.py | 2 +- .../device_communicators/quick_all_reduce.py | 2 +- vllm/platforms/cuda.py | 3 +- vllm/platforms/rocm.py | 2 +- vllm/usage/usage_lib.py | 3 +- vllm/utils/__init__.py | 45 +----------------- vllm/utils/torch_utils.py | 46 +++++++++++++++++++ 17 files changed, 67 insertions(+), 59 deletions(-) diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index 954774a8e398..132a838b8d44 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -5,7 +5,7 @@ import pytest from vllm.config import CompilationMode -from vllm.utils import cuda_device_count_stateless +from vllm.utils.torch_utils import cuda_device_count_stateless from ..utils import compare_all_settings diff --git a/tests/distributed/test_utils.py b/tests/distributed/test_utils.py index 2a6936fcd4c2..c10c2565811b 100644 --- a/tests/distributed/test_utils.py +++ b/tests/distributed/test_utils.py @@ -11,10 +11,10 @@ from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.utils import StatelessProcessGroup from vllm.utils import ( - cuda_device_count_stateless, get_open_port, update_environment_variables, ) +from vllm.utils.torch_utils import cuda_device_count_stateless from ..utils import multi_gpu_test diff --git a/tests/kernels/moe/test_modular_kernel_combinations.py b/tests/kernels/moe/test_modular_kernel_combinations.py index a86185a2dc46..a7beb313011a 100644 --- a/tests/kernels/moe/test_modular_kernel_combinations.py +++ b/tests/kernels/moe/test_modular_kernel_combinations.py @@ -13,8 +13,9 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.config import VllmConfig, set_current_vllm_config from vllm.platforms import current_platform -from vllm.utils import cuda_device_count_stateless, has_deep_ep, has_deep_gemm, has_pplx +from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe +from vllm.utils.torch_utils import cuda_device_count_stateless from .modular_kernel_tools.common import ( Config, diff --git a/tests/utils.py b/tests/utils.py index 5bfdf703390e..ad5139c7f8ba 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -46,9 +46,9 @@ from vllm.utils import ( FlexibleArgumentParser, GB_bytes, - cuda_device_count_stateless, get_open_port, ) +from vllm.utils.torch_utils import cuda_device_count_stateless if current_platform.is_rocm(): from amdsmi import ( diff --git a/tests/utils_/test_utils.py b/tests/utils_/test_utils.py index ea64514dbbb9..d7f702f421d3 100644 --- a/tests/utils_/test_utils.py +++ b/tests/utils_/test_utils.py @@ -413,7 +413,7 @@ def test_bind_kv_cache_non_attention(): def test_bind_kv_cache_pp(): - with patch("vllm.utils.cuda_device_count_stateless", lambda: 2): + with patch("vllm.utils.torch_utils.cuda_device_count_stateless", lambda: 2): # this test runs with 1 GPU, but we simulate 2 GPUs cfg = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=2)) with set_current_vllm_config(cfg): diff --git a/tests/v1/shutdown/test_delete.py b/tests/v1/shutdown/test_delete.py index d94357827864..255515948433 100644 --- a/tests/v1/shutdown/test_delete.py +++ b/tests/v1/shutdown/test_delete.py @@ -12,7 +12,7 @@ from vllm import LLM, SamplingParams from vllm.engine.arg_utils import AsyncEngineArgs from vllm.sampling_params import RequestOutputKind -from vllm.utils import cuda_device_count_stateless +from vllm.utils.torch_utils import cuda_device_count_stateless from vllm.v1.engine.async_llm import AsyncLLM MODELS = ["meta-llama/Llama-3.2-1B"] diff --git a/tests/v1/shutdown/test_forward_error.py b/tests/v1/shutdown/test_forward_error.py index 383348e88540..e65d46dfa43a 100644 --- a/tests/v1/shutdown/test_forward_error.py +++ b/tests/v1/shutdown/test_forward_error.py @@ -14,7 +14,7 @@ from vllm import LLM, AsyncEngineArgs, SamplingParams from vllm.distributed import get_tensor_model_parallel_rank from vllm.model_executor.models.llama import LlamaForCausalLM -from vllm.utils import cuda_device_count_stateless +from vllm.utils.torch_utils import cuda_device_count_stateless from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.exceptions import EngineDeadError diff --git a/tests/v1/shutdown/test_startup_error.py b/tests/v1/shutdown/test_startup_error.py index 019c0c4d7cf0..3877fceae00c 100644 --- a/tests/v1/shutdown/test_startup_error.py +++ b/tests/v1/shutdown/test_startup_error.py @@ -13,7 +13,7 @@ from vllm.distributed import get_tensor_model_parallel_rank from vllm.engine.arg_utils import AsyncEngineArgs from vllm.model_executor.models.llama import LlamaForCausalLM -from vllm.utils import cuda_device_count_stateless +from vllm.utils.torch_utils import cuda_device_count_stateless from vllm.v1.engine.async_llm import AsyncLLM MODELS = ["meta-llama/Llama-3.2-1B"] diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 944a1e8666f4..5c7fbbebe3cc 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -15,7 +15,8 @@ from vllm.config.utils import config from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import cuda_device_count_stateless, get_open_ports_list +from vllm.utils import get_open_ports_list +from vllm.utils.torch_utils import cuda_device_count_stateless if TYPE_CHECKING: from ray.runtime_env import RuntimeEnv diff --git a/vllm/distributed/device_communicators/all_reduce_utils.py b/vllm/distributed/device_communicators/all_reduce_utils.py index 9e99fd01a919..413b8096d2e3 100644 --- a/vllm/distributed/device_communicators/all_reduce_utils.py +++ b/vllm/distributed/device_communicators/all_reduce_utils.py @@ -19,7 +19,8 @@ import vllm.envs as envs from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary from vllm.logger import init_logger -from vllm.utils import cuda_device_count_stateless, update_environment_variables +from vllm.utils import update_environment_variables +from vllm.utils.torch_utils import cuda_device_count_stateless logger = init_logger(__name__) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 4bc737494cb5..4b82f3b5d396 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -17,7 +17,7 @@ from vllm.distributed.parallel_state import in_the_same_node_as from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import cuda_device_count_stateless +from vllm.utils.torch_utils import cuda_device_count_stateless try: ops.meta_size() diff --git a/vllm/distributed/device_communicators/quick_all_reduce.py b/vllm/distributed/device_communicators/quick_all_reduce.py index 7a9574963526..9c7765883cfd 100644 --- a/vllm/distributed/device_communicators/quick_all_reduce.py +++ b/vllm/distributed/device_communicators/quick_all_reduce.py @@ -13,7 +13,7 @@ from vllm.distributed.parallel_state import in_the_same_node_as from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import cuda_device_count_stateless +from vllm.utils.torch_utils import cuda_device_count_stateless logger = init_logger(__name__) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index a6b9df7c1446..c736e084a38d 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -16,7 +16,8 @@ import vllm._C # noqa import vllm.envs as envs from vllm.logger import init_logger -from vllm.utils import cuda_device_count_stateless, import_pynvml +from vllm.utils import import_pynvml +from vllm.utils.torch_utils import cuda_device_count_stateless from .interface import DeviceCapability, Platform, PlatformEnum diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index b25b96889309..68e6c06c8814 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -9,7 +9,7 @@ import vllm.envs as envs from vllm.logger import init_logger -from vllm.utils import cuda_device_count_stateless +from vllm.utils.torch_utils import cuda_device_count_stateless from .interface import DeviceCapability, Platform, PlatformEnum diff --git a/vllm/usage/usage_lib.py b/vllm/usage/usage_lib.py index 27a4f89e0045..4211535131a4 100644 --- a/vllm/usage/usage_lib.py +++ b/vllm/usage/usage_lib.py @@ -21,7 +21,8 @@ import vllm.envs as envs from vllm.connections import global_http_connection from vllm.logger import init_logger -from vllm.utils import cuda_device_count_stateless, cuda_get_device_properties +from vllm.utils import cuda_get_device_properties +from vllm.utils.torch_utils import cuda_device_count_stateless from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 14e60b23ca79..88393dd2235e 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -49,7 +49,7 @@ ) from concurrent.futures.process import ProcessPoolExecutor from dataclasses import dataclass, field -from functools import cache, lru_cache, partial, wraps +from functools import cache, partial, wraps from pathlib import Path from typing import ( TYPE_CHECKING, @@ -686,49 +686,6 @@ def enable_trace_function_call_for_thread(vllm_config: VllmConfig) -> None: enable_trace_function_call(log_path) -@lru_cache(maxsize=8) -def _cuda_device_count_stateless(cuda_visible_devices: str | None = None) -> int: - # Note: cuda_visible_devices is not used, but we keep it as an argument for - # LRU Cache purposes. - - # Code below is based on - # https://github.com/pytorch/pytorch/blob/ - # c1cd946818442aca8c7f812b16d187ce1586c3bc/ - # torch/cuda/__init__.py#L831C1-L831C17 - import torch.cuda - import torch.version - - from vllm.platforms import current_platform - - if not torch.cuda._is_compiled(): - return 0 - if current_platform.is_rocm(): - # ROCm uses amdsmi instead of nvml for stateless device count - # This requires a sufficiently modern version of Torch 2.4.0 - raw_count = ( - torch.cuda._device_count_amdsmi() - if (hasattr(torch.cuda, "_device_count_amdsmi")) - else -1 - ) - else: - raw_count = torch.cuda._device_count_nvml() - r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count - return r - - -def cuda_device_count_stateless() -> int: - """Get number of CUDA devices, caching based on the value of - CUDA_VISIBLE_DEVICES at the time of call. - - This should be used instead of torch.cuda.device_count() - unless CUDA_VISIBLE_DEVICES has already been set to the desired - value.""" - - # This can be removed and simply replaced with torch.cuda.get_device_count - # after https://github.com/pytorch/pytorch/pull/122815 is released. - return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES) - - def cuda_is_initialized() -> bool: """Check if CUDA is initialized.""" if not torch.cuda._is_compiled(): diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py index 34133c2ed43a..bcb2a26139fa 100644 --- a/vllm/utils/torch_utils.py +++ b/vllm/utils/torch_utils.py @@ -3,9 +3,12 @@ import contextlib import threading from collections.abc import Collection +from functools import lru_cache import torch +import vllm.envs as envs + STR_DTYPE_TO_TORCH_DTYPE = { "float32": torch.float32, "half": torch.half, @@ -290,3 +293,46 @@ def current_stream() -> torch.cuda.Stream: "may not support current_stream with torch API" ) return _current_stream_tls.value + + +@lru_cache(maxsize=8) +def _cuda_device_count_stateless(cuda_visible_devices: str | None = None) -> int: + # Note: cuda_visible_devices is not used, but we keep it as an argument for + # LRU Cache purposes. + + # Code below is based on + # https://github.com/pytorch/pytorch/blob/ + # c1cd946818442aca8c7f812b16d187ce1586c3bc/ + # torch/cuda/__init__.py#L831C1-L831C17 + import torch.cuda + import torch.version + + from vllm.platforms import current_platform + + if not torch.cuda._is_compiled(): + return 0 + if current_platform.is_rocm(): + # ROCm uses amdsmi instead of nvml for stateless device count + # This requires a sufficiently modern version of Torch 2.4.0 + raw_count = ( + torch.cuda._device_count_amdsmi() + if (hasattr(torch.cuda, "_device_count_amdsmi")) + else -1 + ) + else: + raw_count = torch.cuda._device_count_nvml() + r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count + return r + + +def cuda_device_count_stateless() -> int: + """Get number of CUDA devices, caching based on the value of + CUDA_VISIBLE_DEVICES at the time of call. + + This should be used instead of torch.cuda.device_count() + unless CUDA_VISIBLE_DEVICES has already been set to the desired + value.""" + + # This can be removed and simply replaced with torch.cuda.get_device_count + # after https://github.com/pytorch/pytorch/pull/122815 is released. + return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES) From b97b512f8ac0d54f91a8153a69f8db5ee105214d Mon Sep 17 00:00:00 2001 From: isotr0py <2037008807@qq.com> Date: Thu, 16 Oct 2025 21:27:30 +0800 Subject: [PATCH 07/15] weak_ref_tensors and get_cuda_view_from_cpu_tensor Signed-off-by: isotr0py <2037008807@qq.com> --- tests/kernels/core/test_uva.py | 3 +- vllm/compilation/cuda_graph.py | 2 +- vllm/model_executor/models/utils.py | 2 +- vllm/utils/__init__.py | 50 -------------------------- vllm/utils/torch_utils.py | 55 +++++++++++++++++++++++++++++ 5 files changed, 59 insertions(+), 53 deletions(-) diff --git a/tests/kernels/core/test_uva.py b/tests/kernels/core/test_uva.py index 73738175e5c7..dee92976eb6f 100644 --- a/tests/kernels/core/test_uva.py +++ b/tests/kernels/core/test_uva.py @@ -3,7 +3,8 @@ import pytest import torch -from vllm.utils import get_cuda_view_from_cpu_tensor, is_uva_available +from vllm.utils import is_uva_available +from vllm.utils.torch_utils import get_cuda_view_from_cpu_tensor CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py index fe20a5f7e63e..a2e0abfebc2c 100644 --- a/vllm/compilation/cuda_graph.py +++ b/vllm/compilation/cuda_graph.py @@ -17,7 +17,7 @@ from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import weak_ref_tensors +from vllm.utils.torch_utils import weak_ref_tensors logger = init_logger(__name__) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 71abfe98813d..c6fa50ccbc66 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -25,10 +25,10 @@ from vllm.utils import ( cdiv, direct_register_custom_op, - get_cuda_view_from_cpu_tensor, is_pin_memory_available, is_uva_available, ) +from vllm.utils.torch_utils import get_cuda_view_from_cpu_tensor logger = init_logger(__name__) diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 2ed8e98811a8..f0d283d7d3fd 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -77,13 +77,11 @@ from argparse import Namespace from vllm.config import ModelConfig, VllmConfig - from vllm.sequence import IntermediateTensors else: Namespace = object ModelConfig = object VllmConfig = object - IntermediateTensors = object logger = init_logger(__name__) @@ -1193,54 +1191,6 @@ def value(self): return self._value -def weak_ref_tensor(tensor: Any) -> Any: - """ - Create a weak reference to a tensor. - The new tensor will share the same data as the original tensor, - but will not keep the original tensor alive. - """ - if isinstance(tensor, torch.Tensor): - return torch.ops._C.weak_ref_tensor(tensor) - else: - return tensor - - -def weak_ref_tensors( - tensors: torch.Tensor - | list[torch.Tensor] - | tuple[torch.Tensor] - | IntermediateTensors, -) -> torch.Tensor | list[Any] | tuple[Any] | Any: - """ - Convenience function to create weak references to tensors, - for single tensor, list of tensors or tuple of tensors. - """ - if isinstance(tensors, torch.Tensor): - return weak_ref_tensor(tensors) - if isinstance(tensors, list): - return [weak_ref_tensor(t) for t in tensors] - if isinstance(tensors, tuple): - return tuple(weak_ref_tensor(t) for t in tensors) - - # For IntermediateTensors used in pipeline parallelism - from vllm.sequence import IntermediateTensors - - if isinstance(tensors, IntermediateTensors): - ret = IntermediateTensors( - {key: weak_ref_tensor(val) for key, val in tensors.tensors.items()} - ) - return ret - raise ValueError("Invalid type for tensors") - - -def get_cuda_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tensor: - """ - Get a CUDA view of a CPU tensor using Unified Virtual Addressing (UVA). - """ - assert cpu_tensor.is_pinned(), "CPU tensor must be pinned" - return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor) - - def import_from_path(module_name: str, file_path: str | os.PathLike): """ Import a Python file according to its file path. diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py index bcb2a26139fa..27eb12050168 100644 --- a/vllm/utils/torch_utils.py +++ b/vllm/utils/torch_utils.py @@ -4,11 +4,18 @@ import threading from collections.abc import Collection from functools import lru_cache +from typing import TYPE_CHECKING, Any import torch import vllm.envs as envs +if TYPE_CHECKING: + from vllm.sequence import IntermediateTensors +else: + IntermediateTensors = object + + STR_DTYPE_TO_TORCH_DTYPE = { "float32": torch.float32, "half": torch.half, @@ -336,3 +343,51 @@ def cuda_device_count_stateless() -> int: # This can be removed and simply replaced with torch.cuda.get_device_count # after https://github.com/pytorch/pytorch/pull/122815 is released. return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES) + + +def weak_ref_tensor(tensor: Any) -> Any: + """ + Create a weak reference to a tensor. + The new tensor will share the same data as the original tensor, + but will not keep the original tensor alive. + """ + if isinstance(tensor, torch.Tensor): + return torch.ops._C.weak_ref_tensor(tensor) + else: + return tensor + + +def weak_ref_tensors( + tensors: torch.Tensor + | list[torch.Tensor] + | tuple[torch.Tensor] + | IntermediateTensors, +) -> torch.Tensor | list[Any] | tuple[Any] | Any: + """ + Convenience function to create weak references to tensors, + for single tensor, list of tensors or tuple of tensors. + """ + if isinstance(tensors, torch.Tensor): + return weak_ref_tensor(tensors) + if isinstance(tensors, list): + return [weak_ref_tensor(t) for t in tensors] + if isinstance(tensors, tuple): + return tuple(weak_ref_tensor(t) for t in tensors) + + # For IntermediateTensors used in pipeline parallelism + from vllm.sequence import IntermediateTensors + + if isinstance(tensors, IntermediateTensors): + ret = IntermediateTensors( + {key: weak_ref_tensor(val) for key, val in tensors.tensors.items()} + ) + return ret + raise ValueError("Invalid type for tensors") + + +def get_cuda_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tensor: + """ + Get a CUDA view of a CPU tensor using Unified Virtual Addressing (UVA). + """ + assert cpu_tensor.is_pinned(), "CPU tensor must be pinned" + return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor) From 8648ac226a0b9e6ad4392d660f9809da2161d72e Mon Sep 17 00:00:00 2001 From: isotr0py <2037008807@qq.com> Date: Thu, 16 Oct 2025 22:28:26 +0800 Subject: [PATCH 08/15] move torch verison helper and ops registration Signed-off-by: isotr0py <2037008807@qq.com> --- .../compile/piecewise/test_full_cudagraph.py | 2 +- .../compile/piecewise/test_multiple_graphs.py | 2 +- tests/compile/piecewise/test_simple.py | 2 +- tests/compile/piecewise/test_toy_llama.py | 2 +- tests/compile/silly_attention.py | 2 +- tests/compile/test_aot_compile.py | 2 +- tests/compile/test_config.py | 2 +- tests/compile/test_decorator.py | 2 +- tests/compile/test_full_graph.py | 2 +- tests/compile/test_fusion_attn.py | 2 +- tests/distributed/test_sequence_parallel.py | 2 +- tests/v1/attention/test_attention_backends.py | 4 +- vllm/attention/layer.py | 3 +- vllm/attention/ops/rocm_aiter_mla.py | 2 +- vllm/compilation/backends.py | 3 +- vllm/compilation/collective_fusion.py | 2 +- vllm/compilation/compiler_interface.py | 2 +- vllm/compilation/decorators.py | 3 +- vllm/compilation/inductor_pass.py | 2 +- vllm/config/compilation.py | 3 +- .../device_communicators/pynccl.py | 2 +- vllm/distributed/parallel_state.py | 4 +- vllm/distributed/utils.py | 3 +- vllm/env_override.py | 2 +- vllm/envs.py | 2 +- vllm/lora/ops/triton_ops/lora_expand_op.py | 2 +- vllm/lora/ops/triton_ops/lora_shrink_op.py | 2 +- .../layers/fused_moe/flashinfer_trtllm_moe.py | 2 +- .../layers/fused_moe/fused_moe.py | 2 +- .../layers/fused_moe/rocm_aiter_fused_moe.py | 2 +- vllm/model_executor/layers/fused_moe/utils.py | 3 +- vllm/model_executor/layers/layernorm.py | 2 +- .../layers/mamba/linear_attn.py | 2 +- .../layers/mamba/mamba_mixer.py | 2 +- .../layers/mamba/mamba_mixer2.py | 2 +- .../model_executor/layers/mamba/short_conv.py | 2 +- .../layers/quantization/bitsandbytes.py | 2 +- .../layers/quantization/fp_quant.py | 2 +- .../layers/quantization/gguf.py | 2 +- .../quantization/kernels/scaled_mm/aiter.py | 2 +- .../quark/schemes/quark_ocp_mx.py | 2 +- .../layers/quantization/utils/fp8_utils.py | 2 +- .../layers/quantization/utils/mxfp4_utils.py | 2 +- .../layers/quantization/utils/mxfp6_utils.py | 2 +- .../layers/quantization/utils/w8a8_utils.py | 2 +- .../layers/rotary_embedding/common.py | 2 +- .../rotary_embedding/rocm_aiter_rope_ops.py | 2 +- vllm/model_executor/layers/utils.py | 2 +- vllm/model_executor/models/deepseek_v2.py | 2 +- vllm/model_executor/models/plamo2.py | 2 +- vllm/model_executor/models/qwen3_next.py | 2 +- .../model_executor/models/transformers_moe.py | 2 +- vllm/platforms/__init__.py | 3 +- vllm/utils/__init__.py | 137 ----------------- vllm/utils/torch_utils.py | 139 +++++++++++++++++- vllm/v1/attention/backends/flex_attention.py | 3 +- vllm/v1/attention/backends/rocm_aiter_fa.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 2 +- 58 files changed, 205 insertions(+), 195 deletions(-) diff --git a/tests/compile/piecewise/test_full_cudagraph.py b/tests/compile/piecewise/test_full_cudagraph.py index e01b58220959..c6d4b5272dbc 100644 --- a/tests/compile/piecewise/test_full_cudagraph.py +++ b/tests/compile/piecewise/test_full_cudagraph.py @@ -11,7 +11,7 @@ from vllm import LLM, SamplingParams from vllm.config import CompilationConfig from vllm.platforms import current_platform -from vllm.utils import is_torch_equal_or_newer +from vllm.utils.torch_utils import is_torch_equal_or_newer @contextlib.contextmanager diff --git a/tests/compile/piecewise/test_multiple_graphs.py b/tests/compile/piecewise/test_multiple_graphs.py index 246239b87d5f..700f57ffb068 100644 --- a/tests/compile/piecewise/test_multiple_graphs.py +++ b/tests/compile/piecewise/test_multiple_graphs.py @@ -20,7 +20,7 @@ set_current_vllm_config, ) from vllm.forward_context import BatchDescriptor, set_forward_context -from vllm.utils import is_torch_equal_or_newer +from vllm.utils.torch_utils import is_torch_equal_or_newer # This import automatically registers `torch.ops.silly.attention` from .. import silly_attention # noqa: F401 diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index f61a0a4eb740..9d4e3f0f300f 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -19,7 +19,7 @@ set_current_vllm_config, ) from vllm.forward_context import BatchDescriptor, set_forward_context -from vllm.utils import is_torch_equal_or_newer +from vllm.utils.torch_utils import is_torch_equal_or_newer # This import automatically registers `torch.ops.silly.attention` from ..silly_attention import get_global_counter, reset_global_counter diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index 500cca87d96e..175ca4a23043 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -27,7 +27,7 @@ set_current_vllm_config, ) from vllm.forward_context import BatchDescriptor, set_forward_context -from vllm.utils import is_torch_equal_or_newer +from vllm.utils.torch_utils import is_torch_equal_or_newer # This import automatically registers `torch.ops.silly.attention` from .. import silly_attention # noqa: F401 diff --git a/tests/compile/silly_attention.py b/tests/compile/silly_attention.py index f33c5772906a..29c02f6e6a1d 100644 --- a/tests/compile/silly_attention.py +++ b/tests/compile/silly_attention.py @@ -8,7 +8,7 @@ import torch from torch.library import Library -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op # Shared library for all compilation test operations # Using "silly" namespace to match existing test expectations diff --git a/tests/compile/test_aot_compile.py b/tests/compile/test_aot_compile.py index 1701d85fe84e..b2734af575a1 100644 --- a/tests/compile/test_aot_compile.py +++ b/tests/compile/test_aot_compile.py @@ -15,7 +15,7 @@ set_current_vllm_config, ) from vllm.forward_context import set_forward_context -from vllm.utils import is_torch_equal_or_newer +from vllm.utils.torch_utils import is_torch_equal_or_newer def reference_fn(x: torch.Tensor): diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index 7f51c763da73..20e3c5039cc0 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -5,7 +5,7 @@ from vllm.compilation.counter import compilation_counter from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig from vllm.config.compilation import CompilationMode -from vllm.utils import _is_torch_equal_or_newer, is_torch_equal_or_newer +from vllm.utils.torch_utils import _is_torch_equal_or_newer, is_torch_equal_or_newer def test_version(): diff --git a/tests/compile/test_decorator.py b/tests/compile/test_decorator.py index e459bc539f2b..c9d01f2317d2 100644 --- a/tests/compile/test_decorator.py +++ b/tests/compile/test_decorator.py @@ -15,7 +15,7 @@ set_current_vllm_config, ) from vllm.forward_context import BatchDescriptor, set_forward_context -from vllm.utils import is_torch_equal_or_newer +from vllm.utils.torch_utils import is_torch_equal_or_newer # This import automatically registers `torch.ops.silly.attention` from . import silly_attention # noqa: F401 diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 2d290771f9ad..248a9f3c7730 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -14,7 +14,7 @@ from vllm.attention.selector import global_force_attn_backend_context_manager from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig from vllm.platforms import current_platform -from vllm.utils import is_torch_equal_or_newer +from vllm.utils.torch_utils import is_torch_equal_or_newer from ..utils import create_new_process_for_each_test diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index 4d6f4b471a3a..a35fb9c8c31f 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -33,7 +33,7 @@ ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp from vllm.platforms import current_platform -from vllm.utils import is_torch_equal_or_newer +from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.v1.kv_cache_interface import AttentionSpec FP8_DTYPE = current_platform.fp8_dtype() diff --git a/tests/distributed/test_sequence_parallel.py b/tests/distributed/test_sequence_parallel.py index deefdf22ba06..c35f6a3c2507 100644 --- a/tests/distributed/test_sequence_parallel.py +++ b/tests/distributed/test_sequence_parallel.py @@ -18,7 +18,7 @@ from vllm.config.compilation import CompilationMode from vllm.config.model import RunnerOption from vllm.logger import init_logger -from vllm.utils import is_torch_equal_or_newer +from vllm.utils.torch_utils import is_torch_equal_or_newer from ..models.registry import HF_EXAMPLE_MODELS from ..utils import compare_two_settings, create_new_process_for_each_test diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index 4d88367ffd02..7a5216d1738d 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -18,8 +18,8 @@ from vllm.attention.backends.registry import _Backend from vllm.config import ModelConfig from vllm.platforms import current_platform -from vllm.utils import cdiv, is_torch_equal_or_newer -from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE +from vllm.utils import cdiv +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE, is_torch_equal_or_newer from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata, set_kv_cache_layout, diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 9f879f7272e2..9b6ea3f0b8f5 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -34,7 +34,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.platforms import current_platform -from vllm.utils import GiB_bytes, direct_register_custom_op +from vllm.utils import GiB_bytes +from vllm.utils.torch_utils import direct_register_custom_op FP8_DTYPE = current_platform.fp8_dtype() logger = init_logger(__name__) diff --git a/vllm/attention/ops/rocm_aiter_mla.py b/vllm/attention/ops/rocm_aiter_mla.py index 8fc034dd721b..6308f63cc4e7 100644 --- a/vllm/attention/ops/rocm_aiter_mla.py +++ b/vllm/attention/ops/rocm_aiter_mla.py @@ -5,7 +5,7 @@ import torch from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer +from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer def get_aiter_mla_metadata( diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 91be7e85af51..7f4c9e6af13b 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -24,7 +24,8 @@ from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname +from vllm.utils import resolve_obj_by_qualname +from vllm.utils.torch_utils import is_torch_equal_or_newer from .caching import VllmSerializableFunction from .compiler_interface import ( diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 7c85c89bcd7a..fb1f78f21a05 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -18,7 +18,7 @@ ) from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op from .inductor_pass import enable_fake_mode from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index e2369a635ad1..0a3f0769db94 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -16,7 +16,7 @@ import vllm.envs as envs from vllm.compilation.counter import compilation_counter from vllm.config import VllmConfig -from vllm.utils import is_torch_equal_or_newer +from vllm.utils.torch_utils import is_torch_equal_or_newer class CompilerInterface: diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 20d4681e2c78..abe61cce0dd8 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -21,7 +21,8 @@ from vllm.config import CompilationMode, VllmConfig, set_current_vllm_config from vllm.logger import init_logger from vllm.sequence import IntermediateTensors -from vllm.utils import resolve_obj_by_qualname, supports_dynamo +from vllm.utils import resolve_obj_by_qualname +from vllm.utils.torch_utils import supports_dynamo from .monitor import start_monitoring_torch_compile diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index 4b263fa6f5a2..9af635a929b4 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -14,7 +14,7 @@ from torch import fx from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily -from vllm.utils import is_torch_equal_or_newer +from vllm.utils.torch_utils import is_torch_equal_or_newer if is_torch_equal_or_newer("2.6"): from torch._inductor.custom_graph_pass import CustomGraphPass diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index a34fb0bf920c..620a521a9a7e 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -16,7 +16,8 @@ from vllm.config.utils import config from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname +from vllm.utils import resolve_obj_by_qualname +from vllm.utils.torch_utils import is_torch_equal_or_newer if TYPE_CHECKING: from vllm.config import VllmConfig diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 6aaa48cc14bf..ad3c8676fafd 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -30,7 +30,7 @@ def register_nccl_symmetric_ops(pynccl_comm): from vllm.distributed.device_communicators.pynccl_allocator import ( nccl_symm_mem_context, ) - from vllm.utils import direct_register_custom_op + from vllm.utils.torch_utils import direct_register_custom_op global _NCCL_SYMM_OPS_REGISTERED if _NCCL_SYMM_OPS_REGISTERED: diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 67a8c6f7c053..10e2ba1d5925 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -50,9 +50,11 @@ from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger from vllm.utils import ( - direct_register_custom_op, get_distributed_init_method, resolve_obj_by_qualname, +) +from vllm.utils.torch_utils import ( + direct_register_custom_op, supports_custom_op, ) diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index a3d9dbe83a12..a5df81e55e36 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -29,7 +29,8 @@ import vllm.envs as envs from vllm.logger import init_logger -from vllm.utils import get_tcp_uri, is_torch_equal_or_newer +from vllm.utils import get_tcp_uri +from vllm.utils.torch_utils import is_torch_equal_or_newer logger = init_logger(__name__) diff --git a/vllm/env_override.py b/vllm/env_override.py index f4ac48584cb7..30071f8ea46c 100644 --- a/vllm/env_override.py +++ b/vllm/env_override.py @@ -5,7 +5,7 @@ import torch from vllm.logger import init_logger -from vllm.utils import is_torch_equal +from vllm.utils.torch_utils import is_torch_equal logger = init_logger(__name__) diff --git a/vllm/envs.py b/vllm/envs.py index 7dcfabe3e044..a7d294ab8298 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -246,7 +246,7 @@ def maybe_convert_bool(value: str | None) -> bool | None: def use_aot_compile() -> bool: - from vllm.utils import is_torch_equal_or_newer + from vllm.utils.torch_utils import is_torch_equal_or_newer default_value = "1" if is_torch_equal_or_newer("2.10.0.dev") else "0" return os.environ.get("VLLM_USE_AOT_COMPILE", default_value) == "1" diff --git a/vllm/lora/ops/triton_ops/lora_expand_op.py b/vllm/lora/ops/triton_ops/lora_expand_op.py index c8330455985a..fd4c1364de7e 100644 --- a/vllm/lora/ops/triton_ops/lora_expand_op.py +++ b/vllm/lora/ops/triton_ops/lora_expand_op.py @@ -12,7 +12,7 @@ from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr, get_lora_op_configs from vllm.triton_utils import tl, triton -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op @triton.jit diff --git a/vllm/lora/ops/triton_ops/lora_shrink_op.py b/vllm/lora/ops/triton_ops/lora_shrink_op.py index 9cba8f494448..8c58915e3f79 100644 --- a/vllm/lora/ops/triton_ops/lora_shrink_op.py +++ b/vllm/lora/ops/triton_ops/lora_shrink_op.py @@ -12,7 +12,7 @@ from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr, get_lora_op_configs from vllm.triton_utils import tl, triton -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op @triton.jit diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index 698d12d5eadd..f21fe16c5108 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -10,7 +10,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, ) -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op def flashinfer_fused_moe_blockscale_fp8( diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 256f4964b654..c88982c1522b 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -51,8 +51,8 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Scheme from vllm.platforms import current_platform from vllm.triton_utils import tl, triton -from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used +from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index b572baecd753..6edbb17c0a8e 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -11,7 +11,7 @@ FusedMoEQuantConfig, ) from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op class QuantMethod(IntEnum): diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index e5957474630c..0627ea50d821 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -23,8 +23,9 @@ mxfp8_e4m3_quantize, ) from vllm.triton_utils import tl, triton -from vllm.utils import cdiv, is_torch_equal_or_newer +from vllm.utils import cdiv from vllm.utils.flashinfer import flashinfer_fp4_quantize +from vllm.utils.torch_utils import is_torch_equal_or_newer @triton.jit diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index a689bc7be00f..fe0d3a9e319c 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -13,7 +13,7 @@ vllm_kernel_override_batch_invariant, ) from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op def is_rocm_aiter_rmsnorm_enabled() -> bool: diff --git a/vllm/model_executor/layers/mamba/linear_attn.py b/vllm/model_executor/layers/mamba/linear_attn.py index b5a37b2582e5..a8f7f652452f 100644 --- a/vllm/model_executor/layers/mamba/linear_attn.py +++ b/vllm/model_executor/layers/mamba/linear_attn.py @@ -35,7 +35,7 @@ MambaStateShapeCalculator, ) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata if TYPE_CHECKING: diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 8f7317556f77..a9a0c216474b 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -37,7 +37,7 @@ selective_state_update, ) from vllm.model_executor.utils import set_weight_attrs -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index b0ee327a8234..fb45afa33dad 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -46,7 +46,7 @@ sharded_weight_loader, ) from vllm.model_executor.utils import set_weight_attrs -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata # Added by the IBM Team, 2024 diff --git a/vllm/model_executor/layers/mamba/short_conv.py b/vllm/model_executor/layers/mamba/short_conv.py index afaa706929a2..04efa8a8b373 100644 --- a/vllm/model_executor/layers/mamba/short_conv.py +++ b/vllm/model_executor/layers/mamba/short_conv.py @@ -27,7 +27,7 @@ causal_conv1d_fn, causal_conv1d_update, ) -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionMetadata diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index 81cf86a7d0ee..ccd9b311cc93 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -23,7 +23,7 @@ QuantizationMethods, ) from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op class BitsAndBytesConfig(QuantizationConfig): diff --git a/vllm/model_executor/layers/quantization/fp_quant.py b/vllm/model_executor/layers/quantization/fp_quant.py index f00ea17ab677..15a253cef0b7 100644 --- a/vllm/model_executor/layers/quantization/fp_quant.py +++ b/vllm/model_executor/layers/quantization/fp_quant.py @@ -24,7 +24,7 @@ from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op class FPQuantConfig(QuantizationConfig): diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 84cd07a0c174..8a914c57a9f7 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -28,7 +28,7 @@ ) from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.utils import set_weight_attrs -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py index 5e133aac10fa..a19396a162bc 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py @@ -7,7 +7,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op from .cutlass import CutlassScaledMMLinearKernel from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py index 1bc1171843d5..c25c522dea55 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py @@ -45,7 +45,7 @@ def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool: from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4 from aiter.ops.triton.quant import dynamic_mxfp4_quant - from vllm.utils import direct_register_custom_op + from vllm.utils.torch_utils import direct_register_custom_op if is_rocm_aiter_fp4_asm_gemm_enabled(): from aiter import gemm_a4w4, per_1x32_f4_quant_hip diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 51af40a11914..cb2d075c1a9d 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -28,13 +28,13 @@ ) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton -from vllm.utils import direct_register_custom_op from vllm.utils.deep_gemm import ( fp8_gemm_nt, is_deep_gemm_e8m0_used, is_deep_gemm_supported, should_use_deepgemm_for_fp8_linear, ) +from vllm.utils.torch_utils import direct_register_custom_op logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py index 231d7dc6ce41..5e87cadfb107 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -7,7 +7,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer +from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/quantization/utils/mxfp6_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp6_utils.py index 2249e9658970..2b5659e30097 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp6_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp6_utils.py @@ -3,7 +3,7 @@ import torch from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_BLOCK_SIZE -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op def _quant_dequant_mxfp6( diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 4fda4d76a980..0d036ffdd286 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -12,8 +12,8 @@ from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer +from vllm.utils.torch_utils import direct_register_custom_op # Input scaling factors are no longer optional in _scaled_mm starting # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale diff --git a/vllm/model_executor/layers/rotary_embedding/common.py b/vllm/model_executor/layers/rotary_embedding/common.py index f1b34f178574..9e6ec9fdd523 100644 --- a/vllm/model_executor/layers/rotary_embedding/common.py +++ b/vllm/model_executor/layers/rotary_embedding/common.py @@ -10,7 +10,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op if current_platform.is_cuda(): from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb diff --git a/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py b/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py index 223350d43267..a01d14f7b3a1 100644 --- a/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py +++ b/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py @@ -5,7 +5,7 @@ import vllm.envs as envs from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op def is_rocm_triton_rotary_embedding_enabled() -> bool: diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index 87ffcb48c8c0..c1a48fa200ca 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -9,7 +9,7 @@ from vllm import _custom_ops as ops from vllm import envs from vllm.platforms import CpuArchEnum, current_platform -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op def shuffle_weight(w: torch.Tensor) -> torch.Tensor: diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 5b55b685dacf..d2bc037b97dd 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -79,8 +79,8 @@ from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from vllm.utils import direct_register_custom_op from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits +from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backends.mla.indexer import ( DeepseekV32IndexerBackend, DeepseekV32IndexerMetadata, diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index b35a8c6b66f2..09293f63f70e 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -64,7 +64,7 @@ ) from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 27e7a3ead45d..06e94734376c 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -71,7 +71,7 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import Qwen3NextConfig from vllm.triton_utils import tl, triton -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata from .interfaces import ( diff --git a/vllm/model_executor/models/transformers_moe.py b/vllm/model_executor/models/transformers_moe.py index 5267e447902f..43ea9a4869ed 100644 --- a/vllm/model_executor/models/transformers_moe.py +++ b/vllm/model_executor/models/transformers_moe.py @@ -28,7 +28,7 @@ from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op from .interfaces import MixtureOfExperts, SupportsMultiModal from .transformers import ( diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index b9140b4fe676..99651a408b31 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -7,7 +7,8 @@ from vllm import envs from vllm.plugins import PLATFORM_PLUGINS_GROUP, load_plugins_by_group -from vllm.utils import resolve_obj_by_qualname, supports_xccl +from vllm.utils import resolve_obj_by_qualname +from vllm.utils.torch_utils import supports_xccl from .interface import CpuArchEnum, Platform, PlatformEnum diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index f0d283d7d3fd..d9cabfb1ec7b 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -64,9 +64,6 @@ import yaml import zmq import zmq.asyncio -from packaging import version -from packaging.version import Version -from torch.library import Library from typing_extensions import Never import vllm.envs as envs @@ -1145,27 +1142,6 @@ def load_config_file(self, file_path: str) -> list[str]: return processed_args -# Using dynamo with vLLM doesn't really work well with PyTorch versions < 2.4.0. -# In particular, the FakeScalarType is not supported for earlier versions of -# PyTorch which breaks dynamo for any ops registered using ScalarType. -def supports_dynamo() -> bool: - base_torch_version = Version(Version(torch.__version__).base_version) - return base_torch_version >= Version("2.4.0") - - -# Supports xccl with PyTorch versions >= 2.8.0.dev for XPU platform -def supports_xccl() -> bool: - return ( - is_torch_equal_or_newer("2.8.0.dev") and torch.distributed.is_xccl_available() - ) - - -# Some backends use pytorch version < 2.4.0 which doesn't -# support `torch.library.custom_op`. -def supports_custom_op() -> bool: - return hasattr(torch.library, "custom_op") - - class AtomicCounter: """An atomic, thread-safe counter""" @@ -1438,70 +1414,6 @@ def __getattr__(self, key: str): ) -# create a library to hold the custom op -vllm_lib = Library("vllm", "FRAGMENT") # noqa - - -def direct_register_custom_op( - op_name: str, - op_func: Callable, - mutates_args: list[str] | None = None, - fake_impl: Callable | None = None, - target_lib: Library | None = None, - dispatch_key: str | None = None, - tags: tuple[torch.Tag, ...] = (), -): - """ - `torch.library.custom_op` can have significant overhead because it - needs to consider complicated dispatching logic. This function - directly registers a custom op and dispatches it to the CUDA backend. - See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5 - for more details. - - By default, the custom op is registered to the vLLM library. If you - want to register it to a different library, you can pass the library - object to the `target_lib` argument. - - IMPORTANT: the lifetime of the operator is tied to the lifetime of the - library object. If you want to bind the operator to a different library, - make sure the library object is alive when the operator is used. - """ - if not supports_custom_op(): - from vllm.platforms import current_platform - - assert not current_platform.is_cuda_alike(), ( - "cuda platform needs torch>=2.4 to support custom op, " - "chances are you are using an old version of pytorch " - "or a custom build of pytorch. It is recommended to " - "use vLLM in a fresh new environment and let it install " - "the required dependencies." - ) - return - - if mutates_args is None: - mutates_args = [] - - if dispatch_key is None: - from vllm.platforms import current_platform - - dispatch_key = current_platform.dispatch_key - - import torch.library - - if hasattr(torch.library, "infer_schema"): - schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args) - else: - # for pytorch 2.4 - import torch._custom_op.impl - - schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args) - my_lib = target_lib or vllm_lib - my_lib.define(op_name + schema_str, tags=tags) - my_lib.impl(op_name, op_func, dispatch_key=dispatch_key) - if fake_impl is not None: - my_lib._register_fake(op_name, fake_impl) - - def resolve_obj_by_qualname(qualname: str) -> Any: """ Resolve an object by its fully-qualified class name. @@ -2233,55 +2145,6 @@ def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]: raise ValueError(f"Unsupported hash function: {hash_fn_name}") -def is_torch_equal_or_newer(target: str) -> bool: - """Check if the installed torch version is >= the target version. - - Args: - target: a version string, like "2.6.0". - - Returns: - Whether the condition meets. - """ - try: - return _is_torch_equal_or_newer(str(torch.__version__), target) - except Exception: - # Fallback to PKG-INFO to load the package info, needed by the doc gen. - return Version(importlib.metadata.version("torch")) >= Version(target) - - -# Helper function used in testing. -def _is_torch_equal_or_newer(torch_version: str, target: str) -> bool: - torch_version = version.parse(torch_version) - return torch_version >= version.parse(target) - - -def _is_torch_equal(target: str) -> bool: - assert target.count(".") == 2 - torch_version = str(torch.__version__) - torch_version = version.parse(torch_version) - # torch version is like "2.6.0.dev20240101" or "2.6.0.dev20240101+cpu" - # or "2.6.0+cu128" but never "2.6.0.1" - return ( - torch_version >= version.parse(target) - and version.parse(target + ".1") > torch_version - ) - - -def is_torch_equal(target: str) -> bool: - """Check if the installed torch version is == the target version. - - Args: - target: a version string, like "2.6.0". - - Returns: - Whether the condition meets. - """ - try: - return _is_torch_equal(target) - except Exception: - return Version(importlib.metadata.version("torch")) == Version(target) - - @cache def _has_module(module_name: str) -> bool: """Return True if *module_name* can be found in the current environment. diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py index 27eb12050168..e7d14c10cbc1 100644 --- a/vllm/utils/torch_utils.py +++ b/vllm/utils/torch_utils.py @@ -1,12 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib +import importlib.metadata import threading -from collections.abc import Collection +from collections.abc import Callable, Collection from functools import lru_cache from typing import TYPE_CHECKING, Any import torch +from packaging import version +from packaging.version import Version +from torch.library import Library import vllm.envs as envs @@ -391,3 +395,136 @@ def get_cuda_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tensor: """ assert cpu_tensor.is_pinned(), "CPU tensor must be pinned" return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor) + + +# Helper function used in testing. +def _is_torch_equal_or_newer(torch_version: str, target: str) -> bool: + torch_version = version.parse(torch_version) + return torch_version >= version.parse(target) + + +def is_torch_equal_or_newer(target: str) -> bool: + """Check if the installed torch version is >= the target version. + + Args: + target: a version string, like "2.6.0". + + Returns: + Whether the condition meets. + """ + try: + return _is_torch_equal_or_newer(str(torch.__version__), target) + except Exception: + # Fallback to PKG-INFO to load the package info, needed by the doc gen. + return Version(importlib.metadata.version("torch")) >= Version(target) + + +def _is_torch_equal(target: str) -> bool: + assert target.count(".") == 2 + torch_version = str(torch.__version__) + torch_version = version.parse(torch_version) + # torch version is like "2.6.0.dev20240101" or "2.6.0.dev20240101+cpu" + # or "2.6.0+cu128" but never "2.6.0.1" + return ( + torch_version >= version.parse(target) + and version.parse(target + ".1") > torch_version + ) + + +def is_torch_equal(target: str) -> bool: + """Check if the installed torch version is == the target version. + + Args: + target: a version string, like "2.6.0". + + Returns: + Whether the condition meets. + """ + try: + return _is_torch_equal(target) + except Exception: + return Version(importlib.metadata.version("torch")) == Version(target) + + +# Using dynamo with vLLM doesn't really work well with PyTorch versions < 2.4.0. +# In particular, the FakeScalarType is not supported for earlier versions of +# PyTorch which breaks dynamo for any ops registered using ScalarType. +def supports_dynamo() -> bool: + return is_torch_equal_or_newer("2.4.0") + + +# Supports xccl with PyTorch versions >= 2.8.0.dev for XPU platform +def supports_xccl() -> bool: + return ( + is_torch_equal_or_newer("2.8.0.dev") and torch.distributed.is_xccl_available() + ) + + +# Some backends use pytorch version < 2.4.0 which doesn't +# support `torch.library.custom_op`. +def supports_custom_op() -> bool: + return hasattr(torch.library, "custom_op") + + +# create a library to hold the custom op +vllm_lib = Library("vllm", "FRAGMENT") # noqa + + +def direct_register_custom_op( + op_name: str, + op_func: Callable, + mutates_args: list[str] | None = None, + fake_impl: Callable | None = None, + target_lib: Library | None = None, + dispatch_key: str | None = None, + tags: tuple[torch.Tag, ...] = (), +): + """ + `torch.library.custom_op` can have significant overhead because it + needs to consider complicated dispatching logic. This function + directly registers a custom op and dispatches it to the CUDA backend. + See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5 + for more details. + + By default, the custom op is registered to the vLLM library. If you + want to register it to a different library, you can pass the library + object to the `target_lib` argument. + + IMPORTANT: the lifetime of the operator is tied to the lifetime of the + library object. If you want to bind the operator to a different library, + make sure the library object is alive when the operator is used. + """ + if not supports_custom_op(): + from vllm.platforms import current_platform + + assert not current_platform.is_cuda_alike(), ( + "cuda platform needs torch>=2.4 to support custom op, " + "chances are you are using an old version of pytorch " + "or a custom build of pytorch. It is recommended to " + "use vLLM in a fresh new environment and let it install " + "the required dependencies." + ) + return + + if mutates_args is None: + mutates_args = [] + + if dispatch_key is None: + from vllm.platforms import current_platform + + dispatch_key = current_platform.dispatch_key + + import torch.library + + if hasattr(torch.library, "infer_schema"): + schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args) + else: + # for pytorch 2.4 + import torch._custom_op.impl + + schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args) + my_lib = target_lib or vllm_lib + my_lib.define(op_name + schema_str, tags=tags) + my_lib.impl(op_name, op_func, dispatch_key=dispatch_key) + if fake_impl is not None: + my_lib._register_fake(op_name, fake_impl) diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index 902872bb25b3..d14f949b6579 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -28,7 +28,8 @@ from vllm.model_executor.layers.batch_invariant import ( vllm_kernel_override_batch_invariant, ) -from vllm.utils import cdiv, is_torch_equal_or_newer +from vllm.utils import cdiv +from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 7c73611d4a58..f7a4114a0a70 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -29,7 +29,7 @@ import aiter from vllm.triton_utils import tl, triton - from vllm.utils import direct_register_custom_op + from vllm.utils.torch_utils import direct_register_custom_op @triton.jit def _vllm_layout_trans_kernel( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ad917126f07b..7a73177ba7d8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -80,12 +80,12 @@ is_pin_memory_available, length_from_prompt_token_ids_or_embeds, round_up, - supports_dynamo, ) from vllm.utils.jsontree import json_map_leaves from vllm.utils.torch_utils import ( STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size, + supports_dynamo, ) from vllm.v1.attention.backends.flash_attn import AttentionMetadata from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder From b0853f810d0db2aaf3b52cda091a3d6d4cad3bd1 Mon Sep 17 00:00:00 2001 From: isotr0py <2037008807@qq.com> Date: Thu, 16 Oct 2025 22:36:41 +0800 Subject: [PATCH 09/15] rename torch_utils.py to torch.py Signed-off-by: isotr0py <2037008807@qq.com> --- benchmarks/kernels/bench_per_token_quant_fp8.py | 2 +- benchmarks/kernels/benchmark_activation.py | 2 +- benchmarks/kernels/benchmark_layernorm.py | 2 +- benchmarks/kernels/benchmark_paged_attention.py | 2 +- benchmarks/kernels/benchmark_quant.py | 2 +- benchmarks/kernels/benchmark_reshape_and_cache.py | 2 +- benchmarks/kernels/benchmark_reshape_and_cache_flash.py | 2 +- tests/compile/piecewise/test_full_cudagraph.py | 2 +- tests/compile/piecewise/test_multiple_graphs.py | 2 +- tests/compile/piecewise/test_simple.py | 2 +- tests/compile/piecewise/test_toy_llama.py | 2 +- tests/compile/silly_attention.py | 2 +- tests/compile/test_aot_compile.py | 2 +- tests/compile/test_basic_correctness.py | 2 +- tests/compile/test_config.py | 2 +- tests/compile/test_decorator.py | 2 +- tests/compile/test_full_graph.py | 2 +- tests/compile/test_fusion_attn.py | 2 +- tests/conftest.py | 2 +- tests/distributed/test_sequence_parallel.py | 2 +- tests/distributed/test_utils.py | 2 +- tests/kernels/attention/conftest.py | 2 +- tests/kernels/attention/test_prefix_prefill.py | 2 +- tests/kernels/core/test_uva.py | 2 +- tests/kernels/moe/test_modular_kernel_combinations.py | 2 +- tests/models/multimodal/pooling/test_intern_vit.py | 2 +- tests/models/multimodal/pooling/test_radio.py | 2 +- tests/models/multimodal/processing/test_tensor_schema.py | 2 +- tests/utils.py | 2 +- tests/utils_/test_utils.py | 2 +- tests/v1/attention/test_attention_backends.py | 2 +- tests/v1/attention/test_mla_backends.py | 2 +- tests/v1/engine/test_async_llm.py | 2 +- tests/v1/engine/test_engine_core.py | 2 +- tests/v1/engine/test_engine_core_client.py | 2 +- tests/v1/shutdown/test_delete.py | 2 +- tests/v1/shutdown/test_forward_error.py | 2 +- tests/v1/shutdown/test_startup_error.py | 2 +- vllm/attention/layer.py | 2 +- vllm/attention/ops/rocm_aiter_mla.py | 2 +- vllm/compilation/backends.py | 2 +- vllm/compilation/collective_fusion.py | 2 +- vllm/compilation/compiler_interface.py | 2 +- vllm/compilation/cuda_graph.py | 2 +- vllm/compilation/decorators.py | 2 +- vllm/compilation/inductor_pass.py | 2 +- vllm/config/compilation.py | 2 +- vllm/config/model.py | 2 +- vllm/config/parallel.py | 2 +- vllm/distributed/device_communicators/all_reduce_utils.py | 2 +- vllm/distributed/device_communicators/custom_all_reduce.py | 2 +- vllm/distributed/device_communicators/pynccl.py | 4 ++-- vllm/distributed/device_communicators/quick_all_reduce.py | 2 +- vllm/distributed/device_communicators/ray_communicator.py | 2 +- .../kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py | 2 +- vllm/distributed/parallel_state.py | 2 +- vllm/distributed/utils.py | 2 +- vllm/env_override.py | 2 +- vllm/envs.py | 2 +- vllm/lora/ops/triton_ops/lora_expand_op.py | 2 +- vllm/lora/ops/triton_ops/lora_shrink_op.py | 2 +- vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py | 2 +- vllm/model_executor/layers/fused_moe/fused_moe.py | 2 +- vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py | 2 +- vllm/model_executor/layers/fused_moe/utils.py | 2 +- vllm/model_executor/layers/layernorm.py | 2 +- vllm/model_executor/layers/mamba/linear_attn.py | 2 +- vllm/model_executor/layers/mamba/mamba_mixer.py | 2 +- vllm/model_executor/layers/mamba/mamba_mixer2.py | 2 +- vllm/model_executor/layers/mamba/mamba_utils.py | 2 +- vllm/model_executor/layers/mamba/short_conv.py | 2 +- vllm/model_executor/layers/quantization/bitsandbytes.py | 2 +- vllm/model_executor/layers/quantization/fp_quant.py | 2 +- vllm/model_executor/layers/quantization/gguf.py | 2 +- .../layers/quantization/kernels/scaled_mm/aiter.py | 2 +- .../layers/quantization/quark/schemes/quark_ocp_mx.py | 2 +- vllm/model_executor/layers/quantization/utils/fp8_utils.py | 2 +- vllm/model_executor/layers/quantization/utils/mxfp4_utils.py | 2 +- vllm/model_executor/layers/quantization/utils/mxfp6_utils.py | 2 +- vllm/model_executor/layers/quantization/utils/w8a8_utils.py | 2 +- vllm/model_executor/layers/rotary_embedding/common.py | 2 +- .../layers/rotary_embedding/rocm_aiter_rope_ops.py | 2 +- vllm/model_executor/layers/utils.py | 2 +- vllm/model_executor/model_loader/base_loader.py | 2 +- vllm/model_executor/model_loader/bitsandbytes_loader.py | 2 +- vllm/model_executor/model_loader/gguf_loader.py | 2 +- vllm/model_executor/model_loader/tensorizer_loader.py | 2 +- vllm/model_executor/model_loader/tpu.py | 2 +- vllm/model_executor/models/config.py | 2 +- vllm/model_executor/models/deepseek_v2.py | 2 +- vllm/model_executor/models/deepseek_vl2.py | 2 +- vllm/model_executor/models/internvl.py | 2 +- vllm/model_executor/models/minicpmv.py | 2 +- vllm/model_executor/models/plamo2.py | 2 +- vllm/model_executor/models/qwen3_next.py | 2 +- vllm/model_executor/models/transformers_moe.py | 2 +- vllm/model_executor/models/utils.py | 2 +- vllm/model_executor/models/whisper.py | 2 +- vllm/platforms/__init__.py | 2 +- vllm/platforms/cuda.py | 2 +- vllm/platforms/rocm.py | 2 +- vllm/usage/usage_lib.py | 2 +- vllm/utils/{torch_utils.py => torch.py} | 0 vllm/v1/attention/backends/flex_attention.py | 2 +- vllm/v1/attention/backends/rocm_aiter_fa.py | 2 +- vllm/v1/kv_cache_interface.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 2 +- vllm/v1/worker/tpu_worker.py | 2 +- vllm/v1/worker/ubatching.py | 2 +- 109 files changed, 109 insertions(+), 109 deletions(-) rename vllm/utils/{torch_utils.py => torch.py} (100%) diff --git a/benchmarks/kernels/bench_per_token_quant_fp8.py b/benchmarks/kernels/bench_per_token_quant_fp8.py index d33b84fc3601..59ed54659872 100644 --- a/benchmarks/kernels/bench_per_token_quant_fp8.py +++ b/benchmarks/kernels/bench_per_token_quant_fp8.py @@ -11,7 +11,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.triton_utils import triton from vllm.utils import FlexibleArgumentParser -from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE +from vllm.utils.torch import STR_DTYPE_TO_TORCH_DTYPE def with_triton_mode(fn): diff --git a/benchmarks/kernels/benchmark_activation.py b/benchmarks/kernels/benchmark_activation.py index 7662655b5efa..51a69974ac71 100644 --- a/benchmarks/kernels/benchmark_activation.py +++ b/benchmarks/kernels/benchmark_activation.py @@ -11,7 +11,7 @@ from vllm.platforms import current_platform from vllm.triton_utils import triton from vllm.utils import FlexibleArgumentParser -from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE +from vllm.utils.torch import STR_DTYPE_TO_TORCH_DTYPE batch_size_range = [1, 16, 32, 64, 128] seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096] diff --git a/benchmarks/kernels/benchmark_layernorm.py b/benchmarks/kernels/benchmark_layernorm.py index bcfa64c3f425..53d4d3174f80 100644 --- a/benchmarks/kernels/benchmark_layernorm.py +++ b/benchmarks/kernels/benchmark_layernorm.py @@ -8,7 +8,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.platforms import current_platform from vllm.utils import FlexibleArgumentParser -from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE +from vllm.utils.torch import STR_DTYPE_TO_TORCH_DTYPE @torch.inference_mode() diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 1b1e71adeec4..82b1efdd20f6 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -10,7 +10,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import FlexibleArgumentParser -from vllm.utils.torch_utils import ( +from vllm.utils.torch import ( STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random, ) diff --git a/benchmarks/kernels/benchmark_quant.py b/benchmarks/kernels/benchmark_quant.py index 61427a77b4e3..88af6f25e358 100644 --- a/benchmarks/kernels/benchmark_quant.py +++ b/benchmarks/kernels/benchmark_quant.py @@ -8,7 +8,7 @@ from vllm import _custom_ops as ops from vllm.platforms import current_platform from vllm.utils import FlexibleArgumentParser -from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE +from vllm.utils.torch import STR_DTYPE_TO_TORCH_DTYPE @torch.inference_mode() diff --git a/benchmarks/kernels/benchmark_reshape_and_cache.py b/benchmarks/kernels/benchmark_reshape_and_cache.py index e0ff09d4b397..16b42e701824 100644 --- a/benchmarks/kernels/benchmark_reshape_and_cache.py +++ b/benchmarks/kernels/benchmark_reshape_and_cache.py @@ -10,7 +10,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import FlexibleArgumentParser -from vllm.utils.torch_utils import ( +from vllm.utils.torch import ( STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random, ) diff --git a/benchmarks/kernels/benchmark_reshape_and_cache_flash.py b/benchmarks/kernels/benchmark_reshape_and_cache_flash.py index 29f1b2ccdcf6..b360256ec876 100644 --- a/benchmarks/kernels/benchmark_reshape_and_cache_flash.py +++ b/benchmarks/kernels/benchmark_reshape_and_cache_flash.py @@ -13,7 +13,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import FlexibleArgumentParser -from vllm.utils.torch_utils import ( +from vllm.utils.torch import ( STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random_flash, ) diff --git a/tests/compile/piecewise/test_full_cudagraph.py b/tests/compile/piecewise/test_full_cudagraph.py index c6d4b5272dbc..a6edbb410ef6 100644 --- a/tests/compile/piecewise/test_full_cudagraph.py +++ b/tests/compile/piecewise/test_full_cudagraph.py @@ -11,7 +11,7 @@ from vllm import LLM, SamplingParams from vllm.config import CompilationConfig from vllm.platforms import current_platform -from vllm.utils.torch_utils import is_torch_equal_or_newer +from vllm.utils.torch import is_torch_equal_or_newer @contextlib.contextmanager diff --git a/tests/compile/piecewise/test_multiple_graphs.py b/tests/compile/piecewise/test_multiple_graphs.py index 700f57ffb068..eaf625c934f5 100644 --- a/tests/compile/piecewise/test_multiple_graphs.py +++ b/tests/compile/piecewise/test_multiple_graphs.py @@ -20,7 +20,7 @@ set_current_vllm_config, ) from vllm.forward_context import BatchDescriptor, set_forward_context -from vllm.utils.torch_utils import is_torch_equal_or_newer +from vllm.utils.torch import is_torch_equal_or_newer # This import automatically registers `torch.ops.silly.attention` from .. import silly_attention # noqa: F401 diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index 9d4e3f0f300f..841392463094 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -19,7 +19,7 @@ set_current_vllm_config, ) from vllm.forward_context import BatchDescriptor, set_forward_context -from vllm.utils.torch_utils import is_torch_equal_or_newer +from vllm.utils.torch import is_torch_equal_or_newer # This import automatically registers `torch.ops.silly.attention` from ..silly_attention import get_global_counter, reset_global_counter diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index 175ca4a23043..da18f7dca26f 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -27,7 +27,7 @@ set_current_vllm_config, ) from vllm.forward_context import BatchDescriptor, set_forward_context -from vllm.utils.torch_utils import is_torch_equal_or_newer +from vllm.utils.torch import is_torch_equal_or_newer # This import automatically registers `torch.ops.silly.attention` from .. import silly_attention # noqa: F401 diff --git a/tests/compile/silly_attention.py b/tests/compile/silly_attention.py index 29c02f6e6a1d..1e055db6f68e 100644 --- a/tests/compile/silly_attention.py +++ b/tests/compile/silly_attention.py @@ -8,7 +8,7 @@ import torch from torch.library import Library -from vllm.utils.torch_utils import direct_register_custom_op +from vllm.utils.torch import direct_register_custom_op # Shared library for all compilation test operations # Using "silly" namespace to match existing test expectations diff --git a/tests/compile/test_aot_compile.py b/tests/compile/test_aot_compile.py index b2734af575a1..9fb4fae2b74a 100644 --- a/tests/compile/test_aot_compile.py +++ b/tests/compile/test_aot_compile.py @@ -15,7 +15,7 @@ set_current_vllm_config, ) from vllm.forward_context import set_forward_context -from vllm.utils.torch_utils import is_torch_equal_or_newer +from vllm.utils.torch import is_torch_equal_or_newer def reference_fn(x: torch.Tensor): diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index 132a838b8d44..0d8a811953c8 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -5,7 +5,7 @@ import pytest from vllm.config import CompilationMode -from vllm.utils.torch_utils import cuda_device_count_stateless +from vllm.utils.torch import cuda_device_count_stateless from ..utils import compare_all_settings diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index 20e3c5039cc0..e666aea3af54 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -5,7 +5,7 @@ from vllm.compilation.counter import compilation_counter from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig from vllm.config.compilation import CompilationMode -from vllm.utils.torch_utils import _is_torch_equal_or_newer, is_torch_equal_or_newer +from vllm.utils.torch import _is_torch_equal_or_newer, is_torch_equal_or_newer def test_version(): diff --git a/tests/compile/test_decorator.py b/tests/compile/test_decorator.py index c9d01f2317d2..7b5149e89e52 100644 --- a/tests/compile/test_decorator.py +++ b/tests/compile/test_decorator.py @@ -15,7 +15,7 @@ set_current_vllm_config, ) from vllm.forward_context import BatchDescriptor, set_forward_context -from vllm.utils.torch_utils import is_torch_equal_or_newer +from vllm.utils.torch import is_torch_equal_or_newer # This import automatically registers `torch.ops.silly.attention` from . import silly_attention # noqa: F401 diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 248a9f3c7730..af8849a62e84 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -14,7 +14,7 @@ from vllm.attention.selector import global_force_attn_backend_context_manager from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig from vllm.platforms import current_platform -from vllm.utils.torch_utils import is_torch_equal_or_newer +from vllm.utils.torch import is_torch_equal_or_newer from ..utils import create_new_process_for_each_test diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index a35fb9c8c31f..b47eecf407d0 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -33,7 +33,7 @@ ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp from vllm.platforms import current_platform -from vllm.utils.torch_utils import is_torch_equal_or_newer +from vllm.utils.torch import is_torch_equal_or_newer from vllm.v1.kv_cache_interface import AttentionSpec FP8_DTYPE = current_platform.fp8_dtype() diff --git a/tests/conftest.py b/tests/conftest.py index 45ccc38a4538..722100c4d99e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -58,7 +58,7 @@ from vllm.sampling_params import BeamSearchParams from vllm.transformers_utils.utils import maybe_model_redirect from vllm.utils.collections import is_list_of -from vllm.utils.torch_utils import set_default_torch_num_threads +from vllm.utils.torch import set_default_torch_num_threads logger = init_logger(__name__) diff --git a/tests/distributed/test_sequence_parallel.py b/tests/distributed/test_sequence_parallel.py index c35f6a3c2507..5974865d2387 100644 --- a/tests/distributed/test_sequence_parallel.py +++ b/tests/distributed/test_sequence_parallel.py @@ -18,7 +18,7 @@ from vllm.config.compilation import CompilationMode from vllm.config.model import RunnerOption from vllm.logger import init_logger -from vllm.utils.torch_utils import is_torch_equal_or_newer +from vllm.utils.torch import is_torch_equal_or_newer from ..models.registry import HF_EXAMPLE_MODELS from ..utils import compare_two_settings, create_new_process_for_each_test diff --git a/tests/distributed/test_utils.py b/tests/distributed/test_utils.py index c10c2565811b..3b8fa43d0b7a 100644 --- a/tests/distributed/test_utils.py +++ b/tests/distributed/test_utils.py @@ -14,7 +14,7 @@ get_open_port, update_environment_variables, ) -from vllm.utils.torch_utils import cuda_device_count_stateless +from vllm.utils.torch import cuda_device_count_stateless from ..utils import multi_gpu_test diff --git a/tests/kernels/attention/conftest.py b/tests/kernels/attention/conftest.py index e520267320c0..6704fa99359f 100644 --- a/tests/kernels/attention/conftest.py +++ b/tests/kernels/attention/conftest.py @@ -3,7 +3,7 @@ import pytest -from vllm.utils.torch_utils import ( +from vllm.utils.torch import ( create_kv_caches_with_random, create_kv_caches_with_random_flash, ) diff --git a/tests/kernels/attention/test_prefix_prefill.py b/tests/kernels/attention/test_prefix_prefill.py index 65972d02f2f6..e75cbb350d73 100644 --- a/tests/kernels/attention/test_prefix_prefill.py +++ b/tests/kernels/attention/test_prefix_prefill.py @@ -15,7 +15,7 @@ from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode from vllm.attention.ops.prefix_prefill import context_attention_fwd from vllm.platforms import current_platform -from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE +from vllm.utils.torch import STR_DTYPE_TO_TORCH_DTYPE NUM_HEADS = [64] NUM_QUERIES_PER_KV = [1, 64] diff --git a/tests/kernels/core/test_uva.py b/tests/kernels/core/test_uva.py index dee92976eb6f..aaa4ec311afc 100644 --- a/tests/kernels/core/test_uva.py +++ b/tests/kernels/core/test_uva.py @@ -4,7 +4,7 @@ import torch from vllm.utils import is_uva_available -from vllm.utils.torch_utils import get_cuda_view_from_cpu_tensor +from vllm.utils.torch import get_cuda_view_from_cpu_tensor CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] diff --git a/tests/kernels/moe/test_modular_kernel_combinations.py b/tests/kernels/moe/test_modular_kernel_combinations.py index a7beb313011a..4403a69f79f1 100644 --- a/tests/kernels/moe/test_modular_kernel_combinations.py +++ b/tests/kernels/moe/test_modular_kernel_combinations.py @@ -15,7 +15,7 @@ from vllm.platforms import current_platform from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe -from vllm.utils.torch_utils import cuda_device_count_stateless +from vllm.utils.torch import cuda_device_count_stateless from .modular_kernel_tools.common import ( Config, diff --git a/tests/models/multimodal/pooling/test_intern_vit.py b/tests/models/multimodal/pooling/test_intern_vit.py index 5a97848216b8..155cabfe3a88 100644 --- a/tests/models/multimodal/pooling/test_intern_vit.py +++ b/tests/models/multimodal/pooling/test_intern_vit.py @@ -7,7 +7,7 @@ from transformers import AutoConfig, AutoModel, CLIPImageProcessor from vllm.distributed import cleanup_dist_env_and_memory -from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE +from vllm.utils.torch import STR_DTYPE_TO_TORCH_DTYPE from ....conftest import ImageTestAssets diff --git a/tests/models/multimodal/pooling/test_radio.py b/tests/models/multimodal/pooling/test_radio.py index 8929563d8b05..47b359b45de9 100644 --- a/tests/models/multimodal/pooling/test_radio.py +++ b/tests/models/multimodal/pooling/test_radio.py @@ -9,7 +9,7 @@ from vllm.distributed import cleanup_dist_env_and_memory from vllm.model_executor.models.radio import RadioModel from vllm.transformers_utils.configs.radio import RadioConfig -from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE +from vllm.utils.torch import STR_DTYPE_TO_TORCH_DTYPE from ....conftest import ImageTestAssets diff --git a/tests/models/multimodal/processing/test_tensor_schema.py b/tests/models/multimodal/processing/test_tensor_schema.py index 8de8ecfe9d83..bbfdee780f9e 100644 --- a/tests/models/multimodal/processing/test_tensor_schema.py +++ b/tests/models/multimodal/processing/test_tensor_schema.py @@ -35,7 +35,7 @@ from vllm.multimodal.utils import group_mm_kwargs_by_modality from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from vllm.utils.collections import is_list_of -from vllm.utils.torch_utils import set_default_torch_dtype +from vllm.utils.torch import set_default_torch_dtype from ...registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS from ...utils import dummy_hf_overrides diff --git a/tests/utils.py b/tests/utils.py index ad5139c7f8ba..dd7b7c14911c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -48,7 +48,7 @@ GB_bytes, get_open_port, ) -from vllm.utils.torch_utils import cuda_device_count_stateless +from vllm.utils.torch import cuda_device_count_stateless if current_platform.is_rocm(): from amdsmi import ( diff --git a/tests/utils_/test_utils.py b/tests/utils_/test_utils.py index 5ce9a9604b08..0c8750bd150b 100644 --- a/tests/utils_/test_utils.py +++ b/tests/utils_/test_utils.py @@ -37,7 +37,7 @@ split_zmq_path, unique_filepath, ) -from vllm.utils.torch_utils import ( +from vllm.utils.torch import ( common_broadcastable_dtype, current_stream, is_lossless_cast, diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index 7a5216d1738d..e8cc45e74f5f 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -19,7 +19,7 @@ from vllm.config import ModelConfig from vllm.platforms import current_platform from vllm.utils import cdiv -from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE, is_torch_equal_or_newer +from vllm.utils.torch import STR_DTYPE_TO_TORCH_DTYPE, is_torch_equal_or_newer from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata, set_kv_cache_layout, diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index 81fd6433b0c8..146c91ea763d 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -23,7 +23,7 @@ from vllm.attention.ops.flashmla import is_flashmla_dense_supported from vllm.config.vllm import set_current_vllm_config from vllm.utils import cdiv -from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE +from vllm.utils.torch import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import FullAttentionSpec diff --git a/tests/v1/engine/test_async_llm.py b/tests/v1/engine/test_async_llm.py index c9605ea1b07c..23752b524899 100644 --- a/tests/v1/engine/test_async_llm.py +++ b/tests/v1/engine/test_async_llm.py @@ -15,7 +15,7 @@ from vllm.outputs import RequestOutput from vllm.platforms import current_platform from vllm.sampling_params import RequestOutputKind -from vllm.utils.torch_utils import set_default_torch_num_threads +from vllm.utils.torch import set_default_torch_num_threads from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.metrics.loggers import ( AggregatedLoggingStatLogger, diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index 7e39cd781bae..4c2260558dfe 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -12,7 +12,7 @@ from vllm import SamplingParams from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform -from vllm.utils.torch_utils import set_default_torch_num_threads +from vllm.utils.torch import set_default_torch_num_threads from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.core import EngineCore from vllm.v1.executor.abstract import Executor, UniProcExecutor diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 770560a5e549..6d989c872962 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -21,7 +21,7 @@ from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform from vllm.usage.usage_lib import UsageContext -from vllm.utils.torch_utils import set_default_torch_num_threads +from vllm.utils.torch import set_default_torch_num_threads from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.core import EngineCore from vllm.v1.engine.core_client import AsyncMPClient, EngineCoreClient, SyncMPClient diff --git a/tests/v1/shutdown/test_delete.py b/tests/v1/shutdown/test_delete.py index 255515948433..c3610c02b8c3 100644 --- a/tests/v1/shutdown/test_delete.py +++ b/tests/v1/shutdown/test_delete.py @@ -12,7 +12,7 @@ from vllm import LLM, SamplingParams from vllm.engine.arg_utils import AsyncEngineArgs from vllm.sampling_params import RequestOutputKind -from vllm.utils.torch_utils import cuda_device_count_stateless +from vllm.utils.torch import cuda_device_count_stateless from vllm.v1.engine.async_llm import AsyncLLM MODELS = ["meta-llama/Llama-3.2-1B"] diff --git a/tests/v1/shutdown/test_forward_error.py b/tests/v1/shutdown/test_forward_error.py index e65d46dfa43a..c7587d2dd55b 100644 --- a/tests/v1/shutdown/test_forward_error.py +++ b/tests/v1/shutdown/test_forward_error.py @@ -14,7 +14,7 @@ from vllm import LLM, AsyncEngineArgs, SamplingParams from vllm.distributed import get_tensor_model_parallel_rank from vllm.model_executor.models.llama import LlamaForCausalLM -from vllm.utils.torch_utils import cuda_device_count_stateless +from vllm.utils.torch import cuda_device_count_stateless from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.exceptions import EngineDeadError diff --git a/tests/v1/shutdown/test_startup_error.py b/tests/v1/shutdown/test_startup_error.py index 3877fceae00c..499b9d123d92 100644 --- a/tests/v1/shutdown/test_startup_error.py +++ b/tests/v1/shutdown/test_startup_error.py @@ -13,7 +13,7 @@ from vllm.distributed import get_tensor_model_parallel_rank from vllm.engine.arg_utils import AsyncEngineArgs from vllm.model_executor.models.llama import LlamaForCausalLM -from vllm.utils.torch_utils import cuda_device_count_stateless +from vllm.utils.torch import cuda_device_count_stateless from vllm.v1.engine.async_llm import AsyncLLM MODELS = ["meta-llama/Llama-3.2-1B"] diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 9b6ea3f0b8f5..e5754ec09c53 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -35,7 +35,7 @@ from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.platforms import current_platform from vllm.utils import GiB_bytes -from vllm.utils.torch_utils import direct_register_custom_op +from vllm.utils.torch import direct_register_custom_op FP8_DTYPE = current_platform.fp8_dtype() logger = init_logger(__name__) diff --git a/vllm/attention/ops/rocm_aiter_mla.py b/vllm/attention/ops/rocm_aiter_mla.py index 6308f63cc4e7..552367b16d5c 100644 --- a/vllm/attention/ops/rocm_aiter_mla.py +++ b/vllm/attention/ops/rocm_aiter_mla.py @@ -5,7 +5,7 @@ import torch from vllm.platforms import current_platform -from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer +from vllm.utils.torch import direct_register_custom_op, is_torch_equal_or_newer def get_aiter_mla_metadata( diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 7f4c9e6af13b..a6e4fcc52e71 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -25,7 +25,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import resolve_obj_by_qualname -from vllm.utils.torch_utils import is_torch_equal_or_newer +from vllm.utils.torch import is_torch_equal_or_newer from .caching import VllmSerializableFunction from .compiler_interface import ( diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index fb1f78f21a05..7d21f0354630 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -18,7 +18,7 @@ ) from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils.torch_utils import direct_register_custom_op +from vllm.utils.torch import direct_register_custom_op from .inductor_pass import enable_fake_mode from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 0a3f0769db94..383dcaa3af21 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -16,7 +16,7 @@ import vllm.envs as envs from vllm.compilation.counter import compilation_counter from vllm.config import VllmConfig -from vllm.utils.torch_utils import is_torch_equal_or_newer +from vllm.utils.torch import is_torch_equal_or_newer class CompilerInterface: diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py index a2e0abfebc2c..0e0fea450a31 100644 --- a/vllm/compilation/cuda_graph.py +++ b/vllm/compilation/cuda_graph.py @@ -17,7 +17,7 @@ from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils.torch_utils import weak_ref_tensors +from vllm.utils.torch import weak_ref_tensors logger = init_logger(__name__) diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index abe61cce0dd8..506ba1767f92 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -22,7 +22,7 @@ from vllm.logger import init_logger from vllm.sequence import IntermediateTensors from vllm.utils import resolve_obj_by_qualname -from vllm.utils.torch_utils import supports_dynamo +from vllm.utils.torch import supports_dynamo from .monitor import start_monitoring_torch_compile diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index 9af635a929b4..23d8171375b5 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -14,7 +14,7 @@ from torch import fx from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily -from vllm.utils.torch_utils import is_torch_equal_or_newer +from vllm.utils.torch import is_torch_equal_or_newer if is_torch_equal_or_newer("2.6"): from torch._inductor.custom_graph_pass import CustomGraphPass diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 620a521a9a7e..b8b4c8bc5f28 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -17,7 +17,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import resolve_obj_by_qualname -from vllm.utils.torch_utils import is_torch_equal_or_newer +from vllm.utils.torch import is_torch_equal_or_newer if TYPE_CHECKING: from vllm.config import VllmConfig diff --git a/vllm/config/model.py b/vllm/config/model.py index 68070d204b3c..cf106ea37e28 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -42,7 +42,7 @@ from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri from vllm.transformers_utils.utils import maybe_model_redirect from vllm.utils import LayerBlockType, LazyLoader -from vllm.utils.torch_utils import common_broadcastable_dtype +from vllm.utils.torch import common_broadcastable_dtype if TYPE_CHECKING: from transformers import PretrainedConfig diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index aa1ac4ab8f0b..8a5e4b427b8e 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -19,7 +19,7 @@ ) from vllm.platforms import current_platform from vllm.utils import get_open_ports_list -from vllm.utils.torch_utils import cuda_device_count_stateless +from vllm.utils.torch import cuda_device_count_stateless if TYPE_CHECKING: from ray.runtime_env import RuntimeEnv diff --git a/vllm/distributed/device_communicators/all_reduce_utils.py b/vllm/distributed/device_communicators/all_reduce_utils.py index 09c89ea31d05..cc98b94ace10 100644 --- a/vllm/distributed/device_communicators/all_reduce_utils.py +++ b/vllm/distributed/device_communicators/all_reduce_utils.py @@ -23,7 +23,7 @@ vllm_kernel_override_batch_invariant, ) from vllm.utils import update_environment_variables -from vllm.utils.torch_utils import cuda_device_count_stateless +from vllm.utils.torch import cuda_device_count_stateless logger = init_logger(__name__) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 4b82f3b5d396..dfb9c6f5afc3 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -17,7 +17,7 @@ from vllm.distributed.parallel_state import in_the_same_node_as from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils.torch_utils import cuda_device_count_stateless +from vllm.utils.torch import cuda_device_count_stateless try: ops.meta_size() diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index ad3c8676fafd..baccaecdacd5 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -19,7 +19,7 @@ ) from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger -from vllm.utils.torch_utils import current_stream +from vllm.utils.torch import current_stream logger = init_logger(__name__) @@ -30,7 +30,7 @@ def register_nccl_symmetric_ops(pynccl_comm): from vllm.distributed.device_communicators.pynccl_allocator import ( nccl_symm_mem_context, ) - from vllm.utils.torch_utils import direct_register_custom_op + from vllm.utils.torch import direct_register_custom_op global _NCCL_SYMM_OPS_REGISTERED if _NCCL_SYMM_OPS_REGISTERED: diff --git a/vllm/distributed/device_communicators/quick_all_reduce.py b/vllm/distributed/device_communicators/quick_all_reduce.py index 9c7765883cfd..86f28bee6408 100644 --- a/vllm/distributed/device_communicators/quick_all_reduce.py +++ b/vllm/distributed/device_communicators/quick_all_reduce.py @@ -13,7 +13,7 @@ from vllm.distributed.parallel_state import in_the_same_node_as from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils.torch_utils import cuda_device_count_stateless +from vllm.utils.torch import cuda_device_count_stateless logger = init_logger(__name__) diff --git a/vllm/distributed/device_communicators/ray_communicator.py b/vllm/distributed/device_communicators/ray_communicator.py index 3b02b885e786..4158b06264af 100644 --- a/vllm/distributed/device_communicators/ray_communicator.py +++ b/vllm/distributed/device_communicators/ray_communicator.py @@ -14,7 +14,7 @@ ) from vllm.distributed.parallel_state import get_pp_group from vllm.logger import init_logger -from vllm.utils.torch_utils import current_stream +from vllm.utils.torch import current_stream logger = init_logger(__name__) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py index 5b32a9756637..1c59f9d8cce0 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py @@ -26,7 +26,7 @@ TensorMemoryPool, ) from vllm.utils import get_ip -from vllm.utils.torch_utils import current_stream +from vllm.utils.torch import current_stream logger = logging.getLogger(__name__) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 10e2ba1d5925..35ef01a1e6a6 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -53,7 +53,7 @@ get_distributed_init_method, resolve_obj_by_qualname, ) -from vllm.utils.torch_utils import ( +from vllm.utils.torch import ( direct_register_custom_op, supports_custom_op, ) diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index a5df81e55e36..ab6d50a6c0e3 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -30,7 +30,7 @@ import vllm.envs as envs from vllm.logger import init_logger from vllm.utils import get_tcp_uri -from vllm.utils.torch_utils import is_torch_equal_or_newer +from vllm.utils.torch import is_torch_equal_or_newer logger = init_logger(__name__) diff --git a/vllm/env_override.py b/vllm/env_override.py index 30071f8ea46c..9c1789ef7fa2 100644 --- a/vllm/env_override.py +++ b/vllm/env_override.py @@ -5,7 +5,7 @@ import torch from vllm.logger import init_logger -from vllm.utils.torch_utils import is_torch_equal +from vllm.utils.torch import is_torch_equal logger = init_logger(__name__) diff --git a/vllm/envs.py b/vllm/envs.py index a7d294ab8298..4c9927abb9bf 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -246,7 +246,7 @@ def maybe_convert_bool(value: str | None) -> bool | None: def use_aot_compile() -> bool: - from vllm.utils.torch_utils import is_torch_equal_or_newer + from vllm.utils.torch import is_torch_equal_or_newer default_value = "1" if is_torch_equal_or_newer("2.10.0.dev") else "0" return os.environ.get("VLLM_USE_AOT_COMPILE", default_value) == "1" diff --git a/vllm/lora/ops/triton_ops/lora_expand_op.py b/vllm/lora/ops/triton_ops/lora_expand_op.py index fd4c1364de7e..bee0ea2edf56 100644 --- a/vllm/lora/ops/triton_ops/lora_expand_op.py +++ b/vllm/lora/ops/triton_ops/lora_expand_op.py @@ -12,7 +12,7 @@ from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr, get_lora_op_configs from vllm.triton_utils import tl, triton -from vllm.utils.torch_utils import direct_register_custom_op +from vllm.utils.torch import direct_register_custom_op @triton.jit diff --git a/vllm/lora/ops/triton_ops/lora_shrink_op.py b/vllm/lora/ops/triton_ops/lora_shrink_op.py index 8c58915e3f79..71ab13b4f36b 100644 --- a/vllm/lora/ops/triton_ops/lora_shrink_op.py +++ b/vllm/lora/ops/triton_ops/lora_shrink_op.py @@ -12,7 +12,7 @@ from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr, get_lora_op_configs from vllm.triton_utils import tl, triton -from vllm.utils.torch_utils import direct_register_custom_op +from vllm.utils.torch import direct_register_custom_op @triton.jit diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index f21fe16c5108..1430a5fe99af 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -10,7 +10,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, ) -from vllm.utils.torch_utils import direct_register_custom_op +from vllm.utils.torch import direct_register_custom_op def flashinfer_fused_moe_blockscale_fp8( diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index c88982c1522b..24a039ea0db6 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -52,7 +52,7 @@ from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used -from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer +from vllm.utils.torch import direct_register_custom_op, is_torch_equal_or_newer from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 6edbb17c0a8e..97f2c2957f8e 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -11,7 +11,7 @@ FusedMoEQuantConfig, ) from vllm.platforms import current_platform -from vllm.utils.torch_utils import direct_register_custom_op +from vllm.utils.torch import direct_register_custom_op class QuantMethod(IntEnum): diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 0627ea50d821..0cbafe18f275 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -25,7 +25,7 @@ from vllm.triton_utils import tl, triton from vllm.utils import cdiv from vllm.utils.flashinfer import flashinfer_fp4_quantize -from vllm.utils.torch_utils import is_torch_equal_or_newer +from vllm.utils.torch import is_torch_equal_or_newer @triton.jit diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index fe0d3a9e319c..acefc11255a9 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -13,7 +13,7 @@ vllm_kernel_override_batch_invariant, ) from vllm.platforms import current_platform -from vllm.utils.torch_utils import direct_register_custom_op +from vllm.utils.torch import direct_register_custom_op def is_rocm_aiter_rmsnorm_enabled() -> bool: diff --git a/vllm/model_executor/layers/mamba/linear_attn.py b/vllm/model_executor/layers/mamba/linear_attn.py index a8f7f652452f..d77d1f7898c5 100644 --- a/vllm/model_executor/layers/mamba/linear_attn.py +++ b/vllm/model_executor/layers/mamba/linear_attn.py @@ -35,7 +35,7 @@ MambaStateShapeCalculator, ) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.utils.torch_utils import direct_register_custom_op +from vllm.utils.torch import direct_register_custom_op from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata if TYPE_CHECKING: diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index a9a0c216474b..8913755dd1fa 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -37,7 +37,7 @@ selective_state_update, ) from vllm.model_executor.utils import set_weight_attrs -from vllm.utils.torch_utils import direct_register_custom_op +from vllm.utils.torch import direct_register_custom_op from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index fb45afa33dad..c9c4f8070664 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -46,7 +46,7 @@ sharded_weight_loader, ) from vllm.model_executor.utils import set_weight_attrs -from vllm.utils.torch_utils import direct_register_custom_op +from vllm.utils.torch import direct_register_custom_op from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata # Added by the IBM Team, 2024 diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index 91a45623582d..f36070300f3d 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -6,7 +6,7 @@ from vllm.config.cache import MambaDType from vllm.config.model import ModelDType from vllm.distributed import divide -from vllm.utils.torch_utils import ( +from vllm.utils.torch import ( STR_DTYPE_TO_TORCH_DTYPE, get_kv_cache_torch_dtype, ) diff --git a/vllm/model_executor/layers/mamba/short_conv.py b/vllm/model_executor/layers/mamba/short_conv.py index 04efa8a8b373..0530e987c343 100644 --- a/vllm/model_executor/layers/mamba/short_conv.py +++ b/vllm/model_executor/layers/mamba/short_conv.py @@ -27,7 +27,7 @@ causal_conv1d_fn, causal_conv1d_update, ) -from vllm.utils.torch_utils import direct_register_custom_op +from vllm.utils.torch import direct_register_custom_op from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionMetadata diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index ccd9b311cc93..a353c132fa2a 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -23,7 +23,7 @@ QuantizationMethods, ) from vllm.platforms import current_platform -from vllm.utils.torch_utils import direct_register_custom_op +from vllm.utils.torch import direct_register_custom_op class BitsAndBytesConfig(QuantizationConfig): diff --git a/vllm/model_executor/layers/quantization/fp_quant.py b/vllm/model_executor/layers/quantization/fp_quant.py index 15a253cef0b7..585088a73511 100644 --- a/vllm/model_executor/layers/quantization/fp_quant.py +++ b/vllm/model_executor/layers/quantization/fp_quant.py @@ -24,7 +24,7 @@ from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform -from vllm.utils.torch_utils import direct_register_custom_op +from vllm.utils.torch import direct_register_custom_op class FPQuantConfig(QuantizationConfig): diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 8a914c57a9f7..4bee58afdf03 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -28,7 +28,7 @@ ) from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.utils import set_weight_attrs -from vllm.utils.torch_utils import direct_register_custom_op +from vllm.utils.torch import direct_register_custom_op logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py index a19396a162bc..eb3d31c7162d 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py @@ -7,7 +7,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops from vllm.platforms import current_platform -from vllm.utils.torch_utils import direct_register_custom_op +from vllm.utils.torch import direct_register_custom_op from .cutlass import CutlassScaledMMLinearKernel from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py index c25c522dea55..50cbaad78f17 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py @@ -45,7 +45,7 @@ def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool: from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4 from aiter.ops.triton.quant import dynamic_mxfp4_quant - from vllm.utils.torch_utils import direct_register_custom_op + from vllm.utils.torch import direct_register_custom_op if is_rocm_aiter_fp4_asm_gemm_enabled(): from aiter import gemm_a4w4, per_1x32_f4_quant_hip diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index cb2d075c1a9d..b91d145da7a9 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -34,7 +34,7 @@ is_deep_gemm_supported, should_use_deepgemm_for_fp8_linear, ) -from vllm.utils.torch_utils import direct_register_custom_op +from vllm.utils.torch import direct_register_custom_op logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py index 5e87cadfb107..3fea2d0b6b8c 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -7,7 +7,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer +from vllm.utils.torch import direct_register_custom_op, is_torch_equal_or_newer logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/quantization/utils/mxfp6_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp6_utils.py index 2b5659e30097..e33b8c6b096d 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp6_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp6_utils.py @@ -3,7 +3,7 @@ import torch from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_BLOCK_SIZE -from vllm.utils.torch_utils import direct_register_custom_op +from vllm.utils.torch import direct_register_custom_op def _quant_dequant_mxfp6( diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 0d036ffdd286..20dbb6bac13c 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -13,7 +13,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.platforms import current_platform from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer -from vllm.utils.torch_utils import direct_register_custom_op +from vllm.utils.torch import direct_register_custom_op # Input scaling factors are no longer optional in _scaled_mm starting # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale diff --git a/vllm/model_executor/layers/rotary_embedding/common.py b/vllm/model_executor/layers/rotary_embedding/common.py index 9e6ec9fdd523..bcdb3176289c 100644 --- a/vllm/model_executor/layers/rotary_embedding/common.py +++ b/vllm/model_executor/layers/rotary_embedding/common.py @@ -10,7 +10,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils.torch_utils import direct_register_custom_op +from vllm.utils.torch import direct_register_custom_op if current_platform.is_cuda(): from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb diff --git a/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py b/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py index a01d14f7b3a1..e12235371545 100644 --- a/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py +++ b/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py @@ -5,7 +5,7 @@ import vllm.envs as envs from vllm.platforms import current_platform -from vllm.utils.torch_utils import direct_register_custom_op +from vllm.utils.torch import direct_register_custom_op def is_rocm_triton_rotary_embedding_enabled() -> bool: diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index c1a48fa200ca..e699c3ac7153 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -9,7 +9,7 @@ from vllm import _custom_ops as ops from vllm import envs from vllm.platforms import CpuArchEnum, current_platform -from vllm.utils.torch_utils import direct_register_custom_op +from vllm.utils.torch import direct_register_custom_op def shuffle_weight(w: torch.Tensor) -> torch.Tensor: diff --git a/vllm/model_executor/model_loader/base_loader.py b/vllm/model_executor/model_loader/base_loader.py index 94dfa478245d..537a69894333 100644 --- a/vllm/model_executor/model_loader/base_loader.py +++ b/vllm/model_executor/model_loader/base_loader.py @@ -12,7 +12,7 @@ initialize_model, process_weights_after_loading, ) -from vllm.utils.torch_utils import set_default_torch_dtype +from vllm.utils.torch import set_default_torch_dtype logger = init_logger(__name__) diff --git a/vllm/model_executor/model_loader/bitsandbytes_loader.py b/vllm/model_executor/model_loader/bitsandbytes_loader.py index 97c7a20bc4d5..8e88963c69e0 100644 --- a/vllm/model_executor/model_loader/bitsandbytes_loader.py +++ b/vllm/model_executor/model_loader/bitsandbytes_loader.py @@ -48,7 +48,7 @@ set_weight_attrs, ) from vllm.platforms import current_platform -from vllm.utils.torch_utils import set_default_torch_dtype +from vllm.utils.torch import set_default_torch_dtype logger = init_logger(__name__) diff --git a/vllm/model_executor/model_loader/gguf_loader.py b/vllm/model_executor/model_loader/gguf_loader.py index 7db1fc167c4f..437d3337cf19 100644 --- a/vllm/model_executor/model_loader/gguf_loader.py +++ b/vllm/model_executor/model_loader/gguf_loader.py @@ -21,7 +21,7 @@ get_gguf_weight_type_map, gguf_quant_weights_iterator, ) -from vllm.utils.torch_utils import set_default_torch_dtype +from vllm.utils.torch import set_default_torch_dtype class GGUFModelLoader(BaseModelLoader): diff --git a/vllm/model_executor/model_loader/tensorizer_loader.py b/vllm/model_executor/model_loader/tensorizer_loader.py index 2b3704cfebba..1fbaa95fed58 100644 --- a/vllm/model_executor/model_loader/tensorizer_loader.py +++ b/vllm/model_executor/model_loader/tensorizer_loader.py @@ -23,7 +23,7 @@ get_model_architecture, initialize_model, ) -from vllm.utils.torch_utils import set_default_torch_dtype +from vllm.utils.torch import set_default_torch_dtype logger = init_logger(__name__) diff --git a/vllm/model_executor/model_loader/tpu.py b/vllm/model_executor/model_loader/tpu.py index fc142f1f07fa..6b86e1def6fa 100644 --- a/vllm/model_executor/model_loader/tpu.py +++ b/vllm/model_executor/model_loader/tpu.py @@ -15,7 +15,7 @@ initialize_model, process_weights_after_loading, ) -from vllm.utils.torch_utils import set_default_torch_dtype +from vllm.utils.torch import set_default_torch_dtype logger = init_logger(__name__) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index da5d80f9828e..e557f38817ec 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -7,7 +7,7 @@ from vllm.logger import init_logger from vllm.model_executor.models import ModelRegistry from vllm.utils import cdiv, round_up -from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE +from vllm.utils.torch import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec if TYPE_CHECKING: diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index d2bc037b97dd..28a7025bedf4 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -80,7 +80,7 @@ from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits -from vllm.utils.torch_utils import direct_register_custom_op +from vllm.utils.torch import direct_register_custom_op from vllm.v1.attention.backends.mla.indexer import ( DeepseekV32IndexerBackend, DeepseekV32IndexerMetadata, diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index a4c9360e50e0..65cd6b085534 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -50,7 +50,7 @@ from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from vllm.utils.collections import is_list_of from vllm.utils.tensor_schema import TensorSchema, TensorShape -from vllm.utils.torch_utils import set_default_torch_dtype +from vllm.utils.torch import set_default_torch_dtype from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import ( diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index e2d2647f0177..ac2325dda4a6 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -52,7 +52,7 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils.tensor_schema import TensorSchema, TensorShape -from vllm.utils.torch_utils import set_default_torch_num_threads +from vllm.utils.torch import set_default_torch_num_threads from .interfaces import ( MultiModalEmbeddings, diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index b4a558ad6970..2c5337387229 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -87,7 +87,7 @@ from vllm.sequence import IntermediateTensors from vllm.utils.collections import flatten_2d_lists from vllm.utils.tensor_schema import TensorSchema, TensorShape -from vllm.utils.torch_utils import set_default_torch_dtype +from vllm.utils.torch import set_default_torch_dtype from .idefics2_vision_model import Idefics2VisionTransformer from .interfaces import ( diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 09293f63f70e..e886fac410a9 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -64,7 +64,7 @@ ) from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors -from vllm.utils.torch_utils import direct_register_custom_op +from vllm.utils.torch import direct_register_custom_op from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 06e94734376c..e097a471359f 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -71,7 +71,7 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import Qwen3NextConfig from vllm.triton_utils import tl, triton -from vllm.utils.torch_utils import direct_register_custom_op +from vllm.utils.torch import direct_register_custom_op from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata from .interfaces import ( diff --git a/vllm/model_executor/models/transformers_moe.py b/vllm/model_executor/models/transformers_moe.py index 43ea9a4869ed..8ca79bbe42fc 100644 --- a/vllm/model_executor/models/transformers_moe.py +++ b/vllm/model_executor/models/transformers_moe.py @@ -28,7 +28,7 @@ from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.platforms import current_platform -from vllm.utils.torch_utils import direct_register_custom_op +from vllm.utils.torch import direct_register_custom_op from .interfaces import MixtureOfExperts, SupportsMultiModal from .transformers import ( diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index c6fa50ccbc66..93346c142f3d 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -28,7 +28,7 @@ is_pin_memory_available, is_uva_available, ) -from vllm.utils.torch_utils import get_cuda_view_from_cpu_tensor +from vllm.utils.torch import get_cuda_view_from_cpu_tensor logger = init_logger(__name__) diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index ccfe1871ef07..62048a81972f 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -52,7 +52,7 @@ from vllm.transformers_utils.processor import cached_get_processor from vllm.utils.jsontree import json_map_leaves from vllm.utils.tensor_schema import TensorSchema, TensorShape -from vllm.utils.torch_utils import set_default_torch_dtype +from vllm.utils.torch import set_default_torch_dtype from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsTranscription from .utils import ( diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index 99651a408b31..13a5152af48b 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -8,7 +8,7 @@ from vllm import envs from vllm.plugins import PLATFORM_PLUGINS_GROUP, load_plugins_by_group from vllm.utils import resolve_obj_by_qualname -from vllm.utils.torch_utils import supports_xccl +from vllm.utils.torch import supports_xccl from .interface import CpuArchEnum, Platform, PlatformEnum diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index c736e084a38d..943cca678a22 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -17,7 +17,7 @@ import vllm.envs as envs from vllm.logger import init_logger from vllm.utils import import_pynvml -from vllm.utils.torch_utils import cuda_device_count_stateless +from vllm.utils.torch import cuda_device_count_stateless from .interface import DeviceCapability, Platform, PlatformEnum diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 68e6c06c8814..ed27df55ee01 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -9,7 +9,7 @@ import vllm.envs as envs from vllm.logger import init_logger -from vllm.utils.torch_utils import cuda_device_count_stateless +from vllm.utils.torch import cuda_device_count_stateless from .interface import DeviceCapability, Platform, PlatformEnum diff --git a/vllm/usage/usage_lib.py b/vllm/usage/usage_lib.py index 4211535131a4..2c7abd2d7ecc 100644 --- a/vllm/usage/usage_lib.py +++ b/vllm/usage/usage_lib.py @@ -22,7 +22,7 @@ from vllm.connections import global_http_connection from vllm.logger import init_logger from vllm.utils import cuda_get_device_properties -from vllm.utils.torch_utils import cuda_device_count_stateless +from vllm.utils.torch import cuda_device_count_stateless from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch.py similarity index 100% rename from vllm/utils/torch_utils.py rename to vllm/utils/torch.py diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index d14f949b6579..a6ee8936fca5 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -29,7 +29,7 @@ vllm_kernel_override_batch_invariant, ) from vllm.utils import cdiv -from vllm.utils.torch_utils import is_torch_equal_or_newer +from vllm.utils.torch import is_torch_equal_or_newer from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index f7a4114a0a70..f9c1b69d491e 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -29,7 +29,7 @@ import aiter from vllm.triton_utils import tl, triton - from vllm.utils.torch_utils import direct_register_custom_op + from vllm.utils.torch import direct_register_custom_op @triton.jit def _vllm_layout_trans_kernel( diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 392519f8fa9a..ab795493eea4 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -11,7 +11,7 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.utils import cdiv -from vllm.utils.torch_utils import get_dtype_size +from vllm.utils.torch import get_dtype_size logger = init_logger(__name__) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 7a73177ba7d8..1836f58d0b6e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -82,7 +82,7 @@ round_up, ) from vllm.utils.jsontree import json_map_leaves -from vllm.utils.torch_utils import ( +from vllm.utils.torch import ( STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size, supports_dynamo, diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 1a758386d1c9..ba839373356b 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -27,7 +27,7 @@ from vllm.platforms.tpu import USE_TPU_INFERENCE from vllm.tasks import SupportedTask from vllm.utils import cdiv -from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE +from vllm.utils.torch import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ModelRunnerOutput diff --git a/vllm/v1/worker/ubatching.py b/vllm/v1/worker/ubatching.py index 6edcb7848638..33ce20943d61 100644 --- a/vllm/v1/worker/ubatching.py +++ b/vllm/v1/worker/ubatching.py @@ -7,7 +7,7 @@ from vllm import forward_context from vllm.forward_context import ForwardContext -from vllm.utils.torch_utils import current_stream +from vllm.utils.torch import current_stream _THREAD_ID_TO_CONTEXT: dict = {} _CURRENT_CONTEXTS: list[Optional["UBatchContext"]] = [None, None] From 32923fd947bb87e7f22d2ddbaed3da7a8e08b45b Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Thu, 16 Oct 2025 22:58:23 +0800 Subject: [PATCH 10/15] fix import Signed-off-by: Isotr0py --- vllm/model_executor/layers/fused_moe/layer.py | 3 ++- vllm/model_executor/layers/quantization/mxfp4.py | 2 +- vllm/model_executor/models/utils.py | 3 +-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index de4ed58e0cf4..a34a93c4f99c 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -52,8 +52,9 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum -from vllm.utils import cdiv, direct_register_custom_op, has_deep_ep, has_pplx, round_up +from vllm.utils import cdiv, has_deep_ep, has_pplx, round_up from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe +from vllm.utils.torch import direct_register_custom_op from vllm.v1.worker.ubatching import dbo_current_ubatch_id if current_platform.is_cuda_alike(): diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index a7f9fdcb5513..f7a7741e8d46 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -48,11 +48,11 @@ from vllm.scalar_type import scalar_types from vllm.utils import ( has_triton_kernels, - is_torch_equal_or_newer, next_power_of_2, round_up, ) from vllm.utils.flashinfer import has_flashinfer +from vllm.utils.torch import is_torch_equal_or_newer logger = init_logger(__name__) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 93346c142f3d..9a95d614a9f0 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -24,11 +24,10 @@ from vllm.sequence import IntermediateTensors from vllm.utils import ( cdiv, - direct_register_custom_op, is_pin_memory_available, is_uva_available, ) -from vllm.utils.torch import get_cuda_view_from_cpu_tensor +from vllm.utils.torch import direct_register_custom_op, get_cuda_view_from_cpu_tensor logger = init_logger(__name__) From 9a787667fc7cb4152a45a06d490a89484dde4798 Mon Sep 17 00:00:00 2001 From: isotr0py <2037008807@qq.com> Date: Thu, 16 Oct 2025 23:06:58 +0800 Subject: [PATCH 11/15] rename back to torch_utils to avoid conflicts Signed-off-by: isotr0py <2037008807@qq.com> --- benchmarks/kernels/bench_per_token_quant_fp8.py | 2 +- benchmarks/kernels/benchmark_activation.py | 2 +- benchmarks/kernels/benchmark_layernorm.py | 2 +- benchmarks/kernels/benchmark_paged_attention.py | 2 +- benchmarks/kernels/benchmark_quant.py | 2 +- benchmarks/kernels/benchmark_reshape_and_cache.py | 2 +- benchmarks/kernels/benchmark_reshape_and_cache_flash.py | 2 +- tests/compile/piecewise/test_full_cudagraph.py | 2 +- tests/compile/piecewise/test_multiple_graphs.py | 2 +- tests/compile/piecewise/test_simple.py | 2 +- tests/compile/piecewise/test_toy_llama.py | 2 +- tests/compile/silly_attention.py | 2 +- tests/compile/test_aot_compile.py | 2 +- tests/compile/test_basic_correctness.py | 2 +- tests/compile/test_config.py | 2 +- tests/compile/test_decorator.py | 2 +- tests/compile/test_full_graph.py | 2 +- tests/compile/test_fusion_attn.py | 2 +- tests/conftest.py | 2 +- tests/distributed/test_sequence_parallel.py | 2 +- tests/distributed/test_utils.py | 2 +- tests/kernels/attention/conftest.py | 2 +- tests/kernels/attention/test_prefix_prefill.py | 2 +- tests/kernels/core/test_uva.py | 2 +- tests/kernels/moe/test_modular_kernel_combinations.py | 2 +- tests/models/multimodal/pooling/test_intern_vit.py | 2 +- tests/models/multimodal/pooling/test_radio.py | 2 +- tests/models/multimodal/processing/test_tensor_schema.py | 2 +- tests/utils.py | 2 +- tests/utils_/test_utils.py | 2 +- tests/v1/attention/test_attention_backends.py | 2 +- tests/v1/attention/test_mla_backends.py | 2 +- tests/v1/engine/test_async_llm.py | 2 +- tests/v1/engine/test_engine_core.py | 2 +- tests/v1/engine/test_engine_core_client.py | 2 +- tests/v1/shutdown/test_delete.py | 2 +- tests/v1/shutdown/test_forward_error.py | 2 +- tests/v1/shutdown/test_startup_error.py | 2 +- vllm/attention/layer.py | 2 +- vllm/attention/ops/rocm_aiter_mla.py | 2 +- vllm/compilation/backends.py | 2 +- vllm/compilation/collective_fusion.py | 2 +- vllm/compilation/compiler_interface.py | 2 +- vllm/compilation/cuda_graph.py | 2 +- vllm/compilation/decorators.py | 2 +- vllm/compilation/inductor_pass.py | 2 +- vllm/config/compilation.py | 2 +- vllm/config/model.py | 2 +- vllm/config/parallel.py | 2 +- vllm/distributed/device_communicators/all_reduce_utils.py | 2 +- vllm/distributed/device_communicators/custom_all_reduce.py | 2 +- vllm/distributed/device_communicators/pynccl.py | 4 ++-- vllm/distributed/device_communicators/quick_all_reduce.py | 2 +- vllm/distributed/device_communicators/ray_communicator.py | 2 +- .../kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py | 2 +- vllm/distributed/parallel_state.py | 2 +- vllm/distributed/utils.py | 2 +- vllm/env_override.py | 2 +- vllm/envs.py | 2 +- vllm/lora/ops/triton_ops/lora_expand_op.py | 2 +- vllm/lora/ops/triton_ops/lora_shrink_op.py | 2 +- .../model_executor/layers/fused_moe/flashinfer_trtllm_moe.py | 2 +- vllm/model_executor/layers/fused_moe/fused_moe.py | 2 +- vllm/model_executor/layers/fused_moe/layer.py | 2 +- vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py | 2 +- vllm/model_executor/layers/fused_moe/utils.py | 2 +- vllm/model_executor/layers/layernorm.py | 2 +- vllm/model_executor/layers/mamba/linear_attn.py | 2 +- vllm/model_executor/layers/mamba/mamba_mixer.py | 2 +- vllm/model_executor/layers/mamba/mamba_mixer2.py | 2 +- vllm/model_executor/layers/mamba/mamba_utils.py | 2 +- vllm/model_executor/layers/mamba/short_conv.py | 2 +- vllm/model_executor/layers/quantization/bitsandbytes.py | 2 +- vllm/model_executor/layers/quantization/fp_quant.py | 2 +- vllm/model_executor/layers/quantization/gguf.py | 2 +- .../layers/quantization/kernels/scaled_mm/aiter.py | 2 +- vllm/model_executor/layers/quantization/mxfp4.py | 2 +- .../layers/quantization/quark/schemes/quark_ocp_mx.py | 2 +- vllm/model_executor/layers/quantization/utils/fp8_utils.py | 2 +- vllm/model_executor/layers/quantization/utils/mxfp4_utils.py | 2 +- vllm/model_executor/layers/quantization/utils/mxfp6_utils.py | 2 +- vllm/model_executor/layers/quantization/utils/w8a8_utils.py | 2 +- vllm/model_executor/layers/rotary_embedding/common.py | 2 +- .../layers/rotary_embedding/rocm_aiter_rope_ops.py | 2 +- vllm/model_executor/layers/utils.py | 2 +- vllm/model_executor/model_loader/base_loader.py | 2 +- vllm/model_executor/model_loader/bitsandbytes_loader.py | 2 +- vllm/model_executor/model_loader/gguf_loader.py | 2 +- vllm/model_executor/model_loader/tensorizer_loader.py | 2 +- vllm/model_executor/model_loader/tpu.py | 2 +- vllm/model_executor/models/config.py | 2 +- vllm/model_executor/models/deepseek_v2.py | 2 +- vllm/model_executor/models/deepseek_vl2.py | 2 +- vllm/model_executor/models/internvl.py | 2 +- vllm/model_executor/models/minicpmv.py | 2 +- vllm/model_executor/models/plamo2.py | 2 +- vllm/model_executor/models/qwen3_next.py | 2 +- vllm/model_executor/models/transformers_moe.py | 2 +- vllm/model_executor/models/utils.py | 5 ++++- vllm/model_executor/models/whisper.py | 2 +- vllm/platforms/__init__.py | 2 +- vllm/platforms/cuda.py | 2 +- vllm/platforms/rocm.py | 2 +- vllm/usage/usage_lib.py | 2 +- vllm/utils/{torch.py => torch_utils.py} | 0 vllm/v1/attention/backends/flex_attention.py | 2 +- vllm/v1/attention/backends/rocm_aiter_fa.py | 2 +- vllm/v1/kv_cache_interface.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 2 +- vllm/v1/worker/tpu_worker.py | 2 +- vllm/v1/worker/ubatching.py | 2 +- 111 files changed, 114 insertions(+), 111 deletions(-) rename vllm/utils/{torch.py => torch_utils.py} (100%) diff --git a/benchmarks/kernels/bench_per_token_quant_fp8.py b/benchmarks/kernels/bench_per_token_quant_fp8.py index 59ed54659872..d33b84fc3601 100644 --- a/benchmarks/kernels/bench_per_token_quant_fp8.py +++ b/benchmarks/kernels/bench_per_token_quant_fp8.py @@ -11,7 +11,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.triton_utils import triton from vllm.utils import FlexibleArgumentParser -from vllm.utils.torch import STR_DTYPE_TO_TORCH_DTYPE +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE def with_triton_mode(fn): diff --git a/benchmarks/kernels/benchmark_activation.py b/benchmarks/kernels/benchmark_activation.py index 51a69974ac71..7662655b5efa 100644 --- a/benchmarks/kernels/benchmark_activation.py +++ b/benchmarks/kernels/benchmark_activation.py @@ -11,7 +11,7 @@ from vllm.platforms import current_platform from vllm.triton_utils import triton from vllm.utils import FlexibleArgumentParser -from vllm.utils.torch import STR_DTYPE_TO_TORCH_DTYPE +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE batch_size_range = [1, 16, 32, 64, 128] seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096] diff --git a/benchmarks/kernels/benchmark_layernorm.py b/benchmarks/kernels/benchmark_layernorm.py index 53d4d3174f80..bcfa64c3f425 100644 --- a/benchmarks/kernels/benchmark_layernorm.py +++ b/benchmarks/kernels/benchmark_layernorm.py @@ -8,7 +8,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.platforms import current_platform from vllm.utils import FlexibleArgumentParser -from vllm.utils.torch import STR_DTYPE_TO_TORCH_DTYPE +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE @torch.inference_mode() diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 82b1efdd20f6..1b1e71adeec4 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -10,7 +10,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import FlexibleArgumentParser -from vllm.utils.torch import ( +from vllm.utils.torch_utils import ( STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random, ) diff --git a/benchmarks/kernels/benchmark_quant.py b/benchmarks/kernels/benchmark_quant.py index 88af6f25e358..61427a77b4e3 100644 --- a/benchmarks/kernels/benchmark_quant.py +++ b/benchmarks/kernels/benchmark_quant.py @@ -8,7 +8,7 @@ from vllm import _custom_ops as ops from vllm.platforms import current_platform from vllm.utils import FlexibleArgumentParser -from vllm.utils.torch import STR_DTYPE_TO_TORCH_DTYPE +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE @torch.inference_mode() diff --git a/benchmarks/kernels/benchmark_reshape_and_cache.py b/benchmarks/kernels/benchmark_reshape_and_cache.py index 16b42e701824..e0ff09d4b397 100644 --- a/benchmarks/kernels/benchmark_reshape_and_cache.py +++ b/benchmarks/kernels/benchmark_reshape_and_cache.py @@ -10,7 +10,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import FlexibleArgumentParser -from vllm.utils.torch import ( +from vllm.utils.torch_utils import ( STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random, ) diff --git a/benchmarks/kernels/benchmark_reshape_and_cache_flash.py b/benchmarks/kernels/benchmark_reshape_and_cache_flash.py index b360256ec876..29f1b2ccdcf6 100644 --- a/benchmarks/kernels/benchmark_reshape_and_cache_flash.py +++ b/benchmarks/kernels/benchmark_reshape_and_cache_flash.py @@ -13,7 +13,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import FlexibleArgumentParser -from vllm.utils.torch import ( +from vllm.utils.torch_utils import ( STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random_flash, ) diff --git a/tests/compile/piecewise/test_full_cudagraph.py b/tests/compile/piecewise/test_full_cudagraph.py index a6edbb410ef6..c6d4b5272dbc 100644 --- a/tests/compile/piecewise/test_full_cudagraph.py +++ b/tests/compile/piecewise/test_full_cudagraph.py @@ -11,7 +11,7 @@ from vllm import LLM, SamplingParams from vllm.config import CompilationConfig from vllm.platforms import current_platform -from vllm.utils.torch import is_torch_equal_or_newer +from vllm.utils.torch_utils import is_torch_equal_or_newer @contextlib.contextmanager diff --git a/tests/compile/piecewise/test_multiple_graphs.py b/tests/compile/piecewise/test_multiple_graphs.py index eaf625c934f5..700f57ffb068 100644 --- a/tests/compile/piecewise/test_multiple_graphs.py +++ b/tests/compile/piecewise/test_multiple_graphs.py @@ -20,7 +20,7 @@ set_current_vllm_config, ) from vllm.forward_context import BatchDescriptor, set_forward_context -from vllm.utils.torch import is_torch_equal_or_newer +from vllm.utils.torch_utils import is_torch_equal_or_newer # This import automatically registers `torch.ops.silly.attention` from .. import silly_attention # noqa: F401 diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index 841392463094..9d4e3f0f300f 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -19,7 +19,7 @@ set_current_vllm_config, ) from vllm.forward_context import BatchDescriptor, set_forward_context -from vllm.utils.torch import is_torch_equal_or_newer +from vllm.utils.torch_utils import is_torch_equal_or_newer # This import automatically registers `torch.ops.silly.attention` from ..silly_attention import get_global_counter, reset_global_counter diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index da18f7dca26f..175ca4a23043 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -27,7 +27,7 @@ set_current_vllm_config, ) from vllm.forward_context import BatchDescriptor, set_forward_context -from vllm.utils.torch import is_torch_equal_or_newer +from vllm.utils.torch_utils import is_torch_equal_or_newer # This import automatically registers `torch.ops.silly.attention` from .. import silly_attention # noqa: F401 diff --git a/tests/compile/silly_attention.py b/tests/compile/silly_attention.py index 1e055db6f68e..29c02f6e6a1d 100644 --- a/tests/compile/silly_attention.py +++ b/tests/compile/silly_attention.py @@ -8,7 +8,7 @@ import torch from torch.library import Library -from vllm.utils.torch import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op # Shared library for all compilation test operations # Using "silly" namespace to match existing test expectations diff --git a/tests/compile/test_aot_compile.py b/tests/compile/test_aot_compile.py index 9fb4fae2b74a..b2734af575a1 100644 --- a/tests/compile/test_aot_compile.py +++ b/tests/compile/test_aot_compile.py @@ -15,7 +15,7 @@ set_current_vllm_config, ) from vllm.forward_context import set_forward_context -from vllm.utils.torch import is_torch_equal_or_newer +from vllm.utils.torch_utils import is_torch_equal_or_newer def reference_fn(x: torch.Tensor): diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index 0d8a811953c8..132a838b8d44 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -5,7 +5,7 @@ import pytest from vllm.config import CompilationMode -from vllm.utils.torch import cuda_device_count_stateless +from vllm.utils.torch_utils import cuda_device_count_stateless from ..utils import compare_all_settings diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index e666aea3af54..20e3c5039cc0 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -5,7 +5,7 @@ from vllm.compilation.counter import compilation_counter from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig from vllm.config.compilation import CompilationMode -from vllm.utils.torch import _is_torch_equal_or_newer, is_torch_equal_or_newer +from vllm.utils.torch_utils import _is_torch_equal_or_newer, is_torch_equal_or_newer def test_version(): diff --git a/tests/compile/test_decorator.py b/tests/compile/test_decorator.py index 7b5149e89e52..c9d01f2317d2 100644 --- a/tests/compile/test_decorator.py +++ b/tests/compile/test_decorator.py @@ -15,7 +15,7 @@ set_current_vllm_config, ) from vllm.forward_context import BatchDescriptor, set_forward_context -from vllm.utils.torch import is_torch_equal_or_newer +from vllm.utils.torch_utils import is_torch_equal_or_newer # This import automatically registers `torch.ops.silly.attention` from . import silly_attention # noqa: F401 diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index af8849a62e84..248a9f3c7730 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -14,7 +14,7 @@ from vllm.attention.selector import global_force_attn_backend_context_manager from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig from vllm.platforms import current_platform -from vllm.utils.torch import is_torch_equal_or_newer +from vllm.utils.torch_utils import is_torch_equal_or_newer from ..utils import create_new_process_for_each_test diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index b47eecf407d0..a35fb9c8c31f 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -33,7 +33,7 @@ ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp from vllm.platforms import current_platform -from vllm.utils.torch import is_torch_equal_or_newer +from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.v1.kv_cache_interface import AttentionSpec FP8_DTYPE = current_platform.fp8_dtype() diff --git a/tests/conftest.py b/tests/conftest.py index 722100c4d99e..45ccc38a4538 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -58,7 +58,7 @@ from vllm.sampling_params import BeamSearchParams from vllm.transformers_utils.utils import maybe_model_redirect from vllm.utils.collections import is_list_of -from vllm.utils.torch import set_default_torch_num_threads +from vllm.utils.torch_utils import set_default_torch_num_threads logger = init_logger(__name__) diff --git a/tests/distributed/test_sequence_parallel.py b/tests/distributed/test_sequence_parallel.py index 5974865d2387..c35f6a3c2507 100644 --- a/tests/distributed/test_sequence_parallel.py +++ b/tests/distributed/test_sequence_parallel.py @@ -18,7 +18,7 @@ from vllm.config.compilation import CompilationMode from vllm.config.model import RunnerOption from vllm.logger import init_logger -from vllm.utils.torch import is_torch_equal_or_newer +from vllm.utils.torch_utils import is_torch_equal_or_newer from ..models.registry import HF_EXAMPLE_MODELS from ..utils import compare_two_settings, create_new_process_for_each_test diff --git a/tests/distributed/test_utils.py b/tests/distributed/test_utils.py index 3b8fa43d0b7a..c10c2565811b 100644 --- a/tests/distributed/test_utils.py +++ b/tests/distributed/test_utils.py @@ -14,7 +14,7 @@ get_open_port, update_environment_variables, ) -from vllm.utils.torch import cuda_device_count_stateless +from vllm.utils.torch_utils import cuda_device_count_stateless from ..utils import multi_gpu_test diff --git a/tests/kernels/attention/conftest.py b/tests/kernels/attention/conftest.py index 6704fa99359f..e520267320c0 100644 --- a/tests/kernels/attention/conftest.py +++ b/tests/kernels/attention/conftest.py @@ -3,7 +3,7 @@ import pytest -from vllm.utils.torch import ( +from vllm.utils.torch_utils import ( create_kv_caches_with_random, create_kv_caches_with_random_flash, ) diff --git a/tests/kernels/attention/test_prefix_prefill.py b/tests/kernels/attention/test_prefix_prefill.py index e75cbb350d73..65972d02f2f6 100644 --- a/tests/kernels/attention/test_prefix_prefill.py +++ b/tests/kernels/attention/test_prefix_prefill.py @@ -15,7 +15,7 @@ from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode from vllm.attention.ops.prefix_prefill import context_attention_fwd from vllm.platforms import current_platform -from vllm.utils.torch import STR_DTYPE_TO_TORCH_DTYPE +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE NUM_HEADS = [64] NUM_QUERIES_PER_KV = [1, 64] diff --git a/tests/kernels/core/test_uva.py b/tests/kernels/core/test_uva.py index aaa4ec311afc..dee92976eb6f 100644 --- a/tests/kernels/core/test_uva.py +++ b/tests/kernels/core/test_uva.py @@ -4,7 +4,7 @@ import torch from vllm.utils import is_uva_available -from vllm.utils.torch import get_cuda_view_from_cpu_tensor +from vllm.utils.torch_utils import get_cuda_view_from_cpu_tensor CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] diff --git a/tests/kernels/moe/test_modular_kernel_combinations.py b/tests/kernels/moe/test_modular_kernel_combinations.py index 4403a69f79f1..a7beb313011a 100644 --- a/tests/kernels/moe/test_modular_kernel_combinations.py +++ b/tests/kernels/moe/test_modular_kernel_combinations.py @@ -15,7 +15,7 @@ from vllm.platforms import current_platform from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe -from vllm.utils.torch import cuda_device_count_stateless +from vllm.utils.torch_utils import cuda_device_count_stateless from .modular_kernel_tools.common import ( Config, diff --git a/tests/models/multimodal/pooling/test_intern_vit.py b/tests/models/multimodal/pooling/test_intern_vit.py index 155cabfe3a88..5a97848216b8 100644 --- a/tests/models/multimodal/pooling/test_intern_vit.py +++ b/tests/models/multimodal/pooling/test_intern_vit.py @@ -7,7 +7,7 @@ from transformers import AutoConfig, AutoModel, CLIPImageProcessor from vllm.distributed import cleanup_dist_env_and_memory -from vllm.utils.torch import STR_DTYPE_TO_TORCH_DTYPE +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from ....conftest import ImageTestAssets diff --git a/tests/models/multimodal/pooling/test_radio.py b/tests/models/multimodal/pooling/test_radio.py index 47b359b45de9..8929563d8b05 100644 --- a/tests/models/multimodal/pooling/test_radio.py +++ b/tests/models/multimodal/pooling/test_radio.py @@ -9,7 +9,7 @@ from vllm.distributed import cleanup_dist_env_and_memory from vllm.model_executor.models.radio import RadioModel from vllm.transformers_utils.configs.radio import RadioConfig -from vllm.utils.torch import STR_DTYPE_TO_TORCH_DTYPE +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from ....conftest import ImageTestAssets diff --git a/tests/models/multimodal/processing/test_tensor_schema.py b/tests/models/multimodal/processing/test_tensor_schema.py index bbfdee780f9e..8de8ecfe9d83 100644 --- a/tests/models/multimodal/processing/test_tensor_schema.py +++ b/tests/models/multimodal/processing/test_tensor_schema.py @@ -35,7 +35,7 @@ from vllm.multimodal.utils import group_mm_kwargs_by_modality from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from vllm.utils.collections import is_list_of -from vllm.utils.torch import set_default_torch_dtype +from vllm.utils.torch_utils import set_default_torch_dtype from ...registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS from ...utils import dummy_hf_overrides diff --git a/tests/utils.py b/tests/utils.py index dd7b7c14911c..ad5139c7f8ba 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -48,7 +48,7 @@ GB_bytes, get_open_port, ) -from vllm.utils.torch import cuda_device_count_stateless +from vllm.utils.torch_utils import cuda_device_count_stateless if current_platform.is_rocm(): from amdsmi import ( diff --git a/tests/utils_/test_utils.py b/tests/utils_/test_utils.py index 0c8750bd150b..5ce9a9604b08 100644 --- a/tests/utils_/test_utils.py +++ b/tests/utils_/test_utils.py @@ -37,7 +37,7 @@ split_zmq_path, unique_filepath, ) -from vllm.utils.torch import ( +from vllm.utils.torch_utils import ( common_broadcastable_dtype, current_stream, is_lossless_cast, diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index e8cc45e74f5f..7a5216d1738d 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -19,7 +19,7 @@ from vllm.config import ModelConfig from vllm.platforms import current_platform from vllm.utils import cdiv -from vllm.utils.torch import STR_DTYPE_TO_TORCH_DTYPE, is_torch_equal_or_newer +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE, is_torch_equal_or_newer from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata, set_kv_cache_layout, diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index 146c91ea763d..81fd6433b0c8 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -23,7 +23,7 @@ from vllm.attention.ops.flashmla import is_flashmla_dense_supported from vllm.config.vllm import set_current_vllm_config from vllm.utils import cdiv -from vllm.utils.torch import STR_DTYPE_TO_TORCH_DTYPE +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import FullAttentionSpec diff --git a/tests/v1/engine/test_async_llm.py b/tests/v1/engine/test_async_llm.py index 23752b524899..c9605ea1b07c 100644 --- a/tests/v1/engine/test_async_llm.py +++ b/tests/v1/engine/test_async_llm.py @@ -15,7 +15,7 @@ from vllm.outputs import RequestOutput from vllm.platforms import current_platform from vllm.sampling_params import RequestOutputKind -from vllm.utils.torch import set_default_torch_num_threads +from vllm.utils.torch_utils import set_default_torch_num_threads from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.metrics.loggers import ( AggregatedLoggingStatLogger, diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index 4c2260558dfe..7e39cd781bae 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -12,7 +12,7 @@ from vllm import SamplingParams from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform -from vllm.utils.torch import set_default_torch_num_threads +from vllm.utils.torch_utils import set_default_torch_num_threads from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.core import EngineCore from vllm.v1.executor.abstract import Executor, UniProcExecutor diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 6d989c872962..770560a5e549 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -21,7 +21,7 @@ from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform from vllm.usage.usage_lib import UsageContext -from vllm.utils.torch import set_default_torch_num_threads +from vllm.utils.torch_utils import set_default_torch_num_threads from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.core import EngineCore from vllm.v1.engine.core_client import AsyncMPClient, EngineCoreClient, SyncMPClient diff --git a/tests/v1/shutdown/test_delete.py b/tests/v1/shutdown/test_delete.py index c3610c02b8c3..255515948433 100644 --- a/tests/v1/shutdown/test_delete.py +++ b/tests/v1/shutdown/test_delete.py @@ -12,7 +12,7 @@ from vllm import LLM, SamplingParams from vllm.engine.arg_utils import AsyncEngineArgs from vllm.sampling_params import RequestOutputKind -from vllm.utils.torch import cuda_device_count_stateless +from vllm.utils.torch_utils import cuda_device_count_stateless from vllm.v1.engine.async_llm import AsyncLLM MODELS = ["meta-llama/Llama-3.2-1B"] diff --git a/tests/v1/shutdown/test_forward_error.py b/tests/v1/shutdown/test_forward_error.py index c7587d2dd55b..e65d46dfa43a 100644 --- a/tests/v1/shutdown/test_forward_error.py +++ b/tests/v1/shutdown/test_forward_error.py @@ -14,7 +14,7 @@ from vllm import LLM, AsyncEngineArgs, SamplingParams from vllm.distributed import get_tensor_model_parallel_rank from vllm.model_executor.models.llama import LlamaForCausalLM -from vllm.utils.torch import cuda_device_count_stateless +from vllm.utils.torch_utils import cuda_device_count_stateless from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.exceptions import EngineDeadError diff --git a/tests/v1/shutdown/test_startup_error.py b/tests/v1/shutdown/test_startup_error.py index 499b9d123d92..3877fceae00c 100644 --- a/tests/v1/shutdown/test_startup_error.py +++ b/tests/v1/shutdown/test_startup_error.py @@ -13,7 +13,7 @@ from vllm.distributed import get_tensor_model_parallel_rank from vllm.engine.arg_utils import AsyncEngineArgs from vllm.model_executor.models.llama import LlamaForCausalLM -from vllm.utils.torch import cuda_device_count_stateless +from vllm.utils.torch_utils import cuda_device_count_stateless from vllm.v1.engine.async_llm import AsyncLLM MODELS = ["meta-llama/Llama-3.2-1B"] diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index e5754ec09c53..9b6ea3f0b8f5 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -35,7 +35,7 @@ from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.platforms import current_platform from vllm.utils import GiB_bytes -from vllm.utils.torch import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op FP8_DTYPE = current_platform.fp8_dtype() logger = init_logger(__name__) diff --git a/vllm/attention/ops/rocm_aiter_mla.py b/vllm/attention/ops/rocm_aiter_mla.py index 552367b16d5c..6308f63cc4e7 100644 --- a/vllm/attention/ops/rocm_aiter_mla.py +++ b/vllm/attention/ops/rocm_aiter_mla.py @@ -5,7 +5,7 @@ import torch from vllm.platforms import current_platform -from vllm.utils.torch import direct_register_custom_op, is_torch_equal_or_newer +from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer def get_aiter_mla_metadata( diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index a6e4fcc52e71..7f4c9e6af13b 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -25,7 +25,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import resolve_obj_by_qualname -from vllm.utils.torch import is_torch_equal_or_newer +from vllm.utils.torch_utils import is_torch_equal_or_newer from .caching import VllmSerializableFunction from .compiler_interface import ( diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 7d21f0354630..fb1f78f21a05 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -18,7 +18,7 @@ ) from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils.torch import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op from .inductor_pass import enable_fake_mode from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 383dcaa3af21..0a3f0769db94 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -16,7 +16,7 @@ import vllm.envs as envs from vllm.compilation.counter import compilation_counter from vllm.config import VllmConfig -from vllm.utils.torch import is_torch_equal_or_newer +from vllm.utils.torch_utils import is_torch_equal_or_newer class CompilerInterface: diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py index 0e0fea450a31..a2e0abfebc2c 100644 --- a/vllm/compilation/cuda_graph.py +++ b/vllm/compilation/cuda_graph.py @@ -17,7 +17,7 @@ from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils.torch import weak_ref_tensors +from vllm.utils.torch_utils import weak_ref_tensors logger = init_logger(__name__) diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 506ba1767f92..abe61cce0dd8 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -22,7 +22,7 @@ from vllm.logger import init_logger from vllm.sequence import IntermediateTensors from vllm.utils import resolve_obj_by_qualname -from vllm.utils.torch import supports_dynamo +from vllm.utils.torch_utils import supports_dynamo from .monitor import start_monitoring_torch_compile diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index 23d8171375b5..9af635a929b4 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -14,7 +14,7 @@ from torch import fx from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily -from vllm.utils.torch import is_torch_equal_or_newer +from vllm.utils.torch_utils import is_torch_equal_or_newer if is_torch_equal_or_newer("2.6"): from torch._inductor.custom_graph_pass import CustomGraphPass diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index b8b4c8bc5f28..620a521a9a7e 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -17,7 +17,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import resolve_obj_by_qualname -from vllm.utils.torch import is_torch_equal_or_newer +from vllm.utils.torch_utils import is_torch_equal_or_newer if TYPE_CHECKING: from vllm.config import VllmConfig diff --git a/vllm/config/model.py b/vllm/config/model.py index cf106ea37e28..68070d204b3c 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -42,7 +42,7 @@ from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri from vllm.transformers_utils.utils import maybe_model_redirect from vllm.utils import LayerBlockType, LazyLoader -from vllm.utils.torch import common_broadcastable_dtype +from vllm.utils.torch_utils import common_broadcastable_dtype if TYPE_CHECKING: from transformers import PretrainedConfig diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 8a5e4b427b8e..aa1ac4ab8f0b 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -19,7 +19,7 @@ ) from vllm.platforms import current_platform from vllm.utils import get_open_ports_list -from vllm.utils.torch import cuda_device_count_stateless +from vllm.utils.torch_utils import cuda_device_count_stateless if TYPE_CHECKING: from ray.runtime_env import RuntimeEnv diff --git a/vllm/distributed/device_communicators/all_reduce_utils.py b/vllm/distributed/device_communicators/all_reduce_utils.py index cc98b94ace10..09c89ea31d05 100644 --- a/vllm/distributed/device_communicators/all_reduce_utils.py +++ b/vllm/distributed/device_communicators/all_reduce_utils.py @@ -23,7 +23,7 @@ vllm_kernel_override_batch_invariant, ) from vllm.utils import update_environment_variables -from vllm.utils.torch import cuda_device_count_stateless +from vllm.utils.torch_utils import cuda_device_count_stateless logger = init_logger(__name__) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index dfb9c6f5afc3..4b82f3b5d396 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -17,7 +17,7 @@ from vllm.distributed.parallel_state import in_the_same_node_as from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils.torch import cuda_device_count_stateless +from vllm.utils.torch_utils import cuda_device_count_stateless try: ops.meta_size() diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index baccaecdacd5..ad3c8676fafd 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -19,7 +19,7 @@ ) from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger -from vllm.utils.torch import current_stream +from vllm.utils.torch_utils import current_stream logger = init_logger(__name__) @@ -30,7 +30,7 @@ def register_nccl_symmetric_ops(pynccl_comm): from vllm.distributed.device_communicators.pynccl_allocator import ( nccl_symm_mem_context, ) - from vllm.utils.torch import direct_register_custom_op + from vllm.utils.torch_utils import direct_register_custom_op global _NCCL_SYMM_OPS_REGISTERED if _NCCL_SYMM_OPS_REGISTERED: diff --git a/vllm/distributed/device_communicators/quick_all_reduce.py b/vllm/distributed/device_communicators/quick_all_reduce.py index 86f28bee6408..9c7765883cfd 100644 --- a/vllm/distributed/device_communicators/quick_all_reduce.py +++ b/vllm/distributed/device_communicators/quick_all_reduce.py @@ -13,7 +13,7 @@ from vllm.distributed.parallel_state import in_the_same_node_as from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils.torch import cuda_device_count_stateless +from vllm.utils.torch_utils import cuda_device_count_stateless logger = init_logger(__name__) diff --git a/vllm/distributed/device_communicators/ray_communicator.py b/vllm/distributed/device_communicators/ray_communicator.py index 4158b06264af..3b02b885e786 100644 --- a/vllm/distributed/device_communicators/ray_communicator.py +++ b/vllm/distributed/device_communicators/ray_communicator.py @@ -14,7 +14,7 @@ ) from vllm.distributed.parallel_state import get_pp_group from vllm.logger import init_logger -from vllm.utils.torch import current_stream +from vllm.utils.torch_utils import current_stream logger = init_logger(__name__) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py index 1c59f9d8cce0..5b32a9756637 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py @@ -26,7 +26,7 @@ TensorMemoryPool, ) from vllm.utils import get_ip -from vllm.utils.torch import current_stream +from vllm.utils.torch_utils import current_stream logger = logging.getLogger(__name__) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 35ef01a1e6a6..10e2ba1d5925 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -53,7 +53,7 @@ get_distributed_init_method, resolve_obj_by_qualname, ) -from vllm.utils.torch import ( +from vllm.utils.torch_utils import ( direct_register_custom_op, supports_custom_op, ) diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index ab6d50a6c0e3..a5df81e55e36 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -30,7 +30,7 @@ import vllm.envs as envs from vllm.logger import init_logger from vllm.utils import get_tcp_uri -from vllm.utils.torch import is_torch_equal_or_newer +from vllm.utils.torch_utils import is_torch_equal_or_newer logger = init_logger(__name__) diff --git a/vllm/env_override.py b/vllm/env_override.py index 9c1789ef7fa2..30071f8ea46c 100644 --- a/vllm/env_override.py +++ b/vllm/env_override.py @@ -5,7 +5,7 @@ import torch from vllm.logger import init_logger -from vllm.utils.torch import is_torch_equal +from vllm.utils.torch_utils import is_torch_equal logger = init_logger(__name__) diff --git a/vllm/envs.py b/vllm/envs.py index 4c9927abb9bf..a7d294ab8298 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -246,7 +246,7 @@ def maybe_convert_bool(value: str | None) -> bool | None: def use_aot_compile() -> bool: - from vllm.utils.torch import is_torch_equal_or_newer + from vllm.utils.torch_utils import is_torch_equal_or_newer default_value = "1" if is_torch_equal_or_newer("2.10.0.dev") else "0" return os.environ.get("VLLM_USE_AOT_COMPILE", default_value) == "1" diff --git a/vllm/lora/ops/triton_ops/lora_expand_op.py b/vllm/lora/ops/triton_ops/lora_expand_op.py index bee0ea2edf56..fd4c1364de7e 100644 --- a/vllm/lora/ops/triton_ops/lora_expand_op.py +++ b/vllm/lora/ops/triton_ops/lora_expand_op.py @@ -12,7 +12,7 @@ from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr, get_lora_op_configs from vllm.triton_utils import tl, triton -from vllm.utils.torch import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op @triton.jit diff --git a/vllm/lora/ops/triton_ops/lora_shrink_op.py b/vllm/lora/ops/triton_ops/lora_shrink_op.py index 71ab13b4f36b..8c58915e3f79 100644 --- a/vllm/lora/ops/triton_ops/lora_shrink_op.py +++ b/vllm/lora/ops/triton_ops/lora_shrink_op.py @@ -12,7 +12,7 @@ from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr, get_lora_op_configs from vllm.triton_utils import tl, triton -from vllm.utils.torch import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op @triton.jit diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index 1430a5fe99af..f21fe16c5108 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -10,7 +10,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, ) -from vllm.utils.torch import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op def flashinfer_fused_moe_blockscale_fp8( diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 24a039ea0db6..c88982c1522b 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -52,7 +52,7 @@ from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used -from vllm.utils.torch import direct_register_custom_op, is_torch_equal_or_newer +from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index a34a93c4f99c..a9b48d0c8021 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -54,7 +54,7 @@ from vllm.platforms.interface import CpuArchEnum from vllm.utils import cdiv, has_deep_ep, has_pplx, round_up from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe -from vllm.utils.torch import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.worker.ubatching import dbo_current_ubatch_id if current_platform.is_cuda_alike(): diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 97f2c2957f8e..6edbb17c0a8e 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -11,7 +11,7 @@ FusedMoEQuantConfig, ) from vllm.platforms import current_platform -from vllm.utils.torch import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op class QuantMethod(IntEnum): diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 0cbafe18f275..0627ea50d821 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -25,7 +25,7 @@ from vllm.triton_utils import tl, triton from vllm.utils import cdiv from vllm.utils.flashinfer import flashinfer_fp4_quantize -from vllm.utils.torch import is_torch_equal_or_newer +from vllm.utils.torch_utils import is_torch_equal_or_newer @triton.jit diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index acefc11255a9..fe0d3a9e319c 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -13,7 +13,7 @@ vllm_kernel_override_batch_invariant, ) from vllm.platforms import current_platform -from vllm.utils.torch import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op def is_rocm_aiter_rmsnorm_enabled() -> bool: diff --git a/vllm/model_executor/layers/mamba/linear_attn.py b/vllm/model_executor/layers/mamba/linear_attn.py index d77d1f7898c5..a8f7f652452f 100644 --- a/vllm/model_executor/layers/mamba/linear_attn.py +++ b/vllm/model_executor/layers/mamba/linear_attn.py @@ -35,7 +35,7 @@ MambaStateShapeCalculator, ) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.utils.torch import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata if TYPE_CHECKING: diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 8913755dd1fa..a9a0c216474b 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -37,7 +37,7 @@ selective_state_update, ) from vllm.model_executor.utils import set_weight_attrs -from vllm.utils.torch import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index c9c4f8070664..fb45afa33dad 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -46,7 +46,7 @@ sharded_weight_loader, ) from vllm.model_executor.utils import set_weight_attrs -from vllm.utils.torch import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata # Added by the IBM Team, 2024 diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index f36070300f3d..91a45623582d 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -6,7 +6,7 @@ from vllm.config.cache import MambaDType from vllm.config.model import ModelDType from vllm.distributed import divide -from vllm.utils.torch import ( +from vllm.utils.torch_utils import ( STR_DTYPE_TO_TORCH_DTYPE, get_kv_cache_torch_dtype, ) diff --git a/vllm/model_executor/layers/mamba/short_conv.py b/vllm/model_executor/layers/mamba/short_conv.py index 0530e987c343..04efa8a8b373 100644 --- a/vllm/model_executor/layers/mamba/short_conv.py +++ b/vllm/model_executor/layers/mamba/short_conv.py @@ -27,7 +27,7 @@ causal_conv1d_fn, causal_conv1d_update, ) -from vllm.utils.torch import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionMetadata diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index a353c132fa2a..ccd9b311cc93 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -23,7 +23,7 @@ QuantizationMethods, ) from vllm.platforms import current_platform -from vllm.utils.torch import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op class BitsAndBytesConfig(QuantizationConfig): diff --git a/vllm/model_executor/layers/quantization/fp_quant.py b/vllm/model_executor/layers/quantization/fp_quant.py index 585088a73511..15a253cef0b7 100644 --- a/vllm/model_executor/layers/quantization/fp_quant.py +++ b/vllm/model_executor/layers/quantization/fp_quant.py @@ -24,7 +24,7 @@ from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform -from vllm.utils.torch import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op class FPQuantConfig(QuantizationConfig): diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 4bee58afdf03..8a914c57a9f7 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -28,7 +28,7 @@ ) from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.utils import set_weight_attrs -from vllm.utils.torch import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py index eb3d31c7162d..a19396a162bc 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py @@ -7,7 +7,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops from vllm.platforms import current_platform -from vllm.utils.torch import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op from .cutlass import CutlassScaledMMLinearKernel from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index f7a7741e8d46..504120ffeabe 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -52,7 +52,7 @@ round_up, ) from vllm.utils.flashinfer import has_flashinfer -from vllm.utils.torch import is_torch_equal_or_newer +from vllm.utils.torch_utils import is_torch_equal_or_newer logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py index 50cbaad78f17..c25c522dea55 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py @@ -45,7 +45,7 @@ def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool: from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4 from aiter.ops.triton.quant import dynamic_mxfp4_quant - from vllm.utils.torch import direct_register_custom_op + from vllm.utils.torch_utils import direct_register_custom_op if is_rocm_aiter_fp4_asm_gemm_enabled(): from aiter import gemm_a4w4, per_1x32_f4_quant_hip diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index b91d145da7a9..cb2d075c1a9d 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -34,7 +34,7 @@ is_deep_gemm_supported, should_use_deepgemm_for_fp8_linear, ) -from vllm.utils.torch import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py index 3fea2d0b6b8c..5e87cadfb107 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -7,7 +7,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils.torch import direct_register_custom_op, is_torch_equal_or_newer +from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/quantization/utils/mxfp6_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp6_utils.py index e33b8c6b096d..2b5659e30097 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp6_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp6_utils.py @@ -3,7 +3,7 @@ import torch from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_BLOCK_SIZE -from vllm.utils.torch import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op def _quant_dequant_mxfp6( diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 20dbb6bac13c..0d036ffdd286 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -13,7 +13,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.platforms import current_platform from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer -from vllm.utils.torch import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op # Input scaling factors are no longer optional in _scaled_mm starting # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale diff --git a/vllm/model_executor/layers/rotary_embedding/common.py b/vllm/model_executor/layers/rotary_embedding/common.py index bcdb3176289c..9e6ec9fdd523 100644 --- a/vllm/model_executor/layers/rotary_embedding/common.py +++ b/vllm/model_executor/layers/rotary_embedding/common.py @@ -10,7 +10,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils.torch import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op if current_platform.is_cuda(): from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb diff --git a/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py b/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py index e12235371545..a01d14f7b3a1 100644 --- a/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py +++ b/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py @@ -5,7 +5,7 @@ import vllm.envs as envs from vllm.platforms import current_platform -from vllm.utils.torch import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op def is_rocm_triton_rotary_embedding_enabled() -> bool: diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index e699c3ac7153..c1a48fa200ca 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -9,7 +9,7 @@ from vllm import _custom_ops as ops from vllm import envs from vllm.platforms import CpuArchEnum, current_platform -from vllm.utils.torch import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op def shuffle_weight(w: torch.Tensor) -> torch.Tensor: diff --git a/vllm/model_executor/model_loader/base_loader.py b/vllm/model_executor/model_loader/base_loader.py index 537a69894333..94dfa478245d 100644 --- a/vllm/model_executor/model_loader/base_loader.py +++ b/vllm/model_executor/model_loader/base_loader.py @@ -12,7 +12,7 @@ initialize_model, process_weights_after_loading, ) -from vllm.utils.torch import set_default_torch_dtype +from vllm.utils.torch_utils import set_default_torch_dtype logger = init_logger(__name__) diff --git a/vllm/model_executor/model_loader/bitsandbytes_loader.py b/vllm/model_executor/model_loader/bitsandbytes_loader.py index 8e88963c69e0..97c7a20bc4d5 100644 --- a/vllm/model_executor/model_loader/bitsandbytes_loader.py +++ b/vllm/model_executor/model_loader/bitsandbytes_loader.py @@ -48,7 +48,7 @@ set_weight_attrs, ) from vllm.platforms import current_platform -from vllm.utils.torch import set_default_torch_dtype +from vllm.utils.torch_utils import set_default_torch_dtype logger = init_logger(__name__) diff --git a/vllm/model_executor/model_loader/gguf_loader.py b/vllm/model_executor/model_loader/gguf_loader.py index 437d3337cf19..7db1fc167c4f 100644 --- a/vllm/model_executor/model_loader/gguf_loader.py +++ b/vllm/model_executor/model_loader/gguf_loader.py @@ -21,7 +21,7 @@ get_gguf_weight_type_map, gguf_quant_weights_iterator, ) -from vllm.utils.torch import set_default_torch_dtype +from vllm.utils.torch_utils import set_default_torch_dtype class GGUFModelLoader(BaseModelLoader): diff --git a/vllm/model_executor/model_loader/tensorizer_loader.py b/vllm/model_executor/model_loader/tensorizer_loader.py index 1fbaa95fed58..2b3704cfebba 100644 --- a/vllm/model_executor/model_loader/tensorizer_loader.py +++ b/vllm/model_executor/model_loader/tensorizer_loader.py @@ -23,7 +23,7 @@ get_model_architecture, initialize_model, ) -from vllm.utils.torch import set_default_torch_dtype +from vllm.utils.torch_utils import set_default_torch_dtype logger = init_logger(__name__) diff --git a/vllm/model_executor/model_loader/tpu.py b/vllm/model_executor/model_loader/tpu.py index 6b86e1def6fa..fc142f1f07fa 100644 --- a/vllm/model_executor/model_loader/tpu.py +++ b/vllm/model_executor/model_loader/tpu.py @@ -15,7 +15,7 @@ initialize_model, process_weights_after_loading, ) -from vllm.utils.torch import set_default_torch_dtype +from vllm.utils.torch_utils import set_default_torch_dtype logger = init_logger(__name__) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index e557f38817ec..da5d80f9828e 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -7,7 +7,7 @@ from vllm.logger import init_logger from vllm.model_executor.models import ModelRegistry from vllm.utils import cdiv, round_up -from vllm.utils.torch import STR_DTYPE_TO_TORCH_DTYPE +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec if TYPE_CHECKING: diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 28a7025bedf4..d2bc037b97dd 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -80,7 +80,7 @@ from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits -from vllm.utils.torch import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backends.mla.indexer import ( DeepseekV32IndexerBackend, DeepseekV32IndexerMetadata, diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index 65cd6b085534..a4c9360e50e0 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -50,7 +50,7 @@ from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from vllm.utils.collections import is_list_of from vllm.utils.tensor_schema import TensorSchema, TensorShape -from vllm.utils.torch import set_default_torch_dtype +from vllm.utils.torch_utils import set_default_torch_dtype from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import ( diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index ac2325dda4a6..e2d2647f0177 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -52,7 +52,7 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils.tensor_schema import TensorSchema, TensorShape -from vllm.utils.torch import set_default_torch_num_threads +from vllm.utils.torch_utils import set_default_torch_num_threads from .interfaces import ( MultiModalEmbeddings, diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 2c5337387229..b4a558ad6970 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -87,7 +87,7 @@ from vllm.sequence import IntermediateTensors from vllm.utils.collections import flatten_2d_lists from vllm.utils.tensor_schema import TensorSchema, TensorShape -from vllm.utils.torch import set_default_torch_dtype +from vllm.utils.torch_utils import set_default_torch_dtype from .idefics2_vision_model import Idefics2VisionTransformer from .interfaces import ( diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index e886fac410a9..09293f63f70e 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -64,7 +64,7 @@ ) from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors -from vllm.utils.torch import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index e097a471359f..06e94734376c 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -71,7 +71,7 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import Qwen3NextConfig from vllm.triton_utils import tl, triton -from vllm.utils.torch import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata from .interfaces import ( diff --git a/vllm/model_executor/models/transformers_moe.py b/vllm/model_executor/models/transformers_moe.py index 8ca79bbe42fc..43ea9a4869ed 100644 --- a/vllm/model_executor/models/transformers_moe.py +++ b/vllm/model_executor/models/transformers_moe.py @@ -28,7 +28,7 @@ from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.platforms import current_platform -from vllm.utils.torch import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op from .interfaces import MixtureOfExperts, SupportsMultiModal from .transformers import ( diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 9a95d614a9f0..66048fbd7d3c 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -27,7 +27,10 @@ is_pin_memory_available, is_uva_available, ) -from vllm.utils.torch import direct_register_custom_op, get_cuda_view_from_cpu_tensor +from vllm.utils.torch_utils import ( + direct_register_custom_op, + get_cuda_view_from_cpu_tensor, +) logger = init_logger(__name__) diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 62048a81972f..ccfe1871ef07 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -52,7 +52,7 @@ from vllm.transformers_utils.processor import cached_get_processor from vllm.utils.jsontree import json_map_leaves from vllm.utils.tensor_schema import TensorSchema, TensorShape -from vllm.utils.torch import set_default_torch_dtype +from vllm.utils.torch_utils import set_default_torch_dtype from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsTranscription from .utils import ( diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index 13a5152af48b..99651a408b31 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -8,7 +8,7 @@ from vllm import envs from vllm.plugins import PLATFORM_PLUGINS_GROUP, load_plugins_by_group from vllm.utils import resolve_obj_by_qualname -from vllm.utils.torch import supports_xccl +from vllm.utils.torch_utils import supports_xccl from .interface import CpuArchEnum, Platform, PlatformEnum diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 943cca678a22..c736e084a38d 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -17,7 +17,7 @@ import vllm.envs as envs from vllm.logger import init_logger from vllm.utils import import_pynvml -from vllm.utils.torch import cuda_device_count_stateless +from vllm.utils.torch_utils import cuda_device_count_stateless from .interface import DeviceCapability, Platform, PlatformEnum diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index ed27df55ee01..68e6c06c8814 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -9,7 +9,7 @@ import vllm.envs as envs from vllm.logger import init_logger -from vllm.utils.torch import cuda_device_count_stateless +from vllm.utils.torch_utils import cuda_device_count_stateless from .interface import DeviceCapability, Platform, PlatformEnum diff --git a/vllm/usage/usage_lib.py b/vllm/usage/usage_lib.py index 2c7abd2d7ecc..4211535131a4 100644 --- a/vllm/usage/usage_lib.py +++ b/vllm/usage/usage_lib.py @@ -22,7 +22,7 @@ from vllm.connections import global_http_connection from vllm.logger import init_logger from vllm.utils import cuda_get_device_properties -from vllm.utils.torch import cuda_device_count_stateless +from vllm.utils.torch_utils import cuda_device_count_stateless from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) diff --git a/vllm/utils/torch.py b/vllm/utils/torch_utils.py similarity index 100% rename from vllm/utils/torch.py rename to vllm/utils/torch_utils.py diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index a6ee8936fca5..d14f949b6579 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -29,7 +29,7 @@ vllm_kernel_override_batch_invariant, ) from vllm.utils import cdiv -from vllm.utils.torch import is_torch_equal_or_newer +from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index f9c1b69d491e..f7a4114a0a70 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -29,7 +29,7 @@ import aiter from vllm.triton_utils import tl, triton - from vllm.utils.torch import direct_register_custom_op + from vllm.utils.torch_utils import direct_register_custom_op @triton.jit def _vllm_layout_trans_kernel( diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index ab795493eea4..392519f8fa9a 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -11,7 +11,7 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.utils import cdiv -from vllm.utils.torch import get_dtype_size +from vllm.utils.torch_utils import get_dtype_size logger = init_logger(__name__) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 1836f58d0b6e..7a73177ba7d8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -82,7 +82,7 @@ round_up, ) from vllm.utils.jsontree import json_map_leaves -from vllm.utils.torch import ( +from vllm.utils.torch_utils import ( STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size, supports_dynamo, diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index ba839373356b..1a758386d1c9 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -27,7 +27,7 @@ from vllm.platforms.tpu import USE_TPU_INFERENCE from vllm.tasks import SupportedTask from vllm.utils import cdiv -from vllm.utils.torch import STR_DTYPE_TO_TORCH_DTYPE +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ModelRunnerOutput diff --git a/vllm/v1/worker/ubatching.py b/vllm/v1/worker/ubatching.py index 33ce20943d61..6edcb7848638 100644 --- a/vllm/v1/worker/ubatching.py +++ b/vllm/v1/worker/ubatching.py @@ -7,7 +7,7 @@ from vllm import forward_context from vllm.forward_context import ForwardContext -from vllm.utils.torch import current_stream +from vllm.utils.torch_utils import current_stream _THREAD_ID_TO_CONTEXT: dict = {} _CURRENT_CONTEXTS: list[Optional["UBatchContext"]] = [None, None] From e819af6c7f36293cf5888a579c34b2de95428f9b Mon Sep 17 00:00:00 2001 From: isotr0py <2037008807@qq.com> Date: Thu, 16 Oct 2025 23:52:46 +0800 Subject: [PATCH 12/15] move make_tensor_with_pad Signed-off-by: isotr0py <2037008807@qq.com> --- tests/kernels/utils.py | 2 +- tests/v1/sample/test_sampler.py | 3 +- tests/v1/sample/utils.py | 2 +- tests/v1/worker/test_gpu_input_batch.py | 3 +- vllm/utils/__init__.py | 62 ----------------------- vllm/utils/torch_utils.py | 66 ++++++++++++++++++++++++- vllm/v1/sample/ops/penalties.py | 3 +- 7 files changed, 73 insertions(+), 68 deletions(-) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 6c7ff984b433..eb00bc72b4b0 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -22,8 +22,8 @@ STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, STR_XFORMERS_ATTN_VAL, - make_tensor_with_pad, ) +from vllm.utils.torch_utils import make_tensor_with_pad # For now, disable "test_aot_dispatch_dynamic" since there are some # bugs related to this test in PyTorch 2.4. diff --git a/tests/v1/sample/test_sampler.py b/tests/v1/sample/test_sampler.py index edc6acae848a..a1513acc7b8e 100644 --- a/tests/v1/sample/test_sampler.py +++ b/tests/v1/sample/test_sampler.py @@ -7,7 +7,8 @@ from tests.v1.sample.utils import create_allowed_token_ids from vllm.platforms import current_platform -from vllm.utils import is_pin_memory_available, make_tensor_with_pad +from vllm.utils import is_pin_memory_available +from vllm.utils.torch_utils import make_tensor_with_pad from vllm.v1.sample.logits_processor import LogitsProcessors from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.sampler import Sampler diff --git a/tests/v1/sample/utils.py b/tests/v1/sample/utils.py index 5d457762fc64..a0abb3b4c6ce 100644 --- a/tests/v1/sample/utils.py +++ b/tests/v1/sample/utils.py @@ -9,7 +9,7 @@ import torch from vllm import CompletionOutput -from vllm.utils import make_tensor_with_pad +from vllm.utils.torch_utils import make_tensor_with_pad from vllm.v1.sample.logits_processor import BatchUpdate, LogitsProcessor from vllm.v1.sample.metadata import SamplingMetadata diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 5ab67dcf761e..132f0a58bbf5 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -10,7 +10,8 @@ from vllm.platforms import current_platform from vllm.sampling_params import SamplingParams -from vllm.utils import is_pin_memory_available, make_tensor_with_pad +from vllm.utils import is_pin_memory_available +from vllm.utils.torch_utils import make_tensor_with_pad from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.logits_processor import LogitsProcessors from vllm.v1.sample.metadata import SamplingMetadata diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index d9cabfb1ec7b..5c1910d6092e 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -54,8 +54,6 @@ import cbor2 import cloudpickle -import numpy as np -import numpy.typing as npt import psutil import regex as re import setproctitle @@ -120,16 +118,6 @@ RESET = "\033[0;0m" -TORCH_DTYPE_TO_NUMPY_DTYPE = { - torch.float16: np.float16, - torch.float32: np.float32, - torch.float64: np.float64, - torch.uint8: np.uint8, - torch.int32: np.int32, - torch.int64: np.int64, -} - - T = TypeVar("T") U = TypeVar("U") @@ -447,56 +435,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): gc.collect() -def make_ndarray_with_pad( - x: list[list[T]], - pad: T, - dtype: npt.DTypeLike, - *, - max_len: int | None = None, -) -> npt.NDArray: - """ - Make a padded array from 2D inputs. - - The padding is applied to the end of each inner list until it reaches - `max_len`. - """ - if max_len is None: - # Unlike for most functions, map is faster than a genexpr over `len` - max_len = max(map(len, x), default=0) - - padded_x = np.full((len(x), max_len), pad, dtype=dtype) - for ind, blocktb in enumerate(x): - assert len(blocktb) <= max_len - padded_x[ind, : len(blocktb)] = blocktb - - return padded_x - - -def make_tensor_with_pad( - x: list[list[T]], - pad: T, - dtype: torch.dtype, - *, - max_len: int | None = None, - device: str | torch.device | None = None, - pin_memory: bool = False, -) -> torch.Tensor: - """ - Make a padded tensor from 2D inputs. - - The padding is applied to the end of each inner list until it reaches - `max_len`. - """ - np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype] - padded_x = make_ndarray_with_pad(x, pad, np_dtype, max_len=max_len) - - tensor = torch.from_numpy(padded_x).to(device) - if pin_memory: - tensor = tensor.pin_memory() - - return tensor - - # TODO: This function can be removed if transformer_modules classes are # serialized by value when communicating between processes def init_cached_hf_modules() -> None: diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py index e7d14c10cbc1..ced2ccdfa95d 100644 --- a/vllm/utils/torch_utils.py +++ b/vllm/utils/torch_utils.py @@ -5,8 +5,10 @@ import threading from collections.abc import Callable, Collection from functools import lru_cache -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypeVar +import numpy as np +import numpy.typing as npt import torch from packaging import version from packaging.version import Version @@ -33,6 +35,18 @@ "fp8_ds_mla": torch.uint8, } +TORCH_DTYPE_TO_NUMPY_DTYPE = { + torch.float16: np.float16, + torch.float32: np.float32, + torch.float64: np.float64, + torch.uint8: np.uint8, + torch.int32: np.int32, + torch.int64: np.int64, +} + + +T = TypeVar("T") + @contextlib.contextmanager def set_default_torch_dtype(dtype: torch.dtype): @@ -252,6 +266,56 @@ def async_tensor_h2d( return t.to(device=target_device, non_blocking=True) +def make_ndarray_with_pad( + x: list[list[T]], + pad: T, + dtype: npt.DTypeLike, + *, + max_len: int | None = None, +) -> npt.NDArray: + """ + Make a padded array from 2D inputs. + + The padding is applied to the end of each inner list until it reaches + `max_len`. + """ + if max_len is None: + # Unlike for most functions, map is faster than a genexpr over `len` + max_len = max(map(len, x), default=0) + + padded_x = np.full((len(x), max_len), pad, dtype=dtype) + for ind, blocktb in enumerate(x): + assert len(blocktb) <= max_len + padded_x[ind, : len(blocktb)] = blocktb + + return padded_x + + +def make_tensor_with_pad( + x: list[list[T]], + pad: T, + dtype: torch.dtype, + *, + max_len: int | None = None, + device: str | torch.device | None = None, + pin_memory: bool = False, +) -> torch.Tensor: + """ + Make a padded tensor from 2D inputs. + + The padding is applied to the end of each inner list until it reaches + `max_len`. + """ + np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype] + padded_x = make_ndarray_with_pad(x, pad, np_dtype, max_len=max_len) + + tensor = torch.from_numpy(padded_x).to(device) + if pin_memory: + tensor = tensor.pin_memory() + + return tensor + + prev_set_stream = torch.cuda.set_stream _current_stream_tls = threading.local() diff --git a/vllm/v1/sample/ops/penalties.py b/vllm/v1/sample/ops/penalties.py index e49b8db47800..44f53d95dd3b 100644 --- a/vllm/v1/sample/ops/penalties.py +++ b/vllm/v1/sample/ops/penalties.py @@ -4,7 +4,8 @@ import torch from vllm.model_executor.layers.utils import apply_penalties -from vllm.utils import is_pin_memory_available, make_tensor_with_pad +from vllm.utils import is_pin_memory_available +from vllm.utils.torch_utils import make_tensor_with_pad def apply_all_penalties( From 1e13e3a3aaf6381bb07669ae404fdd8881dec989 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Fri, 17 Oct 2025 21:09:26 +0800 Subject: [PATCH 13/15] fix Signed-off-by: Isotr0py --- vllm/model_executor/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index 5ffee6cb8d8b..759b809433b1 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -7,7 +7,7 @@ import torch -from vllm.utils import is_torch_equal_or_newer +from vllm.utils.torch_utils import is_torch_equal_or_newer def set_random_seed(seed: int) -> None: From 6655659af36cfd29613167bcd3dfa0d41c0e3094 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Sat, 18 Oct 2025 13:54:47 +0800 Subject: [PATCH 14/15] fix Signed-off-by: Isotr0py --- tests/compile/test_fusions_e2e.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/compile/test_fusions_e2e.py b/tests/compile/test_fusions_e2e.py index 7399abaec542..efb5774b7870 100644 --- a/tests/compile/test_fusions_e2e.py +++ b/tests/compile/test_fusions_e2e.py @@ -15,8 +15,8 @@ from vllm import LLM, SamplingParams from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig from vllm.platforms import current_platform -from vllm.utils import is_torch_equal_or_newer from vllm.utils.flashinfer import has_flashinfer +from vllm.utils.torch_utils import is_torch_equal_or_newer from ..utils import flat_product, multi_gpu_test From 2c28a3aedf6f25b43abee34e01ad57b9143b626e Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Sat, 18 Oct 2025 22:49:06 +0800 Subject: [PATCH 15/15] move kv_cache_dtype_str_to_dtype Signed-off-by: Isotr0py --- vllm/utils/__init__.py | 9 --------- vllm/utils/torch_utils.py | 11 +++++++++++ vllm/v1/worker/gpu_model_runner.py | 2 +- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 1af15343b9d8..7cb3805fcbea 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -98,15 +98,6 @@ RESET = "\033[0;0m" -def kv_cache_dtype_str_to_dtype( - kv_cache_dtype: str, model_config: ModelConfig -) -> torch.dtype: - if kv_cache_dtype == "auto": - # Model config may not be specified for unit tests, default to float16 - return model_config.dtype if model_config else torch.half - return STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype] - - T = TypeVar("T") U = TypeVar("U") diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py index ced2ccdfa95d..adcacb34cb7c 100644 --- a/vllm/utils/torch_utils.py +++ b/vllm/utils/torch_utils.py @@ -17,8 +17,10 @@ import vllm.envs as envs if TYPE_CHECKING: + from vllm.config import ModelConfig from vllm.sequence import IntermediateTensors else: + ModelConfig = object IntermediateTensors = object @@ -164,6 +166,15 @@ def get_kv_cache_torch_dtype( return torch_dtype +def kv_cache_dtype_str_to_dtype( + kv_cache_dtype: str, model_config: ModelConfig +) -> torch.dtype: + if kv_cache_dtype == "auto": + # Model config may not be specified for unit tests, default to float16 + return model_config.dtype if model_config else torch.half + return STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype] + + def create_kv_caches_with_random_flash( num_blocks: int, block_size: int, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 319bfe6aca0a..258f2f460fba 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -72,7 +72,6 @@ cdiv, check_use_alibi, is_pin_memory_available, - kv_cache_dtype_str_to_dtype, length_from_prompt_token_ids_or_embeds, round_up, ) @@ -81,6 +80,7 @@ from vllm.utils.mem_utils import DeviceMemoryProfiler from vllm.utils.torch_utils import ( get_dtype_size, + kv_cache_dtype_str_to_dtype, supports_dynamo, ) from vllm.v1.attention.backends.flash_attn import AttentionMetadata