diff --git a/.github/workflows/vllm_ascend_test.yaml b/.github/workflows/vllm_ascend_test.yaml index 99afe503ba..4ecc7dc551 100644 --- a/.github/workflows/vllm_ascend_test.yaml +++ b/.github/workflows/vllm_ascend_test.yaml @@ -136,12 +136,14 @@ jobs: pytest -sv tests/singlecard/test_camem.py # test_ascend_config.py should be ran separately because it will regenerate the global config many times. pytest -sv tests/singlecard/test_ascend_config.py + pytest -sv tests/singlecard/test_prompt_embedding.py pytest -sv tests/singlecard/ \ --ignore=tests/singlecard/test_offline_inference.py \ --ignore=tests/singlecard/test_scheduler.py \ --ignore=tests/singlecard/test_guided_decoding.py \ --ignore=tests/singlecard/test_camem.py \ - --ignore=tests/singlecard/test_ascend_config.py + --ignore=tests/singlecard/test_ascend_config.py \ + --ignore=tests/singlecard/test_prompt_embedding.py else pytest -sv tests/multicard/test_ilama_lora_tp2.py # Fixme: run VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py will raise error. diff --git a/examples/prompt_embedding_inference.py b/examples/prompt_embedding_inference.py new file mode 100644 index 0000000000..e375a8b4f9 --- /dev/null +++ b/examples/prompt_embedding_inference.py @@ -0,0 +1,83 @@ +import torch +from transformers import (AutoModelForCausalLM, AutoTokenizer, + PreTrainedTokenizer) +from vllm import LLM + + +def init_tokenizer_and_llm(model_name: str): + tokenizer = AutoTokenizer.from_pretrained(model_name) + transformers_model = AutoModelForCausalLM.from_pretrained(model_name) + embedding_layer = transformers_model.get_input_embeddings() + llm = LLM(model=model_name, enable_prompt_embeds=True) + return tokenizer, embedding_layer, llm + + +def get_prompt_embeds(chat: list[dict[str, + str]], tokenizer: PreTrainedTokenizer, + embedding_layer: torch.nn.Module): + token_ids = tokenizer.apply_chat_template(chat, + add_generation_prompt=True, + return_tensors='pt') + prompt_embeds = embedding_layer(token_ids).squeeze(0) + return prompt_embeds + + +def single_prompt_inference(llm: LLM, tokenizer: PreTrainedTokenizer, + embedding_layer: torch.nn.Module): + chat = [{ + "role": "user", + "content": "Please tell me about the capital of France." + }] + prompt_embeds = get_prompt_embeds(chat, tokenizer, embedding_layer) + + outputs = llm.generate({ + "prompt_embeds": prompt_embeds, + }) + + print("\n[Single Inference Output]") + print("-" * 30) + for o in outputs: + print(o.outputs[0].text) + print("-" * 30) + + +def batch_prompt_inference(llm: LLM, tokenizer: PreTrainedTokenizer, + embedding_layer: torch.nn.Module): + chats = [[{ + "role": "user", + "content": "Please tell me about the capital of France." + }], + [{ + "role": "user", + "content": "When is the day longest during the year?" + }], + [{ + "role": "user", + "content": "Where is bigger, the moon or the sun?" + }]] + + prompt_embeds_list = [ + get_prompt_embeds(chat, tokenizer, embedding_layer) for chat in chats + ] + + outputs = llm.generate([{ + "prompt_embeds": embeds + } for embeds in prompt_embeds_list]) + + print("\n[Batch Inference Outputs]") + print("-" * 30) + for i, o in enumerate(outputs): + print(f"Q{i+1}: {chats[i][0]['content']}") + print(f"A{i+1}: {o.outputs[0].text}\n") + print("-" * 30) + + +def main(): + model_name = "meta-llama/Llama-3.2-1B-Instruct" + tokenizer, embedding_layer, llm = init_tokenizer_and_llm(model_name) + single_prompt_inference(llm, tokenizer, embedding_layer) + batch_prompt_inference(llm, tokenizer, embedding_layer) + + +if __name__ == "__main__": + main() diff --git a/tests/singlecard/test_prompt_embedding.py b/tests/singlecard/test_prompt_embedding.py new file mode 100644 index 0000000000..47538173a5 --- /dev/null +++ b/tests/singlecard/test_prompt_embedding.py @@ -0,0 +1,259 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# Adapted from vllm/tests/entrypoints/openai/test_completion_with_prompt_embeds.py +# +import base64 +import io +import os + +import openai # use the official client for correctness check +import pytest +import pytest_asyncio +import torch +from modelscope import snapshot_download # type: ignore +from openai import BadRequestError +from transformers import AutoConfig +from vllm.engine.arg_utils import EngineArgs + +from tests.utils import RemoteOpenAIServer + +if not hasattr(EngineArgs, "enable_prompt_embeds"): + pytest.skip("Not supported vllm version", allow_module_level=True) + +# any model with a chat template should work here +MODEL_NAME = snapshot_download("LLM-Research/Llama-3.2-1B-Instruct") + +CONFIG = AutoConfig.from_pretrained(MODEL_NAME) + + +@pytest.fixture(scope="module") +def default_server_args() -> list[str]: + return [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "8192", + "--max-num-seqs", + "128", + "--enforce-eager", + # Prompt Embeds server args + "--enable-prompt-embeds", + "--no-enable-chunked-prefill", + ] + + +@pytest.fixture(scope="module", + params=["", "--disable-frontend-multiprocessing"]) +def server_with_prompt_embeds(default_server_args, request): + if request.param: + default_server_args.append(request.param) + + with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client_with_prompt_embeds(server_with_prompt_embeds): + async with server_with_prompt_embeds.get_async_client() as async_client: + yield async_client + + +def create_dummy_embeds(num_tokens: int = 5) -> str: + """Create dummy embeddings and return them as base64 encoded string.""" + dummy_embeds = torch.randn(num_tokens, CONFIG.hidden_size) + buffer = io.BytesIO() + torch.save(dummy_embeds, buffer) + return base64.b64encode(buffer.getvalue()).decode('utf-8') + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.skipif( + os.getenv("VLLM_USE_V1") == "1", + reason="Enable embedding input will fallback to v0, skip it") +async def test_completions_with_prompt_embeds( + client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str): + # Test case: Single prompt embeds input + encoded_embeds = create_dummy_embeds() + completion = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": encoded_embeds}) + assert len(completion.choices[0].text) >= 1 + assert completion.choices[0].prompt_logprobs is None + + # Test case: batch completion with prompt_embeds + encoded_embeds2 = create_dummy_embeds() + completion = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}) + assert len(completion.choices) == 2 + assert len(completion.choices[0].text) >= 1 + assert len(completion.choices[1].text) >= 1 + + # Test case: streaming with prompt_embeds + encoded_embeds = create_dummy_embeds() + single_completion = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": encoded_embeds}) + single_output = single_completion.choices[0].text + + stream = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + stream=True, + extra_body={"prompt_embeds": encoded_embeds}) + chunks = [] + finish_reason_count = 0 + async for chunk in stream: + chunks.append(chunk.choices[0].text) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + assert finish_reason_count == 1 + assert chunk.choices[0].finish_reason == "length" + assert chunk.choices[0].text + assert "".join(chunks) == single_output + + # Test case: batch streaming with prompt_embeds + encoded_embeds2 = create_dummy_embeds() + stream = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + stream=True, + extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}) + chunks_stream_embeds: list[list[str]] = [[], []] + finish_reason_count = 0 + async for chunk in stream: + chunks_stream_embeds[chunk.choices[0].index].append( + chunk.choices[0].text) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + assert finish_reason_count == 2 + assert chunk.choices[0].finish_reason == "length" + assert chunk.choices[0].text + assert len(chunks_stream_embeds[0]) > 0 + assert len(chunks_stream_embeds[1]) > 0 + + # Test case: mixed text and prompt_embeds + encoded_embeds = create_dummy_embeds() + completion_mixed = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="This is a prompt", + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": encoded_embeds}) + assert len(completion.choices) == 2 + completion_text_only = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="This is a prompt", + max_tokens=5, + temperature=0.0, + ) + completion_embeds_only = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": encoded_embeds}) + # Embeddings responses should be handled first + assert completion_mixed.choices[0].text == completion_embeds_only.choices[ + 0].text + assert completion_mixed.choices[1].text == completion_text_only.choices[ + 0].text + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.skipif( + os.getenv("VLLM_USE_V1") == "1", + reason="Enable embedding input will fallback to v0, skip it") +async def test_completions_errors_with_prompt_embeds( + client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str): + # Test error case: invalid prompt_embeds + with pytest.raises(BadRequestError): + await client_with_prompt_embeds.completions.create( + prompt="", + model=model_name, + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": "invalid_base64"}) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("logprobs_arg", [1, 0]) +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.skipif( + os.getenv("VLLM_USE_V1") == "1", + reason="Enable embedding input will fallback to v0, skip it") +async def test_completions_with_logprobs_and_prompt_embeds( + client_with_prompt_embeds: openai.AsyncOpenAI, logprobs_arg: int, + model_name: str): + # Test case: Logprobs using prompt_embeds + encoded_embeds = create_dummy_embeds() + completion = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + echo=False, + logprobs=logprobs_arg, + extra_body={"prompt_embeds": encoded_embeds}) + + logprobs = completion.choices[0].logprobs + assert logprobs is not None + assert len(logprobs.text_offset) == 5 + assert len(logprobs.token_logprobs) == 5 + assert len(logprobs.top_logprobs) == 5 + for top_logprobs in logprobs.top_logprobs[1:]: + assert max(logprobs_arg, 1) <= len(top_logprobs) <= logprobs_arg + 1 + assert len(logprobs.tokens) == 5 + + # Test case: Log probs with batch completion and prompt_embeds + encoded_embeds2 = create_dummy_embeds() + completion = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + echo=False, + logprobs=logprobs_arg, + extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}) + + assert len(completion.choices) == 2 + for choice in completion.choices: + logprobs = choice.logprobs + assert logprobs is not None + assert len(logprobs.text_offset) == 5 + assert len(logprobs.token_logprobs) == 5 + assert len(logprobs.top_logprobs) == 5 + for top_logprobs in logprobs.top_logprobs[1:]: + assert max(logprobs_arg, + 1) <= len(top_logprobs) <= logprobs_arg + 1 + assert len(logprobs.tokens) == 5 diff --git a/tests/utils.py b/tests/utils.py index f8b6f345a0..ced7d9a1b1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -20,13 +20,143 @@ import functools import os import signal -from typing import Callable +import subprocess +import sys +import time +from typing import Callable, Optional +import openai +import requests from typing_extensions import ParamSpec +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.entrypoints.openai.cli_args import make_arg_parser +from vllm.model_executor.model_loader import get_model_loader +from vllm.utils import FlexibleArgumentParser, get_open_port _P = ParamSpec("_P") +class RemoteOpenAIServer: + DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key + + def __init__(self, + model: str, + vllm_serve_args: list[str], + *, + env_dict: Optional[dict[str, str]] = None, + seed: Optional[int] = 0, + auto_port: bool = True, + max_wait_seconds: Optional[float] = None) -> None: + if auto_port: + if "-p" in vllm_serve_args or "--port" in vllm_serve_args: + raise ValueError("You have manually specified the port " + "when `auto_port=True`.") + + # Don't mutate the input args + vllm_serve_args = vllm_serve_args + [ + "--port", str(get_open_port()) + ] + if seed is not None: + if "--seed" in vllm_serve_args: + raise ValueError("You have manually specified the seed " + f"when `seed={seed}`.") + + vllm_serve_args = vllm_serve_args + ["--seed", str(seed)] + + parser = FlexibleArgumentParser( + description="vLLM's remote OpenAI server.") + parser = make_arg_parser(parser) + args = parser.parse_args(["--model", model, *vllm_serve_args]) + self.host = str(args.host or 'localhost') + self.port = int(args.port) + + self.show_hidden_metrics = \ + args.show_hidden_metrics_for_version is not None + + # download the model before starting the server to avoid timeout + is_local = os.path.isdir(model) + if not is_local: + engine_args = AsyncEngineArgs.from_cli_args(args) + model_config = engine_args.create_model_config() + load_config = engine_args.create_load_config() + + model_loader = get_model_loader(load_config) + model_loader.download_model(model_config) + + env = os.environ.copy() + # the current process might initialize cuda, + # to be safe, we should use spawn method + env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' + if env_dict is not None: + env.update(env_dict) + self.proc = subprocess.Popen( + ["vllm", "serve", model, *vllm_serve_args], + env=env, + stdout=sys.stdout, + stderr=sys.stderr, + ) + max_wait_seconds = max_wait_seconds or 240 + self._wait_for_server(url=self.url_for("health"), + timeout=max_wait_seconds) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.proc.terminate() + try: + self.proc.wait(8) + except subprocess.TimeoutExpired: + # force kill if needed + self.proc.kill() + + def _wait_for_server(self, *, url: str, timeout: float): + # run health check + start = time.time() + while True: + try: + if requests.get(url).status_code == 200: + break + except Exception: + # this exception can only be raised by requests.get, + # which means the server is not ready yet. + # the stack trace is not useful, so we suppress it + # by using `raise from None`. + result = self.proc.poll() + if result is not None and result != 0: + raise RuntimeError("Server exited unexpectedly.") from None + + time.sleep(0.5) + if time.time() - start > timeout: + raise RuntimeError( + "Server failed to start in time.") from None + + @property + def url_root(self) -> str: + return f"http://{self.host}:{self.port}" + + def url_for(self, *parts: str) -> str: + return self.url_root + "/" + "/".join(parts) + + def get_client(self, **kwargs): + if "timeout" not in kwargs: + kwargs["timeout"] = 600 + return openai.OpenAI( + base_url=self.url_for("v1"), + api_key=self.DUMMY_API_KEY, + max_retries=0, + **kwargs, + ) + + def get_async_client(self, **kwargs): + if "timeout" not in kwargs: + kwargs["timeout"] = 600 + return openai.AsyncOpenAI(base_url=self.url_for("v1"), + api_key=self.DUMMY_API_KEY, + max_retries=0, + **kwargs) + + def fork_new_process_for_each_test( f: Callable[_P, None]) -> Callable[_P, None]: """Decorator to fork a new process for each test function. diff --git a/vllm_ascend/worker/model_runner.py b/vllm_ascend/worker/model_runner.py index 2f3b872f0b..43059b82cc 100644 --- a/vllm_ascend/worker/model_runner.py +++ b/vllm_ascend/worker/model_runner.py @@ -33,7 +33,7 @@ from vllm.attention.backends.utils import CommonAttentionState from vllm.config import VllmConfig from vllm.core.scheduler import SchedulerOutputs -from vllm.distributed import get_dp_group, get_pp_group +from vllm.distributed import broadcast_tensor_dict, get_dp_group, get_pp_group from vllm.distributed.kv_transfer import get_kv_transfer_group from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY, InputRegistry @@ -43,7 +43,8 @@ from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata, SamplingMetadataCache from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.sampler import (Sampler, SamplerOutput, + get_sampler) from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.models import supports_lora, supports_multimodal @@ -84,6 +85,7 @@ class ModelInputForNPU(ModelRunnerInputBase): additional fields. """ input_tokens: Optional[torch.Tensor] = None + inputs_embeds: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None token_types: Optional[torch.Tensor] = None seq_lens: Optional[List[int]] = None @@ -103,6 +105,7 @@ class ModelInputForNPU(ModelRunnerInputBase): def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { "input_tokens": self.input_tokens, + "inputs_embeds": self.inputs_embeds, "input_positions": self.input_positions, "lora_requests": self.lora_requests, "lora_mapping": self.lora_mapping, @@ -151,6 +154,7 @@ class ModelInputForNPUWithSamplingMetadata(ModelInputForNPU): def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { "input_tokens": self.input_tokens, + "inputs_embeds": self.inputs_embeds, "input_positions": self.input_positions, "lora_requests": self.lora_requests, "lora_mapping": self.lora_mapping, @@ -188,6 +192,7 @@ class InterDataForSeqGroup: def simple_reinit(self): self.input_tokens[0].clear() # type: ignore + self.inputs_embeds = None # type: ignore self.input_positions[0].clear() # type: ignore self.token_types[0].clear() # type: ignore self.mrope_input_positions = None # type: ignore @@ -213,6 +218,7 @@ def __init__( # Input tokens and positions. input_tokens: Optional[List[List[int]]] = None, + inputs_embeds: Optional[torch.Tensor] = None, input_positions: Optional[List[List[int]]] = None, token_types: Optional[List[List[int]]] = None, mrope_input_positions: Optional[List[List[List[int]]]] = None, @@ -268,6 +274,7 @@ def __init__( else: for seq_id in range(len(self.seq_ids)): self.input_tokens[seq_id].clear() + self.inputs_embeds = inputs_embeds if input_positions: self.input_positions = input_positions @@ -329,6 +336,7 @@ def __init__( else: self.input_tokens = input_tokens or [] + self.inputs_embeds = inputs_embeds self.input_positions = input_positions or [] self.token_types = token_types or [] self.mrope_input_positions = mrope_input_positions or None @@ -368,6 +376,26 @@ def __post_init__(self): self.lora_index_mapping = [] self.lora_prompt_mapping = [] + def __repr__(self) -> str: + return (f"InterDataForSeqGroup(" + f"request_id={self.request_id}, " + f"seq_ids={self.seq_ids}, " + f"is_prompt={self.is_prompt}, " + f"block_tables={self.block_tables}, " + f"computed_block_nums={self.computed_block_nums}, " + f"n_seqs={self.n_seqs}, " + f"input_tokens={self.input_tokens}, " + f"inputs_embeds.shape=" + f"{getattr(self.inputs_embeds, 'shape', None)}, " + f"input_positions={self.input_positions}, " + f"token_types={self.token_types}, " + f"mrope_input_positions={self.mrope_input_positions}, " + f"seq_lens={self.seq_lens}, " + f"orig_seq_lens={self.orig_seq_lens}, " + f"query_lens={self.query_lens}, " + f"context_lens={self.context_lens}, " + f"multi_modal_kwargs={self.multi_modal_kwargs}") + def __init__(self, runner, finished_requests_ids: Optional[List[str]] = None): @@ -492,11 +520,30 @@ def build(self) -> ModelInputForNPU: create on-device tensors. """ # Combine and flatten intermediate data. - input_tokens = [ - flatten_2d_lists(inter_data.input_tokens) - for inter_data in self.inter_data_list - ] - if not input_tokens: + input_tokens = list[int]() + inputs_embeds_list = list[torch.Tensor]() + token_types = list[int]() + for inter_data in self.inter_data_list: + for cur_input_tokens in inter_data.input_tokens: + input_tokens.extend(cur_input_tokens) + for cur_token_types in inter_data.token_types: + token_types.extend(cur_token_types) + if inter_data.inputs_embeds is not None: + inputs_embeds_list.append( + inter_data.inputs_embeds.to( + dtype=self.runner.model_config.dtype, + device=self.runner.device)) + + inputs_embeds: Optional[torch.Tensor] + if len(inputs_embeds_list) == 0: + inputs_embeds = None + else: + inputs_embeds = torch.cat(inputs_embeds_list, dim=0).to( + dtype=self.runner.model_config.dtype, + device=self.runner.device) + assert len(inputs_embeds) == len(input_tokens) + + if not input_tokens and inputs_embeds is None: # This may happen when all prefill requests hit # prefix caching and there is no decode request. return self.model_input_cls() @@ -548,10 +595,6 @@ def build(self) -> ModelInputForNPU: else: graph_pad_size = -1 - #print(f"before tensor input_tokens: {input_tokens}") - #print(f"before tensor input_positions: {input_positions}") - #print(f"before list seq_lens: {seq_lens}") - input_tokens = flatten_2d_lists(input_tokens) if input_positions: input_positions = flatten_2d_lists(input_positions) if graph_pad_size != -1 and not is_prompt: @@ -563,6 +606,10 @@ def build(self) -> ModelInputForNPU: input_tokens_tensor = torch.tensor(input_tokens, dtype=torch.long, device=self.runner.device) + token_types_tensor = torch.tensor(token_types, + dtype=torch.long, + device=self.runner.device) \ + if token_types else None if mrope_input_positions is not None: input_positions_tensor = torch.tensor(mrope_input_positions, dtype=torch.long, @@ -613,6 +660,8 @@ def build(self) -> ModelInputForNPU: return self.model_input_cls( input_tokens=input_tokens_tensor, + inputs_embeds=inputs_embeds, + token_types=token_types_tensor, input_positions=input_positions_tensor, attn_metadata=attn_metadata, seq_lens=seq_lens, @@ -645,13 +694,23 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, context_len = seq_data.get_num_computed_tokens() # Compute tokens. - tokens = seq_data.get_token_ids()[context_len:seq_len] + # Fixme: this is for the version compatibility, remove this once vllm v0.8.5 does not be supported. + if not hasattr(seq_data, + "prompt_embeds") or seq_data.prompt_embeds is None: + tokens = seq_data.get_token_ids()[context_len:seq_len] + prompt_embeds = None + else: + tokens = [0] * (seq_len - context_len) + prompt_embeds = seq_data.get_token_embeddings( + )[context_len:seq_len] + token_types = seq_group_metadata.token_type_ids inter_data.seq_lens[seq_idx] = seq_len inter_data.orig_seq_lens[seq_idx] = seq_len inter_data.context_lens[seq_idx] = context_len inter_data.input_tokens[seq_idx].extend(tokens) + inter_data.inputs_embeds = prompt_embeds inter_data.input_positions[seq_idx].extend(range(context_len, seq_len)) inter_data.token_types[seq_idx].extend( token_types if token_types else []) @@ -1379,6 +1438,7 @@ def execute_model( model_kwargs["attn_metadata"] = model_input.attn_metadata hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, + inputs_embeds=model_input.inputs_embeds, positions=model_input.input_positions, intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs(multi_modal_kwargs, @@ -1422,33 +1482,60 @@ def execute_model( hidden_or_intermediate_states, ) - if not self.is_driver_worker: - return [] + if self.is_driver_worker: + if model_input.async_callback is not None: + model_input.async_callback() - if model_input.async_callback is not None: - model_input.async_callback() + # Sample the next token. + assert isinstance(self.sampler, Sampler) + orig_include_gpu_probs = self.sampler.include_gpu_probs_tensor + if model_input.inputs_embeds is not None: + self.sampler.include_gpu_probs_tensor = True - # Sample the next token. - output = self.sampler( - logits=logits, - sampling_metadata=model_input.sampling_metadata, - ) - if (self.observability_config is not None - and self.observability_config.collect_model_forward_time - and output is not None): - model_forward_end.synchronize() - model_forward_time = model_forward_start.elapsed_time( - model_forward_end) - orig_model_forward_time = 0.0 - if intermediate_tensors is not None: - orig_model_forward_time = intermediate_tensors.tensors.get( - "model_forward_time", torch.tensor(0.0)).item() - # If there are multiple workers, we are still tracking the latency - # from the start time of the driver worker to the end time of the - # driver worker. The model forward time will then end up covering - # the communication time as well. - output.model_forward_time = (orig_model_forward_time + - model_forward_time) + output: SamplerOutput = self.sampler( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) + if (self.observability_config is not None + and self.observability_config.collect_model_forward_time + and output is not None): + model_forward_end.synchronize() + model_forward_time = model_forward_start.elapsed_time( + model_forward_end) + orig_model_forward_time = 0.0 + if intermediate_tensors is not None: + orig_model_forward_time = intermediate_tensors.tensors.get( + "model_forward_time", torch.tensor(0.0)).item() + # If there are multiple workers, we are still tracking the + # latency from the start time of the driver worker to the end + # time of the driver worker. The model forward time will then + # end up covering the communication time as well. + output.model_forward_time = (orig_model_forward_time + + model_forward_time) + + if model_input.inputs_embeds is not None: + if self.is_driver_worker: + sampled = broadcast_tensor_dict( + {"token_ids": output.sampled_token_ids}) + else: + sampled = broadcast_tensor_dict() + if sampled["token_ids"] is not None: + sampled_token_embeds = self.model.get_input_embeddings( + sampled["token_ids"].squeeze(1)) + if self.is_driver_worker: + self.sampler.include_gpu_probs_tensor = \ + orig_include_gpu_probs + + output.sampled_token_embeds = sampled_token_embeds + + for token_embed, sequence_group_output in zip( + output.sampled_token_embeds, output.outputs): + assert len(sequence_group_output.samples) == 1 + sequence_group_output.samples[ + 0].output_embed = token_embed + + if not self.is_driver_worker: + return [] if self.return_hidden_states: # we only need to pass hidden states of most recent token