Skip to content

Commit

Permalink
Add function invoke call for underlying models (vllm-project#1)
Browse files Browse the repository at this point in the history
  • Loading branch information
gc-fu authored Oct 17, 2023
1 parent ac5cf86 commit 9895bbd
Showing 1 changed file with 95 additions and 0 deletions.
95 changes: 95 additions & 0 deletions tests/under_models/send_mock_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""Try sending a mocked request to the underlying model execute stage"""

from vllm.sequence import SequenceGroupMetadata
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from vllm.sequence import SequenceData
from vllm.outputs import RequestOutput
from vllm.engine.arg_utils import AsyncEngineArgs
import pytest

from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
Union)
from functools import partial
import asyncio

# This is the model to load for workers
MODEL_PATH="/models/vicuna-7b/"


"""
1. Prepare a faked sequencegroup meta data
2. Start a mocked AsyncLLMEngine, and modify its step_async function
3. invoke the step_async function manually
"""

class UglyAsyncLLMEngine(LLMEngine):
"""Extension of LLMEngine to add async methods."""

async def step_async(self) -> List[RequestOutput]:
sampling_para = SamplingParams(n=2, best_of=5, temperature=0.8, top_p=0.95, max_tokens=7)
seq_data = {}
seq_data[0] = SequenceData(prompt_token_ids=[1, 3087, 8970, 338, 263])
request_id = "cmpl-7bef75eaa4394a3d895b5508dd5f69f6"

seq_group_meta_data = SequenceGroupMetadata(request_id=request_id, is_prompt=True, seq_data=seq_data, sampling_params=sampling_para, block_tables={})
seq_group_meta_data_lists = [seq_group_meta_data]

output = await self._run_workers_async(
"execute_model",
seq_group_metadata_list=seq_group_meta_data_lists,
blocks_to_swap_in={},
blocks_to_swap_out={},
blocks_to_copy={},
)

# TODO: change this to real one
return RequestOutput(request_id=request_id, prompt="", prompt_token_ids=[1, 3087, 8970, 338, 263], outputs=[], finished=False)


async def _run_workers_async(
self,
method: str,
*args,
get_all_outputs: bool = False,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
all_outputs = []
for worker in self.workers:
if self.parallel_config.worker_use_ray:
executor = partial(worker.execute_method.remote, method)
else:
executor = getattr(worker, method)

output = executor(*args, **kwargs)
all_outputs.append(output)

if self.parallel_config.worker_use_ray:
all_outputs = await asyncio.gather(*all_outputs)

if get_all_outputs:
return all_outputs

# Make sure all workers have the same results.
output = all_outputs[0]
for other_output in all_outputs[1:]:
assert output == other_output
return output

setattr(AsyncLLMEngine, "_engine_class", UglyAsyncLLMEngine)


@pytest.mark.asyncio
async def test_model_execution():
# Let's build an engine_args
engine_args = AsyncEngineArgs(model='/models/vicuna-7b/', tokenizer='/models/vicuna-7b/', tokenizer_mode='auto', trust_remote_code=False, download_dir=None, load_format='auto', dtype='auto', seed=0, max_model_len=None, worker_use_ray=False, pipeline_parallel_size=1, tensor_parallel_size=1, block_size=16, swap_space=16, gpu_memory_utilization=0.9, max_num_batched_tokens=None, max_num_seqs=256, disable_log_stats=False, revision=None, tokenizer_revision=None, quantization=None, engine_use_ray=False, disable_log_requests=True, max_log_len=None)
# Start the engine
engine = AsyncLLMEngine.from_engine_args(engine_args)

engine.start_background_loop()
await asyncio.sleep(2)
await engine.engine.step_async()


0 comments on commit 9895bbd

Please sign in to comment.