Skip to content

Commit

Permalink
good performance
Browse files Browse the repository at this point in the history
  • Loading branch information
alexm-redhat committed Sep 10, 2024
1 parent 1b068cb commit eebf5a8
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 119 deletions.
6 changes: 3 additions & 3 deletions tests/entrypoints/openai/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
RTOL = 0.03
EXPECTED_VALUE = 0.58


# @pytest.fixture(scope="module")
# def server():
# args = [
Expand All @@ -30,11 +29,12 @@
# with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
# yield remote_server


@pytest.fixture(scope="module")
def server():
args = [
"--max-model-len", "4096",
"--disable-log-requests", "--num-scheduler-steps", "8", "--multi-step-stream-outputs"
"--max-model-len", "4096", "--disable-log-requests",
"--num-scheduler-steps", "8", "--multi-step-stream-outputs"
]

with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MultiModalDataDict
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
RequestOutputFactory, request_output_builder)
RequestOutputFactory)
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
Expand Down
57 changes: 33 additions & 24 deletions vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,28 +403,39 @@ def is_stopped(self) -> bool:
def errored(self) -> bool:
return self._errored

def construct_from_delta(self, prev_request_output: RequestOutput,
delta_request_output: RequestOutput):
prev_output = prev_request_output.outputs[0]
delta_output = delta_request_output.outputs[0]

# Handle text
prev_output.text += delta_output.text

# Handle output tokens
if delta_output.token_ids:
if isinstance(delta_output.token_ids, list):
prev_output.token_ids.extend(delta_output.token_ids)
else:
prev_output.token_ids.append(delta_output.token_ids)
def update_request_output_from_delta(self,
prev_request_output: RequestOutput,
delta_request_output: RequestOutput):
# Sanity checks
assert delta_request_output.is_delta
assert len(prev_request_output.outputs) == len(
delta_request_output.outputs)

for i, prev_output in enumerate(prev_request_output.outputs):
delta_output = delta_request_output.outputs[i]

# Update prev_output

# Handle text
prev_output.text += delta_output.text

# Handle output tokens
if delta_output.token_ids:
if isinstance(delta_output.token_ids, list):
prev_output.token_ids.extend(delta_output.token_ids)
else:
prev_output.token_ids.append(delta_output.token_ids)

if prev_output.cumulative_logprob:
prev_output.cumulative_logprob = delta_output.cumulative_logprob

assert prev_output.cumulative_logprob is None
assert prev_output.logprobs is None
if prev_output.logprobs:
prev_output.logprobs.extend(delta_output.logprobs)

prev_output.finish_reason = delta_output.finish_reason
prev_output.stop_reason = delta_output.stop_reason
assert prev_output.lora_request is None
prev_output.finish_reason = delta_output.finish_reason
prev_output.stop_reason = delta_output.stop_reason

# Update prev_request_output
prev_request_output.finished = delta_request_output.finished
prev_request_output.metrics = delta_request_output.metrics

Expand Down Expand Up @@ -471,14 +482,12 @@ async def generate(
if isinstance(request_output, BaseException):
raise request_output

if prev_request_output:
prev_request_output = self.construct_from_delta(
if prev_request_output and request_output.is_delta:
prev_request_output = self.update_request_output_from_delta(
prev_request_output, request_output)
else:
assert not request_output.is_delta
prev_request_output = request_output
# TODO: Fix
prev_request_output.outputs[0].token_ids = list(
prev_request_output.outputs[0].token_ids)

finished = prev_request_output.finished
yield prev_request_output
Expand Down
4 changes: 3 additions & 1 deletion vllm/engine/multiprocessing/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ def __init__(self,
# (Make optional if necessary)
use_delta_outputs = True

self.engine = LLMEngine(*args, **kwargs, use_delta_outputs=use_delta_outputs)
self.engine = LLMEngine(*args,
**kwargs,
use_delta_outputs=use_delta_outputs)
self.log_requests = log_requests

self.use_async_sockets = use_async_sockets
Expand Down
183 changes: 95 additions & 88 deletions vllm/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
SequenceGroup, SequenceStatus)

from vllm.utils import PyObjectCache


@dataclass
class CompletionOutput:
Expand Down Expand Up @@ -68,23 +66,6 @@ def __repr__(self) -> str:
f"embedding={len(self.embedding)})")


def request_output_builder():
outputs = [
CompletionOutput(index=0,
text="",
token_ids=[],
cumulative_logprob=None,
logprobs=None,
finish_reason=None)
]
return RequestOutput(request_id="",
prompt=None,
prompt_token_ids=[],
prompt_logprobs=None,
outputs=outputs,
finished=False)


class RequestOutput:
"""The output data of a completion request to the LLM.
Expand Down Expand Up @@ -130,98 +111,87 @@ def __init__(
self.lora_request = lora_request
self.encoder_prompt = encoder_prompt
self.encoder_prompt_token_ids = encoder_prompt_token_ids
self.is_delta = False

@classmethod
def delta_from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
assert seq_group.sampling_params is not None
seqs = seq_group.get_seqs()
assert len(seqs) == 1
seq = seqs[0]

# TODO: Fix 1 assumption
include_logprobs = seq_group.sampling_params.logprobs is not None
text_buffer_length = seq_group.sampling_params.output_text_buffer_length

prev_request_output = seq_group.request_output
delta_request_output = seq_group.delta_request_output

include_logprobs = seq_group.sampling_params.logprobs is not None
text_buffer_length = seq_group.sampling_params.output_text_buffer_length
num_seqs = len(seq_group.seqs)
assert len(prev_request_output.outputs) == num_seqs
assert len(delta_request_output.outputs) == num_seqs

for i in range(num_seqs):
seq = seq_group.seqs[i]

# TODO: Fix this
assert not include_logprobs
# assert text_buffer_length == 0, "text_buffer_length = {}".format(text_buffer_length)
prev_output: CompletionOutput = prev_request_output.outputs[i]
delta_output: CompletionOutput = delta_request_output.outputs[i]
assert prev_output.index == i and delta_output.index == i

# print("seq_group.request_id = {}".format(seq_group.request_id))
# print(" text_buffer_length = {}".format(text_buffer_length))
# print(" output.text = {}".format(seq.get_output_text_to_return(text_buffer_length)))
# print(" len(seq.data._output_token_ids) = {}".format(len(seq.data._output_token_ids)))
# print(" output.finish_reason = {}".format(SequenceStatus.get_finished_reason(seq.status)))
# print(" output.stop_reason = {}".format(seq.stop_reason))
# Update delta_output

# Init CompletionOutput
prev_output: CompletionOutput = prev_request_output.outputs[0]
delta_output: CompletionOutput = delta_request_output.outputs[0]
# Handle text
prev_text = prev_output.text
cur_text = seq.get_output_text_to_return(text_buffer_length)
delta_text = cur_text[len(prev_text):]
delta_output.text = delta_text

assert prev_output.index == 0
assert delta_output.index == 0
# Handle output tokens
if seq.data.last_appended_tokens:
if len(seq.data.last_appended_tokens) == 1:
delta_output.token_ids = seq.data.last_appended_tokens[0]
else:
delta_output.token_ids = seq.data.last_appended_tokens
seq.data.last_appended_tokens.clear()
else:
delta_output.token_ids = None

# Handle text
prev_text = prev_output.text
cur_text = seq.get_output_text_to_return(text_buffer_length)
delta_text = cur_text[len(prev_text):]
delta_output.cumulative_logprob = seq.get_cumulative_logprob(
) if include_logprobs else None

delta_output.text = delta_text
# Handle logprobs
delta_output.logprobs = seq.output_logprobs[
len(prev_output.logprobs):] if include_logprobs else None

# Handle tokens
if seq.data.last_appended_tokens:
if len(seq.data.last_appended_tokens) == 1:
delta_output.token_ids = seq.data.last_appended_tokens[0]
else:
delta_output.token_ids = seq.data.last_appended_tokens
seq.data.last_appended_tokens.clear()
else:
delta_output.token_ids = None
delta_output.finish_reason = SequenceStatus.get_finished_reason(
seq.status)
delta_output.stop_reason = seq.stop_reason

delta_output.cumulative_logprob = seq.get_cumulative_logprob(
) if include_logprobs else None
# Update prev_output
prev_output.text = cur_text
prev_output.token_ids = seq.data._output_token_ids
prev_output.cumulative_logprob = delta_output.cumulative_logprob

# TODO: Fix this
delta_output.logprobs = seq.output_logprobs if include_logprobs else None
prev_output.logprobs = seq.output_logprobs if include_logprobs else None

delta_output.finish_reason = SequenceStatus.get_finished_reason(
seq.status)
delta_output.stop_reason = seq.stop_reason
delta_output.lora_request = None
prev_output.finish_reason = SequenceStatus.get_finished_reason(
seq.status)
prev_output.stop_reason = seq.stop_reason

# Update finish time
finished_time = time.time() if seq_group.is_finished() else None
seq_group.set_finished_time(finished_time)

# Init delta_request_output
# Update delta_request_output
assert delta_request_output.request_id == seq_group.request_id
delta_request_output.prompt = ""
delta_request_output.prompt_token_ids.clear()
if delta_request_output.prompt_logprobs is not None:
delta_request_output.prompt_logprobs.clear()

# Update finish state
delta_request_output.finished = seq_group.is_finished()
delta_request_output.metrics = seq_group.metrics
delta_request_output.lora_request = seq_group.lora_request
delta_request_output.encoder_prompt = ""
if delta_request_output.encoder_prompt_token_ids is not None:
delta_request_output.encoder_prompt_token_ids.clear()

# Update request_output
prev_output.text = cur_text
prev_output.token_ids = seq.data._output_token_ids
prev_output.cumulative_logprob = seq.get_cumulative_logprob(
) if include_logprobs else None

prev_output.logprobs = seq.output_logprobs if include_logprobs else None

prev_output.finish_reason = SequenceStatus.get_finished_reason(
seq.status)
prev_output.stop_reason = seq.stop_reason
prev_output.lora_request = None
# Take metrics only when finished
# (TODO: Allow metrics with a delta)
if delta_request_output.finished:
delta_request_output.metrics = seq_group.metrics
else:
assert delta_request_output.metrics is None

# Update prev_request_output
prev_request_output.finished = seq_group.is_finished()
prev_request_output.metrics = seq_group.metrics

Expand Down Expand Up @@ -283,6 +253,44 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
encoder_prompt=encoder_prompt,
encoder_prompt_token_ids=encoder_prompt_token_ids)

@classmethod
def init_seq_group_request_output(cls, seq_group: SequenceGroup,
request_output: "RequestOutput") -> None:
assert seq_group.request_output is None
seq_group.request_output = request_output

# Create delta request output
num_seqs = len(seq_group.seqs)
outputs = []
for i in range(num_seqs):
outputs.append(
CompletionOutput(index=i,
text="",
token_ids=None,
cumulative_logprob=None,
logprobs=None,
finish_reason=None,
stop_reason=None))

delta_request_output = RequestOutput(
request_id=request_output.request_id,
prompt=None,
prompt_token_ids=[],
prompt_logprobs=None,
outputs=outputs,
finished=False)

# Mark as delta request (so the client is aware)
delta_request_output.is_delta = True

assert seq_group.delta_request_output is None
seq_group.delta_request_output = delta_request_output

# Ensure we use list for output tokens
# (More efficient for delta transfers)
for output in seq_group.request_output.outputs:
output.token_ids = list(output.token_ids)

def __repr__(self) -> str:
return (f"RequestOutput(request_id={self.request_id}, "
f"prompt={self.prompt!r}, "
Expand Down Expand Up @@ -355,11 +363,10 @@ def create(seq_group: SequenceGroup, use_delta_outputs: bool = False):
if seq_group.request_output:
return RequestOutput.delta_from_seq_group(seq_group)
else:
seq_group.request_output = RequestOutput.from_seq_group(
seq_group)
seq_group.delta_request_output = RequestOutput.from_seq_group(
seq_group)
return seq_group.request_output
request_output = RequestOutput.from_seq_group(seq_group)
RequestOutput.init_seq_group_request_output(
seq_group, request_output)
return request_output

else:
return RequestOutput.from_seq_group(seq_group)
4 changes: 2 additions & 2 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ class SequenceData(msgspec.Struct,
_prompt_token_ids: array
_output_token_ids: array = msgspec.field(
default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, []))

### The below fields should not be passed as an argument ###
_cumulative_logprob: float = 0.0
_prompt_token_ids_tuple: Tuple[int,
Expand Down Expand Up @@ -224,7 +224,7 @@ def output_token_ids_array(self) -> array:

def append_token_id(self, token_id: int, logprob: float) -> None:
self.last_appended_tokens.append(token_id)

self._output_token_ids.append(token_id)
self._new_appended_tokens.append(token_id)
self._cached_all_token_ids.append(token_id)
Expand Down

0 comments on commit eebf5a8

Please sign in to comment.