diff --git a/tests/core/test_core_tpu.py b/tests/core/test_core_tpu.py index 82806f63b..a4b907109 100644 --- a/tests/core/test_core_tpu.py +++ b/tests/core/test_core_tpu.py @@ -143,11 +143,15 @@ def test_add_request(self): mock_engine_request.mm_inputs = [] mock_engine_request.use_structured_output = False mock_engine_request.kv_transfer_params = None + mock_engine_request.pooling_params = None + mock_engine_request.sampling_params.guided_decoding = None # Mock the prefill engine's scheduler mock_prefill_scheduler = self.mock_prefill_engine_instance.scheduler # Call the method under test + mock_engine_request, _ = proc.preprocess_add_request( + mock_engine_request) proc.add_request(mock_engine_request) # Assert that the request was added to the first prefill engine's scheduler diff --git a/tpu_commons/core/core_tpu.py b/tpu_commons/core/core_tpu.py index 61a9246be..076c9ee77 100644 --- a/tpu_commons/core/core_tpu.py +++ b/tpu_commons/core/core_tpu.py @@ -13,6 +13,7 @@ import jax from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.tasks import POOLING_TASKS from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, UtilityOutput, UtilityResult) @@ -229,40 +230,34 @@ def _create_engine_cores( return engine_cores - def _add_request(self, request: EngineCoreRequest) -> Request: - if request.mm_hashes is not None: - # Here, if hash exists for a multimodal input, then it will be - # fetched from the cache, else it will be added to the cache. - # Note that the cache here is mirrored with the client cache, so - # anything that has a hash must have a HIT cache entry here - # as well. - assert request.mm_inputs is not None - request.mm_inputs = self._prefill_engines[ - 0].mm_input_cache_server.get_and_update_p1( - request.mm_inputs, request.mm_hashes) + def add_request(self, request: EngineCoreRequest, request_wave: int = 0): + # vllm_request = self._add_request(request) - req = Request.from_engine_core_request(request) - - if req.use_structured_output: - # Start grammar compilation asynchronously - self._prefill_engines[0].structured_output_manager.grammar_init( - req) + # TODO(fhzhang): support multiple prefill engines. + if not isinstance(request.request_id, str): + raise TypeError( + f"request_id must be a string, got {type(request.request_id)}") - return req + if pooling_params := request.pooling_params: + supported_pooling_tasks = [ + task for task in self.get_supported_tasks() + if task in POOLING_TASKS + ] - def add_request(self, request: EngineCoreRequest): - vllm_request = self._add_request(request) + if pooling_params.task not in supported_pooling_tasks: + raise ValueError(f"Unsupported task: {pooling_params.task!r} " + f"Supported tasks: {supported_pooling_tasks}") - # TODO(fhzhang): support multiple prefill engines. - self._prefill_engines[0].scheduler.add_request(vllm_request) - self._requests[request.request_id] = vllm_request + self._prefill_engines[0].scheduler.add_request(request) + self._requests[request.request_id] = request def _handle_client_request(self, request_type: EngineCoreRequestType, request: Any) -> None: """Dispatch request from client.""" if request_type == EngineCoreRequestType.ADD: - self.add_request(request) + req, request_wave = request + self.add_request(req) elif request_type == EngineCoreRequestType.ABORT: # TODO(fhzhang): we need to keep track of which engine is processing # the request and finish it there.