Skip to content

Commit

Permalink
feat: powv (vllm-project#1)
Browse files Browse the repository at this point in the history
* feat: powv per token

* feat: add justfile

* fix: justfile

* fix: missing link in powv pass

* fix: powv calculation

* ref: powv to separate function

* fix: move to parent class

* feat: initial verify endpoint

* feat: initial verify endpoint

* fix: actually add as route

* feat(WIP): verfiy endpoint

* fix: sequence of ints instead of list for chat completion

* fix: loosen restrictions on verify chat completion

* fix: verifychatcompletion for get_powv

* fix: using wrong field

* fix: add very into rpc layer

* fix: await verify

* fix: non-async fields

* fix: async handling

* fix: no more destruct

* feat: return powv to the top

* fix: send back via socket

* feat: add endpoint for completion

* feat: add version guards
  • Loading branch information
GentikSolm authored Sep 9, 2024
1 parent 77d9e51 commit 76e79c9
Show file tree
Hide file tree
Showing 18 changed files with 241 additions and 15 deletions.
8 changes: 8 additions & 0 deletions justfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
default:
just --list

install:
CUDACXX=/usr/local/cuda-12/bin/nvcc pip install -e .

vllm:
CUDA_VISIBLE_DEVICES=3 vllm serve NousResearch/Meta-Llama-3.1-8B-Instruct --dtype auto
8 changes: 8 additions & 0 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from typing_extensions import assert_never

from vllm.entrypoints.openai.protocol import VerifyChatCompletion
import vllm.envs as envs
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
Expand Down Expand Up @@ -997,6 +998,13 @@ async def add_request(

return stream.generator()

def verify(
self,
inputs: VerifyChatCompletion,
) -> Optional[int]:
"""Verifies outputs for a request."""
return self.engine.verify_chat_completion(inputs)

async def generate(
self,
inputs: PromptInputs,
Expand Down
7 changes: 7 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch
from typing_extensions import TypeVar, assert_never

from vllm.entrypoints.openai.protocol import VerifyChatCompletion
import vllm.envs as envs
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
Expand Down Expand Up @@ -1036,6 +1037,11 @@ def process_model_inputs(

return self.input_processor(model_inputs)

def verify_chat_completion(
self,
inputs: VerifyChatCompletion):
return self.model_executor.verify_output(inputs)

def add_request(
self,
request_id: str,
Expand Down Expand Up @@ -1308,6 +1314,7 @@ def _process_model_outputs(self, ctx: SchedulerContext) -> None:
for o in outputs:
if (isinstance(o, SamplerOutput)
and seq_group.metrics is not None):
seq_group.powv = o.powv
if seq_group.metrics.model_forward_time is not None:
seq_group.metrics.model_forward_time += (
o.model_forward_time)
Expand Down
11 changes: 10 additions & 1 deletion vllm/engine/protocol.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import (AsyncGenerator, List, Mapping, Optional, Protocol,
import types
from typing import (AsyncGenerator, Coroutine, List, Mapping, Optional, Protocol,
runtime_checkable)

from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.entrypoints.openai.protocol import VerifyChatCompletion
from vllm.inputs.data import PromptInputs
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
Expand Down Expand Up @@ -33,6 +35,13 @@ def errored(self) -> bool:
def limit_concurrency(self) -> Optional[int]:
"""Maximum number of concurrently running requests."""

def verify(
self,
inputs: VerifyChatCompletion,
) -> Coroutine[None, None, int]:
"""Verify outputs for a request"""
...

def generate(
self,
inputs: PromptInputs,
Expand Down
63 changes: 61 additions & 2 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@
DetokenizeRequest,
DetokenizeResponse,
EmbeddingRequest,
EmbeddingResponse, ErrorResponse,
EmbeddingResponse, ErrorResponse, TokenizeChatRequest,
TokenizeRequest,
TokenizeResponse)
TokenizeResponse, VerifyChatCompletion, VerifyChatCompletionResponse, VerifyCompletionResponse)
# yapf: enable
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
from vllm.entrypoints.openai.rpc.server import run_rpc_server
Expand Down Expand Up @@ -280,6 +280,36 @@ async def show_version():
return JSONResponse(content=ver)


@router.post("/v1/chat/completions/verify")
async def verify_chat_completion(req: VerifyChatCompletionResponse):
version = '0.0.0'
if req.version != version:
return JSONResponse(content=f"Bad version. Got {req.version}, need {version}.")
tokenize_request = TokenizeChatRequest(messages=req.messages, model=req.model)
generator = await openai_serving_tokenization.create_tokenize(tokenize_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
elif not isinstance(generator, TokenizeResponse):
return JSONResponse(content=generator.model_dump(), status_code=500)
(
lora_request,
_,
) = openai_serving_tokenization._maybe_get_adapters(tokenize_request)

tokenizer = await openai_serving_tokenization.async_engine_client.get_tokenizer(lora_request)
prompt_tokens = generator.tokens
response_tokens = openai_serving_tokenization._tokenize_prompt_input(
tokenize_request,
tokenizer,
req.response,
add_special_tokens=False,
)['prompt_token_ids']
res = await openai_serving_chat.verify_chat_completion(VerifyChatCompletion(model=req.model, input_tokens=prompt_tokens, response_tokens=response_tokens, powv=req.powv))
return JSONResponse(content=res == req.powv and req.powv is not None)



@router.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request):
Expand Down Expand Up @@ -309,6 +339,35 @@ async def create_completion(request: CompletionRequest, raw_request: Request):

return StreamingResponse(content=generator, media_type="text/event-stream")

@router.post("/v1/completions/verify")
async def verify_completion(req: VerifyCompletionResponse):
version = '0.0.0'
if req.version != version:
return JSONResponse(content=f"Bad version. Got {req.version}, need {version}.")
tokenize_request = TokenizeChatRequest(prompt=req.prompt, model=req.model)
generator = await openai_serving_tokenization.create_tokenize(tokenize_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
elif not isinstance(generator, TokenizeResponse):
return JSONResponse(content=generator.model_dump(), status_code=500)
(
lora_request,
_,
) = openai_serving_tokenization._maybe_get_adapters(tokenize_request)

tokenizer = await openai_serving_tokenization.async_engine_client.get_tokenizer(lora_request)
prompt_tokens = generator.tokens
response_tokens = openai_serving_tokenization._tokenize_prompt_input(
tokenize_request,
tokenizer,
req.response,
add_special_tokens=False,
)['prompt_token_ids']
res = await openai_serving_chat.verify_chat_completion(VerifyChatCompletion(model=req.model, input_tokens=prompt_tokens, response_tokens=response_tokens, powv=req.powv))
return JSONResponse(content=res == req.powv and req.powv is not None)



@router.post("/v1/embeddings")
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
Expand Down
23 changes: 22 additions & 1 deletion vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import time
from argparse import Namespace
from typing import Any, Dict, List, Literal, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Sequence, Union

import torch
from openai.types.chat import ChatCompletionContentPartParam
Expand Down Expand Up @@ -784,6 +784,7 @@ class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
logprobs: Optional[ChatCompletionLogProbs] = None
finish_reason: Optional[str] = None
stop_reason: Optional[Union[int, str]] = None
powv: Optional[int] = None


class ChatCompletionStreamResponse(OpenAIBaseModel):
Expand Down Expand Up @@ -861,6 +862,26 @@ class TokenizeChatRequest(OpenAIBaseModel):
add_generation_prompt: bool = Field(default=True)
add_special_tokens: bool = Field(default=False)

class VerifyChatCompletion(OpenAIBaseModel):
model: str
input_tokens: Sequence[int]
response_tokens: Sequence[int]
powv: Optional[int] = None

class VerifyChatCompletionResponse(OpenAIBaseModel):
model: str
messages: List[ChatCompletionMessageParam]
response: str
powv: int
version: str

class VerifyCompletionResponse(OpenAIBaseModel):
model: str
prompt: str
response: str
powv: int
version: str


TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest]

Expand Down
11 changes: 9 additions & 2 deletions vllm/entrypoints/openai/rpc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass
from enum import Enum
from typing import Mapping, Optional, Union
from typing import Mapping, Optional, Sequence, Union

from vllm.inputs import PromptInputs
from vllm.lora.request import LoRARequest
Expand All @@ -27,6 +27,13 @@ class RPCGenerateRequest:
prompt_adapter_request: Optional[PromptAdapterRequest] = None


@dataclass
class RPCVerifyResponse:
model: str
input_tokens: Sequence[int]
response_tokens: Sequence[int]


@dataclass
class RPCAbortRequest:
request_id: str
Expand All @@ -47,4 +54,4 @@ class RPCUtilityRequest(Enum):


RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest,
RPCUtilityRequest]
RPCUtilityRequest, RPCVerifyResponse]
14 changes: 12 additions & 2 deletions vllm/entrypoints/openai/rpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from uuid import uuid4

import cloudpickle
from vllm.entrypoints.openai.protocol import VerifyChatCompletion
import zmq
import zmq.asyncio
from zmq import Frame # type: ignore[attr-defined]
Expand All @@ -17,7 +18,7 @@
VLLM_RPC_SOCKET_LIMIT_CUTOFF,
VLLM_RPC_SUCCESS_STR,
VLLM_RPC_ZMQ_HWM, RPCAbortRequest,
RPCGenerateRequest, RPCUtilityRequest)
RPCGenerateRequest, RPCUtilityRequest, RPCVerifyResponse)
# yapf: enable
from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS
from vllm.inputs import PromptInputs
Expand Down Expand Up @@ -198,7 +199,7 @@ def to_proxy_socket(self) -> Iterator[Socket]:
finally:
socket.close(linger=0)

async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,
async def _send_get_data_rpc_request(self, request: RPC_REQUEST_TYPE,
expected_type: Any,
error_message: str) -> Any:
"""Send an RPC request that is expecting data back."""
Expand Down Expand Up @@ -371,6 +372,15 @@ def is_stopped(self) -> bool:
def errored(self) -> bool:
return self._errored

async def verify(self, inputs: VerifyChatCompletion):
return await self._send_get_data_rpc_request(
RPCVerifyResponse(
model=inputs.model,
input_tokens=inputs.input_tokens,
response_tokens=inputs.response_tokens),
expected_type=Optional[int],
error_message="Failed to verify response")

async def generate(
self,
inputs: PromptInputs,
Expand Down
11 changes: 10 additions & 1 deletion vllm/entrypoints/openai/rpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import cloudpickle
import uvloop
from vllm.entrypoints.openai.protocol import VerifyChatCompletion
import zmq
import zmq.asyncio
from typing_extensions import Never
Expand All @@ -16,7 +17,7 @@
ParallelConfig, SchedulerConfig)
from vllm.entrypoints.openai.rpc import (VLLM_RPC_SUCCESS_STR,
VLLM_RPC_ZMQ_HWM, RPCAbortRequest,
RPCGenerateRequest, RPCUtilityRequest)
RPCGenerateRequest, RPCUtilityRequest, RPCVerifyResponse)
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext

Expand Down Expand Up @@ -150,6 +151,11 @@ async def stop_profile(self, identity):
pickle.dumps(VLLM_RPC_SUCCESS_STR),
))

async def verify_response(self,identity, input: VerifyChatCompletion):
output =self.engine.verify(input)
await self.socket.send_multipart((identity, pickle.dumps(output)),
copy=False)

def _make_handler_coro(self, identity,
message: Frame) -> Coroutine[Any, Any, Never]:
"""Route the zmq message to the handler coroutine."""
Expand All @@ -159,6 +165,9 @@ def _make_handler_coro(self, identity,
if isinstance(request, RPCGenerateRequest):
return self.generate(identity, request)

elif isinstance(request, RPCVerifyResponse):
return self.verify_response(identity,request)

elif isinstance(request, RPCAbortRequest):
return self.abort(identity, request)

Expand Down
15 changes: 13 additions & 2 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import json
import time
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, Final, List,
from typing import (AsyncGenerator, AsyncIterator, Callable, Coroutine, Dict, Final, List,
Optional)
from typing import Sequence as GenericSequence
from typing import Union
Expand All @@ -21,7 +21,7 @@
ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, ErrorResponse, FunctionCall, ToolCall, UsageInfo)
DeltaToolCall, ErrorResponse, FunctionCall, ToolCall, UsageInfo, VerifyChatCompletion)
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing,
PromptAdapterPath,
Expand Down Expand Up @@ -86,6 +86,13 @@ def __init__(self,
raise TypeError("Error: --enable-auto-tool-choice requires "
"--tool-call-parser")

def verify_chat_completion(self, request: VerifyChatCompletion):
verified = self.async_engine_client.verify(
request,
)
return verified


async def create_chat_completion(
self,
request: ChatCompletionRequest,
Expand Down Expand Up @@ -274,6 +281,7 @@ async def chat_completion_stream_generator(
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(role=role),
powv=res.powv,
logprobs=None,
finish_reason=None)
chunk = ChatCompletionStreamResponse(
Expand Down Expand Up @@ -316,6 +324,7 @@ async def chat_completion_stream_generator(
index=i,
delta=DeltaMessage(
content=last_msg_content),
powv=res.powv,
logprobs=None,
finish_reason=None))
chunk = ChatCompletionStreamResponse(
Expand Down Expand Up @@ -417,6 +426,7 @@ async def chat_completion_stream_generator(
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=delta_message,
powv=res.powv,
logprobs=logprobs,
finish_reason=None)
chunk = ChatCompletionStreamResponse(
Expand Down Expand Up @@ -488,6 +498,7 @@ async def chat_completion_stream_generator(
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=delta_message,
powv=res.powv,
logprobs=logprobs,
finish_reason=output.finish_reason
if not (tool_parser
Expand Down
Loading

0 comments on commit 76e79c9

Please sign in to comment.