From 92715bdfe601995f9faebad3e8d9d4e3a760ef9b Mon Sep 17 00:00:00 2001 From: Christian Pinto Date: Mon, 14 Jul 2025 14:39:17 +0000 Subject: [PATCH 1/2] Initial implementation of hidden_state_processors via plugins Signed-off-by: Christian Pinto --- vllm/config.py | 6 +- vllm/engine/arg_utils.py | 4 + vllm/entrypoints/llm.py | 5 ++ vllm/envs.py | 5 ++ vllm/outputs.py | 6 +- .../hidden_states_processors/__init__.py | 77 +++++++++++++++++++ .../hidden_states_processors/default.py | 19 +++++ .../hidden_states_processors/interface.py | 19 +++++ vllm/v1/core/sched/scheduler.py | 6 ++ vllm/v1/engine/__init__.py | 1 + vllm/v1/engine/async_llm.py | 7 +- vllm/v1/engine/llm_engine.py | 3 +- vllm/v1/engine/output_processor.py | 35 +++++++-- vllm/v1/outputs.py | 3 + vllm/v1/worker/gpu_model_runner.py | 18 ++++- 15 files changed, 201 insertions(+), 13 deletions(-) create mode 100644 vllm/plugins/hidden_states_processors/__init__.py create mode 100644 vllm/plugins/hidden_states_processors/default.py create mode 100644 vllm/plugins/hidden_states_processors/interface.py diff --git a/vllm/config.py b/vllm/config.py index 07df71ec51ef..e1f55a592dcf 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -425,6 +425,9 @@ class ModelConfig: - "transformers" will use the Transformers model implementation.""" override_attention_dtype: Optional[str] = None """Override dtype for attention""" + process_hidden_states: Optional[bool] = False + """Extract the hidden states of the model to be processed before the request + is completed. This is so far only supported for embedding/pooling models """ def compute_hash(self) -> str: """ @@ -4820,7 +4823,8 @@ def __str__(self): f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, " # noqa f"use_async_output_proc={self.model_config.use_async_output_proc}, " f"pooler_config={self.model_config.pooler_config!r}, " - f"compilation_config={self.compilation_config!r}") + f"compilation_config={self.compilation_config!r}" + f"process_hidden_states={self.model_config.process_hidden_states}") _current_vllm_config: Optional[VllmConfig] = None diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 709968004718..6209671ad9c7 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -350,6 +350,7 @@ class EngineArgs: MultiModalConfig.mm_processor_kwargs disable_mm_preprocessor_cache: bool = \ MultiModalConfig.disable_mm_preprocessor_cache + process_hidden_states: bool = False # LoRA fields enable_lora: bool = False enable_lora_bias: bool = LoRAConfig.bias_enabled @@ -503,6 +504,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: **model_kwargs["enable_prompt_embeds"]) model_group.add_argument("--served-model-name", **model_kwargs["served_model_name"]) + model_group.add_argument("--process-hidden-states", + **model_kwargs["process_hidden_states"]) # This one is a special case because it is the # opposite of ModelConfig.use_async_output_proc model_group.add_argument( @@ -910,6 +913,7 @@ def create_model_config(self) -> ModelConfig: enable_sleep_mode=self.enable_sleep_mode, model_impl=self.model_impl, override_attention_dtype=self.override_attention_dtype, + process_hidden_states=self.process_hidden_states, ) def validate_tensorizer_args(self): diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 2c961156bc84..2250caa418f4 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -147,6 +147,9 @@ class LLM: compilation_config: Either an integer or a dictionary. If it is an integer, it is used as the level of compilation optimization. If it is a dictionary, it can specify the full compilation configuration. + process_hidden_states: If True, it loads the hidden states processor + and to process the hiddne states for each request before returning + to the user. **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs]. Note: @@ -195,6 +198,7 @@ def __init__( override_pooler_config: Optional[PoolerConfig] = None, compilation_config: Optional[Union[int, dict[str, Any], CompilationConfig]] = None, + process_hidden_states: bool = False, **kwargs, ) -> None: """LLM constructor.""" @@ -268,6 +272,7 @@ def __init__( mm_processor_kwargs=mm_processor_kwargs, override_pooler_config=override_pooler_config, compilation_config=compilation_config_instance, + process_hidden_states=process_hidden_states, **kwargs, ) diff --git a/vllm/envs.py b/vllm/envs.py index 0eff741519ae..1f6f0ee73b6f 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -991,6 +991,11 @@ def get_vllm_port() -> Optional[int]: # The default value is "VLLM". "VLLM_PROCESS_NAME_PREFIX": lambda: os.getenv("VLLM_PROCESS_NAME_PREFIX", "VLLM"), + # Controls which hidden states processor plugin to load. + # This is used when more than a hidden states processor is installed + # to decide which one to use. + "VLLM_USE_HIDDEN_STATES_PROCESSOR": + lambda: os.getenv("VLLM_USE_HIDDEN_STATES_PROCESSOR", None), } # --8<-- [end:env-vars-definition] diff --git a/vllm/outputs.py b/vllm/outputs.py index 9784a8894472..38efc86c89ff 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -69,9 +69,13 @@ class PoolingOutput: data: The extracted hidden states. """ data: torch.Tensor + processed_hidden_states: Optional[Any] = None def __repr__(self) -> str: - return (f"PoolingOutput(data={self.data})") + hidden_states = ("None" if not self.processed_hidden_states else type( + self.processed_hidden_states).__name__) + return (f"PoolingOutput(data={self.data}" + f"Processed hidden states={hidden_states})") def __eq__(self, other: object) -> bool: return (isinstance(other, self.__class__) and bool( diff --git a/vllm/plugins/hidden_states_processors/__init__.py b/vllm/plugins/hidden_states_processors/__init__.py new file mode 100644 index 000000000000..a1424c68325d --- /dev/null +++ b/vllm/plugins/hidden_states_processors/__init__.py @@ -0,0 +1,77 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import logging +from typing import Optional + +import vllm.envs as envs +from vllm.config import VllmConfig +from vllm.plugins import load_plugins_by_group +from vllm.plugins.hidden_states_processors.interface import ( + HiddenStatesProcessor) +from vllm.utils import resolve_obj_by_qualname + +logger = logging.getLogger(__name__) + + +def identity_hidden_states_processor() -> str: + return ("vllm.plugins.hidden_states_processors." + "default.IdentityHiddenStatesProcessor") + + +default_hidden_states_processors = { + "identity": identity_hidden_states_processor +} + + +def get_hidden_states_processor( + vllm_config: VllmConfig) -> Optional["HiddenStatesProcessor"]: + # hidden states processors are loaded as plugins under the + # 'vllm.hidden_state_processor_plugins group. Similar to platform + # plugins, these plugins register a function that returns the class + # name for the processor to install. + # All hidden state plugins implement the HiddenStatesProcessor class + + hidden_states_processor_plugins = \ + load_plugins_by_group('vllm.hidden_states_processor_plugins') + + available_plugins = { + **default_hidden_states_processors, + **hidden_states_processor_plugins + } + + loadable_plugins = {} + for name, func in available_plugins.items(): + try: + assert callable(func) + processor_cls_qualname = func() + if processor_cls_qualname is not None: + loadable_plugins[name] = processor_cls_qualname + except Exception: + pass + + num_available_plugins = len(loadable_plugins.keys()) + + # Just a sanity check to make sure we are not + # messing up with the available plugins + assert num_available_plugins > 0 + + if num_available_plugins > 1 and envs.VLLM_USE_HIDDEN_STATES_PROCESSOR: + activated_plugin_cls = loadable_plugins[ + envs.VLLM_USE_HIDDEN_STATES_PROCESSOR] + activated_plugin_name = envs.VLLM_USE_HIDDEN_STATES_PROCESSOR + else: + activated_plugin_name = list(loadable_plugins.keys())[0] + activated_plugin_cls = loadable_plugins[activated_plugin_name] + if (num_available_plugins > 1 + and not envs.VLLM_USE_HIDDEN_STATES_PROCESSOR): + logger.info( + "Multiple hidden states processor plugins available " + "but VLLM_USE_HIDDEN_STATES_PROCESSOR is not pointing " + "to any specific plugins. Loading the first available one.\n" + "Available hidden states " + "processor plugins %s", str(loadable_plugins.keys())) + + logger.info("Loaded hidden states processor plugin: %s", + activated_plugin_name) + return resolve_obj_by_qualname(activated_plugin_cls)(vllm_config) diff --git a/vllm/plugins/hidden_states_processors/default.py b/vllm/plugins/hidden_states_processors/default.py new file mode 100644 index 000000000000..0b0ae0291d62 --- /dev/null +++ b/vllm/plugins/hidden_states_processors/default.py @@ -0,0 +1,19 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any + +import torch + +from vllm.plugins.hidden_states_processors.interface import ( + HiddenStatesProcessor) + + +class IdentityHiddenStatesProcessor(HiddenStatesProcessor): + + def apply(self, data: torch.Tensor) -> Any: + """ + This is the default identity hidden states processor + that returns the hidden_states data as is + """ + return data diff --git a/vllm/plugins/hidden_states_processors/interface.py b/vllm/plugins/hidden_states_processors/interface.py new file mode 100644 index 000000000000..726198532f10 --- /dev/null +++ b/vllm/plugins/hidden_states_processors/interface.py @@ -0,0 +1,19 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import ABC, abstractmethod +from typing import Any + +import torch + +from vllm.config import VllmConfig + + +class HiddenStatesProcessor(ABC): + + def __init__(self, vllm_config: VllmConfig): + self.vllm_config = vllm_config + + @abstractmethod + def apply(self, data: torch.Tensor) -> Any: + ... \ No newline at end of file diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 446f98034cb8..bd0fd08c3c2a 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -757,6 +757,7 @@ def update_from_output( num_scheduled_tokens = scheduler_output.num_scheduled_tokens pooler_outputs = model_runner_output.pooler_output num_nans_in_logits = model_runner_output.num_nans_in_logits + hidden_states = model_runner_output.hidden_states outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) spec_decoding_stats: Optional[SpecDecodingStats] = None @@ -821,6 +822,10 @@ def update_from_output( else: stopped_preempted_reqs.add(request) + req_hidden_states = None + if hidden_states: + req_hidden_states = hidden_states[req_index] + # Extract sample logprobs if needed. if request.sampling_params is not None \ and request.sampling_params.logprobs is not None and logprobs: @@ -864,6 +869,7 @@ def update_from_output( new_logprobs=new_logprobs, new_prompt_logprobs_tensors=prompt_logprobs_tensors, pooling_output=pooler_output, + hidden_states=req_hidden_states, stop_reason=request.stop_reason, events=request.take_events(), kv_transfer_params=kv_transfer_params, diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 79dc80d8fc54..611b08fdbc62 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -109,6 +109,7 @@ class EngineCoreOutput( new_prompt_logprobs_tensors: Optional[LogprobsTensors] = None pooling_output: Optional[torch.Tensor] = None + hidden_states: Optional[torch.Tensor] = None finish_reason: Optional[FinishReason] = None stop_reason: Union[int, str, None] = None diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index ed0d9620f476..4f62fc9f0275 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -111,8 +111,11 @@ def __init__( ) # OutputProcessor (converts EngineCoreOutputs --> RequestOutput). - self.output_processor = OutputProcessor(self.tokenizer, - log_stats=self.log_stats) + self.output_processor = OutputProcessor( + vllm_config=vllm_config, + tokenizer=self.tokenizer, + log_stats=self.log_stats, + ) # EngineCore (starts the engine in background process). self.engine_core = EngineCoreClient.make_async_mp_client( diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index efbdffbc0900..57bead357b83 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -97,7 +97,8 @@ def __init__( mm_registry=mm_registry) # OutputProcessor (convert EngineCoreOutputs --> RequestOutput). - self.output_processor = OutputProcessor(self.tokenizer, + self.output_processor = OutputProcessor(vllm_config=vllm_config, + tokenizer=self.tokenizer, log_stats=self.log_stats) # EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs) diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 3be6c4821214..bf6fddf32030 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -8,8 +8,10 @@ import torch +from vllm.config import VllmConfig from vllm.outputs import (CompletionOutput, PoolingOutput, PoolingRequestOutput, RequestOutput) +from vllm.plugins.hidden_states_processors import get_hidden_states_processor from vllm.sampling_params import RequestOutputKind from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import TokenizerGroup @@ -164,6 +166,7 @@ def make_request_output( self, new_token_ids: list[int], pooling_output: Optional[torch.Tensor], + processed_hidden_states: Optional[Any], finish_reason: Optional[FinishReason], stop_reason: Union[int, str, None], kv_transfer_params: Optional[dict[str, Any]] = None, @@ -179,9 +182,12 @@ def make_request_output( request_id = self.request_id if pooling_output is not None: - return self._new_request_output( - request_id, [self._new_pooling_output(pooling_output)], - finished) + output = self._new_pooling_output( + pooling_output, + processed_hidden_states=processed_hidden_states) + return self._new_request_output(request_id=request_id, + outputs=[output], + finished=finished) output = self._new_completion_output(new_token_ids, finish_reason, stop_reason) @@ -266,9 +272,11 @@ def _new_completion_output( def _new_pooling_output( self, pooling_output: torch.Tensor, + processed_hidden_states: Any, ) -> PoolingOutput: - return PoolingOutput(data=pooling_output) + return PoolingOutput(data=pooling_output, + processed_hidden_states=processed_hidden_states) class OutputProcessor: @@ -276,6 +284,7 @@ class OutputProcessor: def __init__( self, + vllm_config: VllmConfig, tokenizer: TokenizerGroup, log_stats: bool, ): @@ -284,6 +293,11 @@ def __init__( self.request_states: dict[str, RequestState] = {} self.parent_requests: dict[str, ParentRequest] = {} self.lora_states = LoRARequestStates() + if vllm_config.model_config.process_hidden_states: + if not (processor := (get_hidden_states_processor(vllm_config))): + raise ValueError( + "Process hidden states is set but no processor plugins") + self.hidden_states_processor = processor def get_num_unfinished_requests(self): return len(self.request_states) @@ -391,6 +405,7 @@ def process_outputs( stop_reason = engine_core_output.stop_reason kv_transfer_params = engine_core_output.kv_transfer_params num_cached_tokens = engine_core_output.num_cached_tokens + hidden_states = engine_core_output.hidden_states req_state.is_prefilling = False if pooling_output is None: @@ -408,10 +423,18 @@ def process_outputs( req_state.logprobs_processor.update_from_output( engine_core_output) + if pooling_output is not None and hidden_states is not None: + # Currently we process hidden states only for pooling models + processed_hidden_states = \ + self.hidden_states_processor.apply(hidden_states) + else: + processed_hidden_states = None + # 4) Create and handle RequestOutput objects. if request_output := req_state.make_request_output( - new_token_ids, pooling_output, finish_reason, stop_reason, - kv_transfer_params, num_cached_tokens): + new_token_ids, pooling_output, processed_hidden_states, + finish_reason, stop_reason, kv_transfer_params, + num_cached_tokens): if req_state.queue is not None: # AsyncLLM: put into queue for handling by generate(). req_state.queue.put(request_output) diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index f78623f571b2..5328aad22e49 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -111,6 +111,9 @@ class ModelRunnerOutput: # req_id -> num_nans_in_logits num_nans_in_logits: Optional[dict[str, int]] = None + # This is used for pooling models that install a hidden states processor + hidden_states: Optional[list[torch.Tensor]] = None + EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[], req_id_to_index={}, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5fe594db667a..d1986f95c053 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1152,6 +1152,18 @@ def _gather_mm_embeddings( mm_embeds.append(mm_embeds_item) return mm_embeds + def _maybe_return_hidden_states( + self, + hidden_states: torch.Tensor, + ) -> list[torch.Tensor]: + final_hidden_states: list[torch.Tensor] = [] + if self.vllm_config.model_config.process_hidden_states: + final_hidden_states = [] + for hidden_state in hidden_states: + final_hidden_states.append(hidden_state.cpu()) + + return final_hidden_states + def get_model(self) -> nn.Module: return self.model @@ -1359,6 +1371,9 @@ def _pool( else: pooler_output.append(None) + return_hidden_states = self._maybe_return_hidden_states( + extracted_hidden_states) + return ModelRunnerOutput( req_ids=self.input_batch.req_ids, req_id_to_index=self.input_batch.req_id_to_index, @@ -1367,8 +1382,7 @@ def _pool( logprobs=None, prompt_logprobs_dict={}, pooler_output=pooler_output, - finished_sending=finished_sending, - finished_recving=finished_recving, + hidden_states=return_hidden_states, ) @torch.inference_mode() From 3f81e8252fd52b10855903fcc34febfbb9e05d86 Mon Sep 17 00:00:00 2001 From: Christian Pinto Date: Fri, 25 Jul 2025 15:33:58 +0000 Subject: [PATCH 2/2] Some minor fixes Signed-off-by: Christian Pinto --- vllm/plugins/hidden_states_processors/__init__.py | 9 +++++++-- vllm/v1/worker/gpu_model_runner.py | 2 ++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/vllm/plugins/hidden_states_processors/__init__.py b/vllm/plugins/hidden_states_processors/__init__.py index a1424c68325d..a14ce9e4ec93 100644 --- a/vllm/plugins/hidden_states_processors/__init__.py +++ b/vllm/plugins/hidden_states_processors/__init__.py @@ -57,8 +57,13 @@ def get_hidden_states_processor( assert num_available_plugins > 0 if num_available_plugins > 1 and envs.VLLM_USE_HIDDEN_STATES_PROCESSOR: - activated_plugin_cls = loadable_plugins[ - envs.VLLM_USE_HIDDEN_STATES_PROCESSOR] + plugin_name = envs.VLLM_USE_HIDDEN_STATES_PROCESSOR + if plugin_name not in loadable_plugins: + raise ValueError( + f"Hidden states processor plugin '{plugin_name}' not found. " + f"Available plugins: {list(loadable_plugins.keys())}") + + activated_plugin_cls = loadable_plugins[plugin_name] activated_plugin_name = envs.VLLM_USE_HIDDEN_STATES_PROCESSOR else: activated_plugin_name = list(loadable_plugins.keys())[0] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d1986f95c053..98997a7c006b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1383,6 +1383,8 @@ def _pool( prompt_logprobs_dict={}, pooler_output=pooler_output, hidden_states=return_hidden_states, + finished_sending=finished_sending, + finished_recving=finished_recving, ) @torch.inference_mode()