Skip to content

Commit

Permalink
Refactor to Nick's suggestion to use _cached_all_token_ids
Browse files Browse the repository at this point in the history
  • Loading branch information
alexm-redhat committed Sep 21, 2024
1 parent ec3feb2 commit 4876a2e
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,6 @@ class SequenceData(msgspec.Struct,
# It is used to compute mrope_position_ids.
_mrope_position_delta: Optional[int] = None

# Used to quickly access the last appended tokens between scheduler
# iterations
last_appended_tokens: List[int] = []

def __post_init__(self) -> None:
assert self._prompt_token_ids.typecode == "l"
assert self._output_token_ids.typecode == "l"
Expand Down Expand Up @@ -236,8 +232,6 @@ def mrope_position_delta(self, new_mrope_position_delta):
self._mrope_position_delta = new_mrope_position_delta

def append_token_id(self, token_id: int, logprob: float) -> None:
self.last_appended_tokens.append(token_id)

self._output_token_ids.append(token_id)
self._new_appended_tokens.append(token_id)
self._cached_all_token_ids.append(token_id)
Expand Down Expand Up @@ -415,7 +409,7 @@ def __init__(
self.stop_reason: Union[int, str, None] = None

# These are used to keep track of delta outputs
self._last_token_ids_offset: int = 0
self._last_output_token_ids_offset: int = 0
self._last_output_text_offset: int = 0

# Used for incremental detokenization
Expand Down Expand Up @@ -499,16 +493,22 @@ def get_output_token_ids_to_return(
if not delta:
return self.get_output_token_ids()

# Optimization for single decode token case
# (which is what we have most of the time)
if len(self.data.last_appended_tokens) == 1:
new_token = self.data.last_appended_tokens[0]
self.data.last_appended_tokens.clear()
return new_token
prompt_len = self.get_prompt_len()
output_len = self.get_output_len()

# Get the number of new tokens
output_last_offset = self._last_output_token_ids_offset
num_new_tokens = output_len - self._last_output_token_ids_offset
self._last_output_token_ids_offset = output_len

# Return new tokens
if num_new_tokens == 1:
# Optimization for single decode token case
# (which is what we have most of the time)
return self.data._cached_all_token_ids[-1]
else:
new_tokens = self.data.last_appended_tokens
self.data.last_appended_tokens = []
return new_tokens
return self.data._cached_all_token_ids[prompt_len +
output_last_offset:]

def hash_of_block(self, logical_idx: int) -> int:
# TODO This can produce incorrect hash when block size > prompt size
Expand Down

0 comments on commit 4876a2e

Please sign in to comment.