Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions vllm_ascend/patch_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.
#

import time
from typing import Dict, Optional

from vllm.outputs import CompletionOutput, RequestOutput
Expand Down Expand Up @@ -84,7 +85,7 @@ def from_seq_group(
output_token_ids = seq.get_output_token_ids_to_return(delta)
num_output_tokens = 1 if isinstance(output_token_ids,
int) else len(output_token_ids)
num_cached_tokens = seq.data.get_num_cached_tokens() # noqa
num_cached_tokens = seq.data.get_num_cached_tokens()

output_logprobs = seq.output_logprobs if include_logprobs else None

Expand Down Expand Up @@ -139,7 +140,44 @@ def from_seq_group(

outputs.append(output)

return None
# Every sequence in the sequence group should have the same prompt.
if include_prompt:
prompt = seq_group.prompt
prompt_token_ids = seq_group.prompt_token_ids
encoder_prompt = seq_group.encoder_prompt
encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids
prompt_logprobs = seq_group.prompt_logprobs
else:
prompt = None
prompt_token_ids = None
encoder_prompt = None
encoder_prompt_token_ids = None
prompt_logprobs = None
finished_time = time.time() if finished else None
seq_group.set_finished_time(finished_time)

init_kwargs = {
"request_id": seq_group.request_id,
"prompt": prompt,
"prompt_token_ids": prompt_token_ids,
"prompt_logprobs": prompt_logprobs,
"outputs": outputs,
"finished": finished,
"metrics": seq_group.metrics,
"lora_request": seq_group.lora_request,
"encoder_prompt": encoder_prompt,
"encoder_prompt_token_ids": encoder_prompt_token_ids,
"num_cached_tokens": num_cached_tokens,
"multi_modal_placeholders": seq_group.multi_modal_placeholders
}

if use_cache:
request_output = seq_group.cached_request_output
request_output.__init__(**init_kwargs) # type: ignore
else:
request_output = cls(**init_kwargs) # type: ignore

return request_output


# Add code to clear finished seq in seq_id_to_seq_group
Expand Down