Skip to content

Commit

Permalink
[Core] Rename PromptInputs and inputs(#8673)
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 committed Sep 24, 2024
1 parent 2529d09 commit 6481cf3
Show file tree
Hide file tree
Showing 18 changed files with 148 additions and 153 deletions.
8 changes: 4 additions & 4 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import DEVICE_OPTIONS, EngineArgs
from vllm.inputs import PromptInputs
from vllm.inputs import PromptType
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import FlexibleArgumentParser

Expand Down Expand Up @@ -61,7 +61,7 @@ def main(args: argparse.Namespace):
dummy_prompt_token_ids = np.random.randint(10000,
size=(args.batch_size,
args.input_len))
dummy_inputs: List[PromptInputs] = [{
dummy_prompts: List[PromptType] = [{
"prompt_token_ids": batch
} for batch in dummy_prompt_token_ids.tolist()]

Expand All @@ -74,13 +74,13 @@ def run_to_completion(profile_dir: Optional[str] = None):
],
on_trace_ready=torch.profiler.tensorboard_trace_handler(
str(profile_dir))) as p:
llm.generate(dummy_inputs,
llm.generate(dummy_prompts,
sampling_params=sampling_params,
use_tqdm=False)
print(p.key_averages())
else:
start_time = time.perf_counter()
llm.generate(dummy_inputs,
llm.generate(dummy_prompts,
sampling_params=sampling_params,
use_tqdm=False)
end_time = time.perf_counter()
Expand Down
2 changes: 1 addition & 1 deletion docs/source/dev/multimodal/multimodal_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Multi-Modality
vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package.

Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models <supported_vlms>`
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptInputs`.
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptType`.

Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities
by following :ref:`this guide <adding_multimodal_plugin>`.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/dev/offline_inference/llm_inputs.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
LLM Inputs
==========

.. autodata:: vllm.inputs.PromptInputs
.. autodata:: vllm.inputs.PromptType

.. autoclass:: vllm.inputs.TextPrompt
:show-inheritance:
Expand Down
2 changes: 1 addition & 1 deletion docs/source/models/vlm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ The :class:`~vllm.LLM` class can be instantiated in much the same way as languag
We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow
the above snippet. Specifically, ``image_feature_size`` can no longer be specified as we now calculate that internally for each model.

To pass an image to the model, note the following in :class:`vllm.inputs.PromptInputs`:
To pass an image to the model, note the following in :class:`vllm.inputs.PromptType`:

* ``prompt``: The prompt should follow the format that is documented on HuggingFace.
* ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`.
Expand Down
12 changes: 6 additions & 6 deletions tests/mq_llm_engine/test_error_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,15 @@ async def test_evil_forward(tmp_socket):

# Throws an error in first forward pass.
with pytest.raises(RAISED_ERROR):
async for _ in client.generate(inputs="Hello my name is",
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
request_id=uuid.uuid4()):
pass
assert client.errored

# Engine is errored, should get ENGINE_DEAD_ERROR.
with pytest.raises(MQEngineDeadError):
async for _ in client.generate(inputs="Hello my name is",
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
request_id=uuid.uuid4()):
pass
Expand Down Expand Up @@ -118,7 +118,7 @@ async def test_failed_health_check(tmp_socket):

# Generate call should throw ENGINE_DEAD_ERROR
with pytest.raises(MQEngineDeadError):
async for _ in client.generate(inputs="Hello my name is",
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
request_id=uuid.uuid4()):
pass
Expand Down Expand Up @@ -165,7 +165,7 @@ async def bad_abort_after_2s():
# with reference to the original KeyError("foo")
with pytest.raises(MQEngineDeadError) as execinfo:
async for _ in client.generate(
inputs="Hello my name is",
prompt="Hello my name is",
sampling_params=SamplingParams(max_tokens=2000),
request_id=uuid.uuid4()):
pass
Expand All @@ -190,7 +190,7 @@ async def test_bad_request(tmp_socket):

# Invalid request should fail, but not crash the server.
with pytest.raises(ValueError):
async for _ in client.generate(inputs="Hello my name is",
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
request_id="abcd-1",
lora_request=LoRARequest(
Expand All @@ -199,7 +199,7 @@ async def test_bad_request(tmp_socket):
pass

# This request should be okay.
async for _ in client.generate(inputs="Hello my name is",
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
request_id="abcd-2"):
pass
Expand Down
2 changes: 1 addition & 1 deletion tests/mq_llm_engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ async def generate(
count = 0
async for out in client.generate(
request_id=request_id,
inputs="Hello my name is Robert and",
prompt="Hello my name is Robert and",
sampling_params=SamplingParams(max_tokens=num_tokens,
temperature=0)):

Expand Down
4 changes: 2 additions & 2 deletions vllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.llm import LLM
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.model_executor.models import ModelRegistry
from vllm.outputs import (CompletionOutput, EmbeddingOutput,
EmbeddingRequestOutput, RequestOutput)
Expand All @@ -19,7 +19,7 @@
"__version_tuple__",
"LLM",
"ModelRegistry",
"PromptInputs",
"PromptType",
"TextPrompt",
"TokensPrompt",
"SamplingParams",
Expand Down
24 changes: 11 additions & 13 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutorAsync
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptInputs
from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
Expand Down Expand Up @@ -405,7 +405,7 @@ async def stop_remote_worker_execution_loop_async(self) -> None:
async def add_request_async(
self,
request_id: str,
inputs: PromptInputs,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
Expand All @@ -420,7 +420,7 @@ async def add_request_async(
arrival_time = time.time()

preprocessed_inputs = await self.input_preprocessor.preprocess_async(
inputs,
prompt,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
Expand Down Expand Up @@ -777,7 +777,7 @@ async def run_engine_loop(engine_ref: ReferenceType):
async def add_request(
self,
request_id: str,
inputs: PromptInputs,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
Expand All @@ -797,7 +797,7 @@ async def add_request(
stream = self._request_tracker.add_request(
request_id,
verbose=self.log_requests,
inputs=inputs,
prompt=prompt,
params=params,
arrival_time=arrival_time or time.time(),
lora_request=lora_request,
Expand All @@ -808,7 +808,7 @@ async def add_request(

async def generate(
self,
inputs: PromptInputs,
prompt: PromptType,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
Expand All @@ -822,8 +822,7 @@ async def generate(
from the LLMEngine to the caller.
Args:
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
for more details about the format of each input.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
Expand Down Expand Up @@ -881,7 +880,7 @@ async def generate(
"""
async for output in await self.add_request(
request_id,
inputs,
prompt,
sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
Expand All @@ -891,7 +890,7 @@ async def generate(

async def encode(
self,
inputs: PromptInputs,
prompt: PromptType,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
Expand All @@ -904,8 +903,7 @@ async def encode(
from the LLMEngine to the caller.
Args:
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
for more details about the format of each input.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
Expand Down Expand Up @@ -959,7 +957,7 @@ async def encode(
"""
async for output in await self.add_request(
request_id,
inputs,
prompt,
pooling_params,
lora_request=lora_request,
trace_headers=trace_headers,
Expand Down
9 changes: 4 additions & 5 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from vllm.executor.gpu_executor import GPUExecutor
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
InputRegistry, LLMInputs, PromptInputs)
InputRegistry, LLMInputs, PromptType)
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
Expand Down Expand Up @@ -689,7 +689,7 @@ def stop_remote_worker_execution_loop(self) -> None:
def add_request(
self,
request_id: str,
inputs: PromptInputs,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
Expand All @@ -704,8 +704,7 @@ def add_request(
Args:
request_id: The unique ID of the request.
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
for more details about the format of each input.
params: Parameters for sampling or pooling.
:class:`~vllm.SamplingParams` for text generation.
Expand Down Expand Up @@ -745,7 +744,7 @@ def add_request(
arrival_time = time.time()

preprocessed_inputs = self.input_preprocessor.preprocess(
inputs,
prompt,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
Expand Down
4 changes: 2 additions & 2 deletions vllm/engine/multiprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import List, Mapping, Optional, Union

from vllm import PoolingParams
from vllm.inputs import PromptInputs
from vllm.inputs import PromptType
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
Expand All @@ -23,7 +23,7 @@ class MQEngineDeadError(RuntimeError):

@dataclass
class RPCProcessRequest:
inputs: PromptInputs
prompt: PromptType
params: Union[SamplingParams, PoolingParams]
request_id: str
lora_request: Optional[LoRARequest] = None
Expand Down
20 changes: 9 additions & 11 deletions vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
RPCStartupResponse)
# yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT
from vllm.inputs import PromptInputs
from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
Expand Down Expand Up @@ -375,7 +375,7 @@ def dead_error(self) -> BaseException:

def generate(
self,
inputs: PromptInputs,
prompt: PromptType,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
Expand All @@ -389,8 +389,7 @@ def generate(
from the LLMEngine to the caller.
Args:
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
for more details about the format of each input.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
Expand All @@ -399,13 +398,13 @@ def generate(
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
"""
return self._process_request(inputs, sampling_params, request_id,
return self._process_request(prompt, sampling_params, request_id,
lora_request, trace_headers,
prompt_adapter_request)

def encode(
self,
inputs: PromptInputs,
prompt: PromptType,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
Expand All @@ -418,8 +417,7 @@ def encode(
from the LLMEngine to the caller.
Args:
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
for more details about the format of each input.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
Expand All @@ -430,12 +428,12 @@ def encode(
The output `EmbeddingRequestOutput` objects from the LLMEngine
for the request.
"""
return self._process_request(inputs, pooling_params, request_id,
return self._process_request(prompt, pooling_params, request_id,
lora_request, trace_headers)

async def _process_request(
self,
inputs: PromptInputs,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
request_id: str,
lora_request: Optional[LoRARequest] = None,
Expand Down Expand Up @@ -468,7 +466,7 @@ async def _process_request(

request_bytes = pickle.dumps(
RPCProcessRequest(
inputs=inputs,
prompt=prompt,
params=params,
request_id=request_id,
lora_request=lora_request,
Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/multiprocessing/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def _handle_process_request(self, request: RPCProcessRequest):
try:
self.engine.add_request(
request_id=request_id,
inputs=request.inputs,
prompt=request.prompt,
params=request.params,
lora_request=request.lora_request,
trace_headers=request.trace_headers,
Expand Down
Loading

0 comments on commit 6481cf3

Please sign in to comment.