From 76e79c9c67ebf9e5243ece40c1992df73cbb04df Mon Sep 17 00:00:00 2001 From: Josh Date: Mon, 9 Sep 2024 15:40:07 -0500 Subject: [PATCH] feat: powv (#1) * feat: powv per token * feat: add justfile * fix: justfile * fix: missing link in powv pass * fix: powv calculation * ref: powv to separate function * fix: move to parent class * feat: initial verify endpoint * feat: initial verify endpoint * fix: actually add as route * feat(WIP): verfiy endpoint * fix: sequence of ints instead of list for chat completion * fix: loosen restrictions on verify chat completion * fix: verifychatcompletion for get_powv * fix: using wrong field * fix: add very into rpc layer * fix: await verify * fix: non-async fields * fix: async handling * fix: no more destruct * feat: return powv to the top * fix: send back via socket * feat: add endpoint for completion * feat: add version guards --- justfile | 8 ++++ vllm/engine/async_llm_engine.py | 8 ++++ vllm/engine/llm_engine.py | 7 +++ vllm/engine/protocol.py | 11 ++++- vllm/entrypoints/openai/api_server.py | 63 ++++++++++++++++++++++++- vllm/entrypoints/openai/protocol.py | 23 ++++++++- vllm/entrypoints/openai/rpc/__init__.py | 11 ++++- vllm/entrypoints/openai/rpc/client.py | 14 +++++- vllm/entrypoints/openai/rpc/server.py | 11 ++++- vllm/entrypoints/openai/serving_chat.py | 15 +++++- vllm/executor/executor_base.py | 8 ++++ vllm/executor/gpu_executor.py | 9 ++++ vllm/model_executor/layers/sampler.py | 2 + vllm/outputs.py | 10 ++-- vllm/sequence.py | 1 + vllm/worker/model_runner.py | 44 +++++++++++++++++ vllm/worker/worker.py | 4 ++ vllm/worker/worker_base.py | 7 ++- 18 files changed, 241 insertions(+), 15 deletions(-) create mode 100644 justfile diff --git a/justfile b/justfile new file mode 100644 index 0000000000000..407736f9ad892 --- /dev/null +++ b/justfile @@ -0,0 +1,8 @@ +default: + just --list + +install: + CUDACXX=/usr/local/cuda-12/bin/nvcc pip install -e . + +vllm: + CUDA_VISIBLE_DEVICES=3 vllm serve NousResearch/Meta-Llama-3.1-8B-Instruct --dtype auto diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 7fe8053fffb7b..0a080d89cd85d 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -6,6 +6,7 @@ from typing_extensions import assert_never +from vllm.entrypoints.openai.protocol import VerifyChatCompletion import vllm.envs as envs from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig) @@ -997,6 +998,13 @@ async def add_request( return stream.generator() + def verify( + self, + inputs: VerifyChatCompletion, + ) -> Optional[int]: + """Verifies outputs for a request.""" + return self.engine.verify_chat_completion(inputs) + async def generate( self, inputs: PromptInputs, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 7da4f7b25db9e..e6084c7b12948 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -11,6 +11,7 @@ import torch from typing_extensions import TypeVar, assert_never +from vllm.entrypoints.openai.protocol import VerifyChatCompletion import vllm.envs as envs from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, EngineConfig, LoadConfig, LoRAConfig, ModelConfig, @@ -1036,6 +1037,11 @@ def process_model_inputs( return self.input_processor(model_inputs) + def verify_chat_completion( + self, + inputs: VerifyChatCompletion): + return self.model_executor.verify_output(inputs) + def add_request( self, request_id: str, @@ -1308,6 +1314,7 @@ def _process_model_outputs(self, ctx: SchedulerContext) -> None: for o in outputs: if (isinstance(o, SamplerOutput) and seq_group.metrics is not None): + seq_group.powv = o.powv if seq_group.metrics.model_forward_time is not None: seq_group.metrics.model_forward_time += ( o.model_forward_time) diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 34ae79f5fa8df..b108f8154eb48 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -1,8 +1,10 @@ -from typing import (AsyncGenerator, List, Mapping, Optional, Protocol, +import types +from typing import (AsyncGenerator, Coroutine, List, Mapping, Optional, Protocol, runtime_checkable) from vllm.config import DecodingConfig, ModelConfig from vllm.core.scheduler import SchedulerOutputs +from vllm.entrypoints.openai.protocol import VerifyChatCompletion from vllm.inputs.data import PromptInputs from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput @@ -33,6 +35,13 @@ def errored(self) -> bool: def limit_concurrency(self) -> Optional[int]: """Maximum number of concurrently running requests.""" + def verify( + self, + inputs: VerifyChatCompletion, + ) -> Coroutine[None, None, int]: + """Verify outputs for a request""" + ... + def generate( self, inputs: PromptInputs, diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 728a2e5232d9b..2248a7f020038 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -34,9 +34,9 @@ DetokenizeRequest, DetokenizeResponse, EmbeddingRequest, - EmbeddingResponse, ErrorResponse, + EmbeddingResponse, ErrorResponse, TokenizeChatRequest, TokenizeRequest, - TokenizeResponse) + TokenizeResponse, VerifyChatCompletion, VerifyChatCompletionResponse, VerifyCompletionResponse) # yapf: enable from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient from vllm.entrypoints.openai.rpc.server import run_rpc_server @@ -280,6 +280,36 @@ async def show_version(): return JSONResponse(content=ver) +@router.post("/v1/chat/completions/verify") +async def verify_chat_completion(req: VerifyChatCompletionResponse): + version = '0.0.0' + if req.version != version: + return JSONResponse(content=f"Bad version. Got {req.version}, need {version}.") + tokenize_request = TokenizeChatRequest(messages=req.messages, model=req.model) + generator = await openai_serving_tokenization.create_tokenize(tokenize_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif not isinstance(generator, TokenizeResponse): + return JSONResponse(content=generator.model_dump(), status_code=500) + ( + lora_request, + _, + ) = openai_serving_tokenization._maybe_get_adapters(tokenize_request) + + tokenizer = await openai_serving_tokenization.async_engine_client.get_tokenizer(lora_request) + prompt_tokens = generator.tokens + response_tokens = openai_serving_tokenization._tokenize_prompt_input( + tokenize_request, + tokenizer, + req.response, + add_special_tokens=False, + )['prompt_token_ids'] + res = await openai_serving_chat.verify_chat_completion(VerifyChatCompletion(model=req.model, input_tokens=prompt_tokens, response_tokens=response_tokens, powv=req.powv)) + return JSONResponse(content=res == req.powv and req.powv is not None) + + + @router.post("/v1/chat/completions") async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request): @@ -309,6 +339,35 @@ async def create_completion(request: CompletionRequest, raw_request: Request): return StreamingResponse(content=generator, media_type="text/event-stream") +@router.post("/v1/completions/verify") +async def verify_completion(req: VerifyCompletionResponse): + version = '0.0.0' + if req.version != version: + return JSONResponse(content=f"Bad version. Got {req.version}, need {version}.") + tokenize_request = TokenizeChatRequest(prompt=req.prompt, model=req.model) + generator = await openai_serving_tokenization.create_tokenize(tokenize_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif not isinstance(generator, TokenizeResponse): + return JSONResponse(content=generator.model_dump(), status_code=500) + ( + lora_request, + _, + ) = openai_serving_tokenization._maybe_get_adapters(tokenize_request) + + tokenizer = await openai_serving_tokenization.async_engine_client.get_tokenizer(lora_request) + prompt_tokens = generator.tokens + response_tokens = openai_serving_tokenization._tokenize_prompt_input( + tokenize_request, + tokenizer, + req.response, + add_special_tokens=False, + )['prompt_token_ids'] + res = await openai_serving_chat.verify_chat_completion(VerifyChatCompletion(model=req.model, input_tokens=prompt_tokens, response_tokens=response_tokens, powv=req.powv)) + return JSONResponse(content=res == req.powv and req.powv is not None) + + @router.post("/v1/embeddings") async def create_embedding(request: EmbeddingRequest, raw_request: Request): diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index ff9c3690672b6..285046d571494 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -2,7 +2,7 @@ # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py import time from argparse import Namespace -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Sequence, Union import torch from openai.types.chat import ChatCompletionContentPartParam @@ -784,6 +784,7 @@ class ChatCompletionResponseStreamChoice(OpenAIBaseModel): logprobs: Optional[ChatCompletionLogProbs] = None finish_reason: Optional[str] = None stop_reason: Optional[Union[int, str]] = None + powv: Optional[int] = None class ChatCompletionStreamResponse(OpenAIBaseModel): @@ -861,6 +862,26 @@ class TokenizeChatRequest(OpenAIBaseModel): add_generation_prompt: bool = Field(default=True) add_special_tokens: bool = Field(default=False) +class VerifyChatCompletion(OpenAIBaseModel): + model: str + input_tokens: Sequence[int] + response_tokens: Sequence[int] + powv: Optional[int] = None + +class VerifyChatCompletionResponse(OpenAIBaseModel): + model: str + messages: List[ChatCompletionMessageParam] + response: str + powv: int + version: str + +class VerifyCompletionResponse(OpenAIBaseModel): + model: str + prompt: str + response: str + powv: int + version: str + TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest] diff --git a/vllm/entrypoints/openai/rpc/__init__.py b/vllm/entrypoints/openai/rpc/__init__.py index efc7e43afdcc9..fa3797f1191fd 100644 --- a/vllm/entrypoints/openai/rpc/__init__.py +++ b/vllm/entrypoints/openai/rpc/__init__.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from enum import Enum -from typing import Mapping, Optional, Union +from typing import Mapping, Optional, Sequence, Union from vllm.inputs import PromptInputs from vllm.lora.request import LoRARequest @@ -27,6 +27,13 @@ class RPCGenerateRequest: prompt_adapter_request: Optional[PromptAdapterRequest] = None +@dataclass +class RPCVerifyResponse: + model: str + input_tokens: Sequence[int] + response_tokens: Sequence[int] + + @dataclass class RPCAbortRequest: request_id: str @@ -47,4 +54,4 @@ class RPCUtilityRequest(Enum): RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest, - RPCUtilityRequest] + RPCUtilityRequest, RPCVerifyResponse] diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index 9b88db746be5c..275fba0efb1b0 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -5,6 +5,7 @@ from uuid import uuid4 import cloudpickle +from vllm.entrypoints.openai.protocol import VerifyChatCompletion import zmq import zmq.asyncio from zmq import Frame # type: ignore[attr-defined] @@ -17,7 +18,7 @@ VLLM_RPC_SOCKET_LIMIT_CUTOFF, VLLM_RPC_SUCCESS_STR, VLLM_RPC_ZMQ_HWM, RPCAbortRequest, - RPCGenerateRequest, RPCUtilityRequest) + RPCGenerateRequest, RPCUtilityRequest, RPCVerifyResponse) # yapf: enable from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS from vllm.inputs import PromptInputs @@ -198,7 +199,7 @@ def to_proxy_socket(self) -> Iterator[Socket]: finally: socket.close(linger=0) - async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, + async def _send_get_data_rpc_request(self, request: RPC_REQUEST_TYPE, expected_type: Any, error_message: str) -> Any: """Send an RPC request that is expecting data back.""" @@ -371,6 +372,15 @@ def is_stopped(self) -> bool: def errored(self) -> bool: return self._errored + async def verify(self, inputs: VerifyChatCompletion): + return await self._send_get_data_rpc_request( + RPCVerifyResponse( + model=inputs.model, + input_tokens=inputs.input_tokens, + response_tokens=inputs.response_tokens), + expected_type=Optional[int], + error_message="Failed to verify response") + async def generate( self, inputs: PromptInputs, diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index bebc2faedb680..297bae338b422 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -5,6 +5,7 @@ import cloudpickle import uvloop +from vllm.entrypoints.openai.protocol import VerifyChatCompletion import zmq import zmq.asyncio from typing_extensions import Never @@ -16,7 +17,7 @@ ParallelConfig, SchedulerConfig) from vllm.entrypoints.openai.rpc import (VLLM_RPC_SUCCESS_STR, VLLM_RPC_ZMQ_HWM, RPCAbortRequest, - RPCGenerateRequest, RPCUtilityRequest) + RPCGenerateRequest, RPCUtilityRequest, RPCVerifyResponse) from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext @@ -150,6 +151,11 @@ async def stop_profile(self, identity): pickle.dumps(VLLM_RPC_SUCCESS_STR), )) + async def verify_response(self,identity, input: VerifyChatCompletion): + output =self.engine.verify(input) + await self.socket.send_multipart((identity, pickle.dumps(output)), + copy=False) + def _make_handler_coro(self, identity, message: Frame) -> Coroutine[Any, Any, Never]: """Route the zmq message to the handler coroutine.""" @@ -159,6 +165,9 @@ def _make_handler_coro(self, identity, if isinstance(request, RPCGenerateRequest): return self.generate(identity, request) + elif isinstance(request, RPCVerifyResponse): + return self.verify_response(identity,request) + elif isinstance(request, RPCAbortRequest): return self.abort(identity, request) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 78f355228012f..6b04dcca3eed4 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1,7 +1,7 @@ import asyncio import json import time -from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, Final, List, +from typing import (AsyncGenerator, AsyncIterator, Callable, Coroutine, Dict, Final, List, Optional) from typing import Sequence as GenericSequence from typing import Union @@ -21,7 +21,7 @@ ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage, - DeltaToolCall, ErrorResponse, FunctionCall, ToolCall, UsageInfo) + DeltaToolCall, ErrorResponse, FunctionCall, ToolCall, UsageInfo, VerifyChatCompletion) from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing, PromptAdapterPath, @@ -86,6 +86,13 @@ def __init__(self, raise TypeError("Error: --enable-auto-tool-choice requires " "--tool-call-parser") + def verify_chat_completion(self, request: VerifyChatCompletion): + verified = self.async_engine_client.verify( + request, + ) + return verified + + async def create_chat_completion( self, request: ChatCompletionRequest, @@ -274,6 +281,7 @@ async def chat_completion_stream_generator( choice_data = ChatCompletionResponseStreamChoice( index=i, delta=DeltaMessage(role=role), + powv=res.powv, logprobs=None, finish_reason=None) chunk = ChatCompletionStreamResponse( @@ -316,6 +324,7 @@ async def chat_completion_stream_generator( index=i, delta=DeltaMessage( content=last_msg_content), + powv=res.powv, logprobs=None, finish_reason=None)) chunk = ChatCompletionStreamResponse( @@ -417,6 +426,7 @@ async def chat_completion_stream_generator( choice_data = ChatCompletionResponseStreamChoice( index=i, delta=delta_message, + powv=res.powv, logprobs=logprobs, finish_reason=None) chunk = ChatCompletionStreamResponse( @@ -488,6 +498,7 @@ async def chat_completion_stream_generator( choice_data = ChatCompletionResponseStreamChoice( index=i, delta=delta_message, + powv=res.powv, logprobs=logprobs, finish_reason=output.finish_reason if not (tool_parser diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index c96cb0f2c2981..0cfb9a0484848 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -5,6 +5,7 @@ ModelConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig, SpeculativeConfig) +from vllm.entrypoints.openai.protocol import VerifyChatCompletion from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput from vllm.prompt_adapter.request import PromptAdapterRequest @@ -80,6 +81,13 @@ def execute_model( """Executes at least one model step on the given sequences.""" raise NotImplementedError + @abstractmethod + def verify_output( + self, input: VerifyChatCompletion + ) -> Optional[int]: + """Verify output response""" + raise NotImplementedError + def stop_remote_worker_execution_loop(self) -> None: """Releases parallel workers from model loop.""" return diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 947776e5d6ef4..60a492360592f 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -1,5 +1,6 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union +from vllm.entrypoints.openai.protocol import VerifyChatCompletion from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -130,6 +131,14 @@ def execute_model( output = self.driver_worker.execute_model(execute_model_req) return output + def verify_output( + self, input: VerifyChatCompletion + ) -> bool: + """Verify output response""" + assert self.driver_worker is not None + return self.driver_worker.verify_output(input) + + def add_lora(self, lora_request: LoRARequest) -> bool: assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." return self.driver_worker.add_lora(lora_request) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index c00da106734ae..272f4b3bee233 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -121,6 +121,8 @@ class SamplerOutput( # block/sync across workers, cpu-gpu sync time and sampling time. model_execute_time: Optional[float] = None + powv: Optional[int] = None + def __getitem__(self, idx: int): return self.outputs[idx] diff --git a/vllm/outputs.py b/vllm/outputs.py index e091b576f5972..02a9c63aeae88 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -100,6 +100,7 @@ def __init__( lora_request: Optional[LoRARequest] = None, encoder_prompt: Optional[str] = None, encoder_prompt_token_ids: Optional[List[int]] = None, + powv: Optional[int] = None ) -> None: self.request_id = request_id self.prompt = prompt @@ -111,6 +112,7 @@ def __init__( self.lora_request = lora_request self.encoder_prompt = encoder_prompt self.encoder_prompt_token_ids = encoder_prompt_token_ids + self.powv: Optional[int] = powv @classmethod def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": @@ -166,7 +168,8 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": seq_group.metrics, lora_request=seq_group.lora_request, encoder_prompt=encoder_prompt, - encoder_prompt_token_ids=encoder_prompt_token_ids) + encoder_prompt_token_ids=encoder_prompt_token_ids, + powv=seq_group.powv) def __repr__(self) -> str: return (f"RequestOutput(request_id={self.request_id}, " @@ -193,11 +196,12 @@ class EmbeddingRequestOutput: """ def __init__(self, request_id: str, outputs: "EmbeddingOutput", - prompt_token_ids: List[int], finished: bool): + prompt_token_ids: List[int], finished: bool, powv: Optional[int]=None): self.request_id = request_id self.prompt_token_ids = prompt_token_ids self.finished = finished self.outputs = outputs + self.powv = powv @classmethod def from_seq_group(cls, @@ -209,7 +213,7 @@ def from_seq_group(cls, prompt_token_ids = seq_group.prompt_token_ids finished = seq_group.is_finished() - return cls(seq_group.request_id, output, prompt_token_ids, finished) + return cls(seq_group.request_id, output, prompt_token_ids, finished, seq_group.powv) def __repr__(self): """ diff --git a/vllm/sequence.py b/vllm/sequence.py index a5ebf152ce776..1ab1e784ebc5f 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -615,6 +615,7 @@ def __init__( self.prompt_adapter_request = prompt_adapter_request self.encoder_seq = encoder_seq self.trace_headers = trace_headers + self.powv: Optional[int] = None @property def prompt(self) -> Optional[str]: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 74f7d4e0860d3..74d8ceef02a9c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -13,7 +13,9 @@ import torch import torch.distributed import torch.nn as nn +from math import floor +from vllm.entrypoints.openai.protocol import VerifyChatCompletion import vllm.envs as envs from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention.backends.abstract import AttentionState @@ -899,6 +901,7 @@ def __init__( # Lazy initialization self.model: nn.Module # Set after load_model + self.model_num_params: int # Set after load_model. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None @@ -911,6 +914,32 @@ def __init__( self.sampling_metadata_cache: SamplingMetadataCache = \ SamplingMetadataCache() + def get_powv( + self, + input: VerifyChatCompletion, + ) -> int: + """ + Calculates Proof of Work value that can be used to verify the outputs + of a model were made with the model claimed. + """ + powv = 0 + input_sum = sum(input.input_tokens) + output_sum = sum(input.response_tokens) + token_sum = input_sum + output_sum + param_index = token_sum % self.model_num_params + for k, param in enumerate(self.model.parameters()): + if k != param_index: + continue + if param.dim() == 1: + weights = param.tolist() + else: + tensor_index = output_sum % param.size()[0] + weights = param[tensor_index].tolist() + weight_index = input_sum % len(weights) + powv = floor(weights[weight_index] * token_sum) + break + return powv + def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) with CudaMemoryProfiler() as m: @@ -921,6 +950,7 @@ def load_model(self) -> None: parallel_config=self.parallel_config, scheduler_config=self.scheduler_config, cache_config=self.cache_config) + self.model_num_params = sum(1 for _ in self.model.parameters()) self.model_memory_usage = m.consumed_memory logger.info("Loading model weights took %.4f GB", @@ -1494,6 +1524,20 @@ def execute_model( logits=logits, sampling_metadata=model_input.sampling_metadata, ) + + if(model_input.input_positions is not None and model_input.sampling_metadata is not None): + seq_id = model_input.sampling_metadata.seq_groups[0].seq_ids[0] + input_tokens = ( + model_input.sampling_metadata.seq_groups[0] + .seq_data[seq_id] + .get_prompt_token_ids() + ) + output_tokens = ( + model_input.sampling_metadata.seq_groups[0] + .seq_data[seq_id] + .get_output_token_ids() + ) + output.powv = self.get_powv(VerifyChatCompletion(input_tokens=input_tokens, response_tokens=output_tokens, model=self.model_config.model)) if (self.observability_config is not None and self.observability_config.collect_model_forward_time and output is not None): diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 0ff559a9af53e..b3316f0061b37 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -6,6 +6,7 @@ import torch import torch.distributed +from vllm.entrypoints.openai.protocol import VerifyChatCompletion import vllm.envs as envs from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, @@ -180,6 +181,9 @@ def init_device(self) -> None: def load_model(self): self.model_runner.load_model() + + def verify_output(self, input: VerifyChatCompletion): + return self.model_runner.get_powv(input) def save_sharded_state( self, diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 6ba4f272315ce..546dcec9ceb68 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -9,6 +9,7 @@ from vllm.config import ObservabilityConfig from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group +from vllm.entrypoints.openai.protocol import VerifyChatCompletion from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput @@ -70,6 +71,10 @@ def start_worker_execution_loop(self) -> None: if output is None: return None + @abstractmethod + def verify_output(self,input: VerifyChatCompletion) -> int: + raise NotImplementedError + @abstractmethod def execute_model( self, @@ -448,7 +453,7 @@ def init_worker(self, *args, **kwargs): self.worker = worker_class(*args, **kwargs) assert self.worker is not None - + def execute_method(self, method, *args, **kwargs): try: target = self if self.worker is None else self.worker