Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Rename PromptInputs to PromptType, and inputs to prompt #8673

Merged
merged 2 commits into from
Sep 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import DEVICE_OPTIONS, EngineArgs
from vllm.inputs import PromptInputs
from vllm.inputs import PromptType
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import FlexibleArgumentParser

Expand Down Expand Up @@ -61,7 +61,7 @@ def main(args: argparse.Namespace):
dummy_prompt_token_ids = np.random.randint(10000,
size=(args.batch_size,
args.input_len))
dummy_inputs: List[PromptInputs] = [{
dummy_prompts: List[PromptType] = [{
"prompt_token_ids": batch
} for batch in dummy_prompt_token_ids.tolist()]

Expand All @@ -74,13 +74,13 @@ def run_to_completion(profile_dir: Optional[str] = None):
],
on_trace_ready=torch.profiler.tensorboard_trace_handler(
str(profile_dir))) as p:
llm.generate(dummy_inputs,
llm.generate(dummy_prompts,
sampling_params=sampling_params,
use_tqdm=False)
print(p.key_averages())
else:
start_time = time.perf_counter()
llm.generate(dummy_inputs,
llm.generate(dummy_prompts,
sampling_params=sampling_params,
use_tqdm=False)
end_time = time.perf_counter()
Expand Down
2 changes: 1 addition & 1 deletion docs/source/dev/multimodal/multimodal_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Multi-Modality
vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package.

Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models <supported_vlms>`
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptInputs`.
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptType`.

Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities
by following :ref:`this guide <adding_multimodal_plugin>`.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/dev/offline_inference/llm_inputs.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
LLM Inputs
==========

.. autodata:: vllm.inputs.PromptInputs
.. autodata:: vllm.inputs.PromptType

.. autoclass:: vllm.inputs.TextPrompt
:show-inheritance:
Expand Down
2 changes: 1 addition & 1 deletion docs/source/models/vlm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ The :class:`~vllm.LLM` class can be instantiated in much the same way as languag
We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow
the above snippet. Specifically, ``image_feature_size`` can no longer be specified as we now calculate that internally for each model.

To pass an image to the model, note the following in :class:`vllm.inputs.PromptInputs`:
To pass an image to the model, note the following in :class:`vllm.inputs.PromptType`:

* ``prompt``: The prompt should follow the format that is documented on HuggingFace.
* ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`.
Expand Down
12 changes: 6 additions & 6 deletions tests/mq_llm_engine/test_error_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,15 @@ async def test_evil_forward(tmp_socket):

# Throws an error in first forward pass.
with pytest.raises(RAISED_ERROR):
async for _ in client.generate(inputs="Hello my name is",
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
request_id=uuid.uuid4()):
pass
assert client.errored

# Engine is errored, should get ENGINE_DEAD_ERROR.
with pytest.raises(MQEngineDeadError):
async for _ in client.generate(inputs="Hello my name is",
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
request_id=uuid.uuid4()):
pass
Expand Down Expand Up @@ -118,7 +118,7 @@ async def test_failed_health_check(tmp_socket):

# Generate call should throw ENGINE_DEAD_ERROR
with pytest.raises(MQEngineDeadError):
async for _ in client.generate(inputs="Hello my name is",
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
request_id=uuid.uuid4()):
pass
Expand Down Expand Up @@ -165,7 +165,7 @@ async def bad_abort_after_2s():
# with reference to the original KeyError("foo")
with pytest.raises(MQEngineDeadError) as execinfo:
async for _ in client.generate(
inputs="Hello my name is",
prompt="Hello my name is",
sampling_params=SamplingParams(max_tokens=2000),
request_id=uuid.uuid4()):
pass
Expand All @@ -190,7 +190,7 @@ async def test_bad_request(tmp_socket):

# Invalid request should fail, but not crash the server.
with pytest.raises(ValueError):
async for _ in client.generate(inputs="Hello my name is",
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
request_id="abcd-1",
lora_request=LoRARequest(
Expand All @@ -199,7 +199,7 @@ async def test_bad_request(tmp_socket):
pass

# This request should be okay.
async for _ in client.generate(inputs="Hello my name is",
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
request_id="abcd-2"):
pass
Expand Down
2 changes: 1 addition & 1 deletion tests/mq_llm_engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ async def generate(
count = 0
async for out in client.generate(
request_id=request_id,
inputs="Hello my name is Robert and",
prompt="Hello my name is Robert and",
sampling_params=SamplingParams(max_tokens=num_tokens,
temperature=0)):

Expand Down
4 changes: 2 additions & 2 deletions vllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.llm import LLM
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.model_executor.models import ModelRegistry
from vllm.outputs import (CompletionOutput, EmbeddingOutput,
EmbeddingRequestOutput, RequestOutput)
Expand All @@ -19,7 +19,7 @@
"__version__",
"LLM",
"ModelRegistry",
"PromptInputs",
"PromptType",
"TextPrompt",
"TokensPrompt",
"SamplingParams",
Expand Down
24 changes: 11 additions & 13 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutorAsync
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptInputs
from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
Expand Down Expand Up @@ -405,7 +405,7 @@ async def stop_remote_worker_execution_loop_async(self) -> None:
async def add_request_async(
self,
request_id: str,
inputs: PromptInputs,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
Expand All @@ -420,7 +420,7 @@ async def add_request_async(
arrival_time = time.time()

preprocessed_inputs = await self.input_preprocessor.preprocess_async(
inputs,
prompt,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
Expand Down Expand Up @@ -777,7 +777,7 @@ async def run_engine_loop(engine_ref: ReferenceType):
async def add_request(
self,
request_id: str,
inputs: PromptInputs,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
Expand All @@ -797,7 +797,7 @@ async def add_request(
stream = self._request_tracker.add_request(
request_id,
verbose=self.log_requests,
inputs=inputs,
prompt=prompt,
params=params,
arrival_time=arrival_time or time.time(),
lora_request=lora_request,
Expand All @@ -808,7 +808,7 @@ async def add_request(

async def generate(
self,
inputs: PromptInputs,
prompt: PromptType,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
Expand All @@ -822,8 +822,7 @@ async def generate(
from the LLMEngine to the caller.

Args:
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
for more details about the format of each input.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
Expand Down Expand Up @@ -881,7 +880,7 @@ async def generate(
"""
async for output in await self.add_request(
request_id,
inputs,
prompt,
sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
Expand All @@ -891,7 +890,7 @@ async def generate(

async def encode(
self,
inputs: PromptInputs,
prompt: PromptType,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
Expand All @@ -904,8 +903,7 @@ async def encode(
from the LLMEngine to the caller.

Args:
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
for more details about the format of each input.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
Expand Down Expand Up @@ -959,7 +957,7 @@ async def encode(
"""
async for output in await self.add_request(
request_id,
inputs,
prompt,
pooling_params,
lora_request=lora_request,
trace_headers=trace_headers,
Expand Down
9 changes: 4 additions & 5 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from vllm.executor.gpu_executor import GPUExecutor
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
InputRegistry, LLMInputs, PromptInputs)
InputRegistry, LLMInputs, PromptType)
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
Expand Down Expand Up @@ -680,7 +680,7 @@ def stop_remote_worker_execution_loop(self) -> None:
def add_request(
self,
request_id: str,
inputs: PromptInputs,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
Expand All @@ -695,8 +695,7 @@ def add_request(

Args:
request_id: The unique ID of the request.
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
for more details about the format of each input.
params: Parameters for sampling or pooling.
:class:`~vllm.SamplingParams` for text generation.
Expand Down Expand Up @@ -736,7 +735,7 @@ def add_request(
arrival_time = time.time()

preprocessed_inputs = self.input_preprocessor.preprocess(
inputs,
prompt,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
Expand Down
4 changes: 2 additions & 2 deletions vllm/engine/multiprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import List, Mapping, Optional, Union

from vllm import PoolingParams
from vllm.inputs import PromptInputs
from vllm.inputs import PromptType
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
Expand All @@ -23,7 +23,7 @@ class MQEngineDeadError(RuntimeError):

@dataclass
class RPCProcessRequest:
inputs: PromptInputs
prompt: PromptType
params: Union[SamplingParams, PoolingParams]
request_id: str
lora_request: Optional[LoRARequest] = None
Expand Down
20 changes: 9 additions & 11 deletions vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
RPCStartupResponse)
# yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT
from vllm.inputs import PromptInputs
from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
Expand Down Expand Up @@ -375,7 +375,7 @@ def dead_error(self) -> BaseException:

def generate(
self,
inputs: PromptInputs,
prompt: PromptType,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
Expand All @@ -389,8 +389,7 @@ def generate(
from the LLMEngine to the caller.

Args:
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
for more details about the format of each input.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
Expand All @@ -399,13 +398,13 @@ def generate(
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
"""
return self._process_request(inputs, sampling_params, request_id,
return self._process_request(prompt, sampling_params, request_id,
lora_request, trace_headers,
prompt_adapter_request)

def encode(
self,
inputs: PromptInputs,
prompt: PromptType,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
Expand All @@ -418,8 +417,7 @@ def encode(
from the LLMEngine to the caller.

Args:
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
for more details about the format of each input.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
Expand All @@ -430,12 +428,12 @@ def encode(
The output `EmbeddingRequestOutput` objects from the LLMEngine
for the request.
"""
return self._process_request(inputs, pooling_params, request_id,
return self._process_request(prompt, pooling_params, request_id,
lora_request, trace_headers)

async def _process_request(
self,
inputs: PromptInputs,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
request_id: str,
lora_request: Optional[LoRARequest] = None,
Expand Down Expand Up @@ -468,7 +466,7 @@ async def _process_request(

request_bytes = pickle.dumps(
RPCProcessRequest(
inputs=inputs,
prompt=prompt,
params=params,
request_id=request_id,
lora_request=lora_request,
Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/multiprocessing/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def _handle_process_request(self, request: RPCProcessRequest):
try:
self.engine.add_request(
request_id=request_id,
inputs=request.inputs,
prompt=request.prompt,
params=request.params,
lora_request=request.lora_request,
trace_headers=request.trace_headers,
Expand Down
Loading
Loading