Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/offline_inference_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

# Create an LLM.
model = LLM(model="intfloat/e5-mistral-7b-instruct", enforce_eager=True)
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
# Generate embedding. The output is a list of PoolingRequestOutputs.
outputs = model.encode(prompts)
# Print the outputs.
for output in outputs:
Expand Down
6 changes: 3 additions & 3 deletions tests/entrypoints/llm/test_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest

from vllm import LLM, EmbeddingRequestOutput, PoolingParams
from vllm import LLM, PoolingParams, PoolingRequestOutput
from vllm.distributed import cleanup_dist_env_and_memory

MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
Expand Down Expand Up @@ -43,8 +43,8 @@ def llm():
cleanup_dist_env_and_memory()


def assert_outputs_equal(o1: List[EmbeddingRequestOutput],
o2: List[EmbeddingRequestOutput]):
def assert_outputs_equal(o1: List[PoolingRequestOutput],
o2: List[PoolingRequestOutput]):
assert [o.outputs for o in o1] == [o.outputs for o in o2]


Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
import torch.cuda

from vllm.model_executor.models import (is_embedding_model,
from vllm.model_executor.models import (is_pooling_model,
is_text_generation_model,
supports_multimodal)
from vllm.model_executor.models.adapters import as_embedding_model
Expand Down Expand Up @@ -31,7 +31,7 @@ def test_registry_imports(model_arch):

# All vLLM models should be convertible to an embedding model
embed_model = as_embedding_model(model_cls)
assert is_embedding_model(embed_model)
assert is_pooling_model(embed_model)

if model_arch in _MULTIMODAL_MODELS:
assert supports_multimodal(model_cls)
Expand Down
4 changes: 2 additions & 2 deletions tests/worker/test_model_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from vllm.attention.backends.utils import CommonAttentionState
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.worker.embedding_model_runner import (
ModelInputForGPUWithPoolingMetadata)
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
from vllm.worker.multi_step_model_runner import StatefulModelInput
from vllm.worker.pooling_model_runner import (
ModelInputForGPUWithPoolingMetadata)


class MockAttentionBackend(AttentionBackend):
Expand Down
31 changes: 27 additions & 4 deletions vllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.model_executor.models import ModelRegistry
from vllm.outputs import (CompletionOutput, EmbeddingOutput,
EmbeddingRequestOutput, RequestOutput)
from vllm.outputs import (CompletionOutput, PoolingOutput,
PoolingRequestOutput, RequestOutput)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams

Expand All @@ -25,12 +25,35 @@
"SamplingParams",
"RequestOutput",
"CompletionOutput",
"EmbeddingOutput",
"EmbeddingRequestOutput",
"PoolingOutput",
"PoolingRequestOutput",
"LLMEngine",
"EngineArgs",
"AsyncLLMEngine",
"AsyncEngineArgs",
"initialize_ray_cluster",
"PoolingParams",
]


def __getattr__(name: str):
import warnings

if name == "EmbeddingOutput":
msg = ("EmbeddingOutput has been renamed to PoolingOutput. "
"The original name will be removed in an upcoming version.")

warnings.warn(DeprecationWarning(msg), stacklevel=2)

return PoolingOutput

if name == "EmbeddingRequestOutput":
msg = ("EmbeddingRequestOutput has been renamed to "
"PoolingRequestOutput. "
"The original name will be removed in an upcoming version.")

warnings.warn(DeprecationWarning(msg), stacklevel=2)

return PoolingRequestOutput

raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
2 changes: 1 addition & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def _resolve_task(
# NOTE: Listed from highest to lowest priority,
# in case the model supports multiple of them
"generate": ModelRegistry.is_text_generation_model(architectures),
"embedding": ModelRegistry.is_embedding_model(architectures),
"embedding": ModelRegistry.is_pooling_model(architectures),
}
supported_tasks_lst: List[_Task] = [
task for task, is_supported in task_support.items() if is_supported
Expand Down
24 changes: 12 additions & 12 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
Expand Down Expand Up @@ -74,7 +74,7 @@ def _log_task_completion(task: asyncio.Task,


class AsyncStream:
"""A stream of RequestOutputs or EmbeddingRequestOutputs for a request
"""A stream of RequestOutputs or PoolingRequestOutputs for a request
that can be iterated over asynchronously via an async generator."""

def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
Expand All @@ -83,7 +83,7 @@ def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
self._queue: asyncio.Queue = asyncio.Queue()
self._finished = False

def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
def put(self, item: Union[RequestOutput, PoolingRequestOutput,
Exception]) -> None:
if not self._finished:
self._queue.put_nowait(item)
Expand All @@ -103,7 +103,7 @@ def finished(self) -> bool:

async def generator(
self
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
try:
while True:
result = await self._queue.get()
Expand Down Expand Up @@ -154,7 +154,7 @@ def propagate_exception(self,

def process_request_output(self,
request_output: Union[RequestOutput,
EmbeddingRequestOutput],
PoolingRequestOutput],
*,
verbose: bool = False) -> None:
"""Process a request output from the engine."""
Expand Down Expand Up @@ -265,7 +265,7 @@ def __init__(self, *args, **kwargs):

async def step_async(
self, virtual_engine: int
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
) -> List[Union[RequestOutput, PoolingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results.
The workers are ran asynchronously if possible.

Expand Down Expand Up @@ -907,7 +907,7 @@ def add_request(
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> Coroutine[None, None, AsyncGenerator[Union[
RequestOutput, EmbeddingRequestOutput], None]]:
RequestOutput, PoolingRequestOutput], None]]:
...

@overload
Expand All @@ -922,7 +922,7 @@ def add_request(
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> Coroutine[None, None, AsyncGenerator[Union[
RequestOutput, EmbeddingRequestOutput], None]]:
RequestOutput, PoolingRequestOutput], None]]:
...

@deprecate_kwargs(
Expand All @@ -941,7 +941,7 @@ async def add_request(
priority: int = 0,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
if inputs is not None:
prompt = inputs
assert prompt is not None and params is not None
Expand Down Expand Up @@ -1070,7 +1070,7 @@ async def encode(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Generate outputs for a request from an embedding model.

Generate outputs for a request. This method is a coroutine. It adds the
Expand All @@ -1088,7 +1088,7 @@ async def encode(
Only applicable with priority scheduling.

Yields:
The output `EmbeddingRequestOutput` objects from the LLMEngine
The output `PoolingRequestOutput` objects from the LLMEngine
for the request.

Details:
Expand Down Expand Up @@ -1141,7 +1141,7 @@ async def encode(
trace_headers=trace_headers,
priority=priority,
):
yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
yield LLMEngine.validate_output(output, PoolingRequestOutput)

async def abort(self, request_id: str) -> None:
"""Abort a request.
Expand Down
8 changes: 4 additions & 4 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
get_local_guided_decoding_logits_processor)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
from vllm.outputs import (PoolingRequestOutput, RequestOutput,
RequestOutputFactory)
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
Expand Down Expand Up @@ -80,7 +80,7 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:


_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
_O = TypeVar("_O", RequestOutput, PoolingRequestOutput)


@dataclass
Expand Down Expand Up @@ -112,7 +112,7 @@ class SchedulerContext:
def __init__(self, multi_step_stream_outputs: bool = False):
self.output_queue: Deque[OutputData] = deque()
self.request_outputs: List[Union[RequestOutput,
EmbeddingRequestOutput]] = []
PoolingRequestOutput]] = []
self.seq_group_metadata_list: Optional[
List[SequenceGroupMetadata]] = None
self.scheduler_outputs: Optional[SchedulerOutputs] = None
Expand Down Expand Up @@ -1314,7 +1314,7 @@ def _advance_to_next_step(
else:
seq.append_token_id(sample.output_token, sample.logprobs)

def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results.

.. figure:: https://i.imgur.com/sv2HssD.png
Expand Down
14 changes: 7 additions & 7 deletions vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
Expand Down Expand Up @@ -495,7 +495,7 @@ def encode(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
) -> AsyncGenerator[PoolingRequestOutput, None]:
...

@overload
Expand All @@ -507,7 +507,7 @@ def encode(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
) -> AsyncGenerator[PoolingRequestOutput, None]:
...

@deprecate_kwargs(
Expand All @@ -524,7 +524,7 @@ def encode(
priority: int = 0,
*,
inputs: Optional[PromptType] = None # DEPRECATED
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Generate outputs for a request from an embedding model.

Generate outputs for a request. This method is a coroutine. It adds the
Expand All @@ -540,7 +540,7 @@ def encode(
trace_headers: OpenTelemetry trace headers.

Yields:
The output `EmbeddingRequestOutput` objects from the LLMEngine
The output `PoolingRequestOutput` objects from the LLMEngine
for the request.
"""
if inputs is not None:
Expand All @@ -549,7 +549,7 @@ def encode(
and request_id is not None)

return cast(
AsyncGenerator[EmbeddingRequestOutput, None],
AsyncGenerator[PoolingRequestOutput, None],
self._process_request(prompt,
pooling_params,
request_id,
Expand All @@ -567,7 +567,7 @@ async def _process_request(
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
EmbeddingRequestOutput, None]]:
PoolingRequestOutput, None]]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""

# If already dead, error out.
Expand Down
5 changes: 2 additions & 3 deletions vllm/engine/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput,
RequestOutput)
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams
Expand Down Expand Up @@ -209,7 +208,7 @@ def encode(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Generate outputs for a request from an embedding model."""
...

Expand Down
Loading