Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel sampling eviction #157

Merged
merged 42 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
e0ef4c6
add new model for evaluating logits over multiple queries using KV cache
masahi Jan 10, 2024
4ccbb27
add test
masahi Jan 10, 2024
f1314a5
clean
masahi Jan 10, 2024
2bee022
Only the number of past tokens is needed
masahi Jan 10, 2024
756b09f
fix build
masahi Jan 10, 2024
09ef5b3
fix
masahi Jan 11, 2024
7b67ba4
correctly handle num_past_tokens > sliding_window case
masahi Jan 11, 2024
e0517fd
wip
masahi Jan 10, 2024
cf89a5b
blac
masahi Jan 10, 2024
9ca4806
wip
masahi Jan 10, 2024
4541b4d
wip
masahi Jan 10, 2024
5d376d2
remove cancel call back in eviction
masahi Jan 10, 2024
59c36cc
Create MultiQueryDecodeRequest
masahi Jan 10, 2024
f58acf7
only the number of past tokens is needed
masahi Jan 10, 2024
d9dd2ca
wip
masahi Jan 10, 2024
cb11761
wip
masahi Jan 10, 2024
24f7bfa
wip
masahi Jan 10, 2024
34da221
fix
masahi Jan 10, 2024
d94e9d8
wip
masahi Jan 10, 2024
4a3bb77
wip
masahi Jan 11, 2024
0c6875e
wip
masahi Jan 11, 2024
a46abe1
wip
masahi Jan 11, 2024
c80bea2
working?
masahi Jan 11, 2024
18239a4
remove dbg print
masahi Jan 11, 2024
fd2b2bd
multi gpu works
masahi Jan 11, 2024
6ac292b
fixed sliding window logic
masahi Jan 11, 2024
2f9d1f7
remove dbug print
masahi Jan 11, 2024
3a9f6d6
clean and fix
masahi Jan 11, 2024
9fb9261
mypy
masahi Jan 11, 2024
906b23b
generate signature update
masahi Jan 13, 2024
2c1aa04
Merge branch 'batch-serving' into parallel-sampling-eviction
masahi Jan 13, 2024
b197e71
more
masahi Jan 13, 2024
2dfa28d
fix mypy
masahi Jan 13, 2024
e287c5f
fix
masahi Jan 13, 2024
417750c
Merge branch 'batch-serving' into parallel-sampling-eviction
masahi Jan 31, 2024
c925c52
fix
masahi Jan 31, 2024
a4d6e01
mypy fix
masahi Jan 31, 2024
7360392
Merge branch 'batch-serving' into parallel-sampling-eviction
masahi Feb 1, 2024
5dbf73e
refactor
masahi Feb 1, 2024
78a6f77
fix
masahi Feb 1, 2024
9189697
rename
masahi Feb 1, 2024
d4fe2d7
Disallow preempting when a request has generated more than max_num_ba…
masahi Feb 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions serve/mlc_serve/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
class RawLogprobsInfo:
current_token_id: int
current_logprob: float
top_token_ids: Optional[np.array]
top_logprobs: Optional[np.array]
top_token_ids: Optional[np.ndarray]
top_logprobs: Optional[np.ndarray]

RawLogprobsInfos = List[Optional[RawLogprobsInfo]]

Expand Down
119 changes: 94 additions & 25 deletions serve/mlc_serve/engine/engine_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from .model_module import (
DecodeRequest,
PrefillRequest,
EvalMultiQueryRequest,
EvictedTokens,
ConversationTemplate,
KVCacheManager,
ModelModule,
Expand Down Expand Up @@ -226,26 +228,70 @@ def update_sequence(

def get_requests_to_process(
current_states: list[RequestState], cache_manager: KVCacheManager
) -> Tuple[list[Union[PrefillRequest, DecodeRequest]], bool, int]:
requests: list[Union[PrefillRequest, DecodeRequest]] = []
) -> Tuple[
list[Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]], bool, int
]:
requests: list[Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]] = []
# TODO: consider having hybrid batch if the underlying attention kernel supports
# mixing prefill and decode.
is_prompt_batch = any(not state.is_prefilled for state in current_states)

token_counts = 0

is_evicted_parallel_sampling_request = (
lambda state: not state.is_prefilled
and state.num_sequences > 1
and any(
len(gen_seq.generated_token_ids) > 0
for gen_seq in state.generation_sequences
)
)

if is_prompt_batch:
for state in current_states:
if not state.is_prefilled:
if is_evicted_parallel_sampling_request(state):
requests.append(
PrefillRequest(
request_id=state.request_id,
token_ids=state.prompt_token_ids,
num_sequence=state.num_sequences,
sampling_params=state.sampling_params,
)
)

token_counts += len(state.prompt_token_ids)

for gen_seq in state.generation_sequences:
requests.append(
EvalMultiQueryRequest(
sequence_id=gen_seq.seq_id,
num_past_tokens=state.prompt_len,
queries=EvictedTokens(gen_seq.generated_token_ids),
sampling_params=state.sampling_params,
)
)
cache_manager.extend(
gen_seq.seq_id,
len(gen_seq.generated_token_ids) + 1,
)

# TODO(masahi): How to account for token counts in EvalMultiQueryRequest in
# Prometheus metric?
elif not state.is_prefilled:
token_ids = state.prompt_token_ids
# generated_token_ids is added for the case where the request is
# recovering from cache eviction.

if (
state.num_sequences == 1
and state.generation_sequences[0].generated_token_ids
):
token_ids += state.generation_sequences[0].generated_token_ids

requests.append(
# generated_token_ids is added for the case where the request is
# recovering from cache eviction.
# TODO(masahi): This needs an update when we support evicting
# a parallel-sampling request.
PrefillRequest(
request_id=state.request_id,
token_ids=state.prompt_token_ids
+ state.generation_sequences[0].generated_token_ids,
token_ids=token_ids,
num_sequence=state.num_sequences,
sampling_params=state.sampling_params,
)
Expand Down Expand Up @@ -392,16 +438,28 @@ def evict_request(self, cancell_callback: Callable[[RequestId], None]) -> int:
candidate_victims = parallel_sample_requests

request_to_remove = min(candidate_victims, key=lambda s: s.num_total_tokens)

# TODO(masahi): Properly support evicting a multi-sequence request
if self.current_batch[request_to_remove.request_id].num_sequences != 1:
cancell_callback(request_to_remove.request_id)
self.remove_request_from_batch(request_to_remove.request_id)
LOG.warn(
"Preempting a multi-sequence request is currently not supported,"
f" cancelling request '{request_to_remove.request_id}'",
victim_state = self.current_batch[request_to_remove.request_id]

if victim_state.num_sequences != 1:
prev_generated_token_counts = sum(
[
len(gen_seq.generated_token_ids)
for gen_seq in victim_state.generation_sequences
]
)
continue
# We could allow evicting and restoring a parallel-sampling request whose prev_generated_token_counts
# is > max_num_batched_tokens, by making the model split a list of EvalMultiQuery requests into parts,
# so that an inference on each part can be done with the max_num_batched_tokens budget.
# But this introduces an undesirable coupling between the engine and the model.
if prev_generated_token_counts >= self.max_num_batched_tokens:
cancell_callback(request_to_remove.request_id)
self.remove_request_from_batch(request_to_remove.request_id)
LOG.warn(
f"Cancelling a parallel-sampling request '{request_to_remove.request_id}'"
f"since it has generated more than {self.max_num_batched_tokens} tokens in total"
"and currently we do not support preempting such request.",
)
continue
Copy link
Member Author

@masahi masahi Feb 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sunggg @elvin-n Please be aware of this limitation. Due to this, there is still a case when a parallel-sampling request is cancelled rather than preempted.

In general, we don't have a good solution for preempting a request which has generated more than max_num_batched_tokens tokens. See also #163. The easiest solution would be to stop generation at max_num_batched_tokens, but then we cannot support "unlimited" generation.


self.remove_request_from_batch(request_to_remove.request_id)
request_to_remove.is_prefilled = False
Expand Down Expand Up @@ -446,14 +504,27 @@ def try_grow_batch(self, num_new_batched_tokens) -> Optional[int]:
gen_seq.next_start_position = (
num_new_batched_tokens
) = num_tokens = self.max_num_batched_tokens

num_kv_slots_needed = min(num_tokens, self.model_context_window_size)
else:
# Evicting and recovering multi-sequence requests is not supported for now.
assert all(
gen_seq.next_start_position == state.prompt_len
for gen_seq in state.generation_sequences
prev_generated_token_counts = sum(
[
len(gen_seq.generated_token_ids)
for gen_seq in state.generation_sequences
]
)

# Restoring an evicted parallel-sampling request with sliding-window attention is
# difficult to reason about, so we use crude upper bounds below for now.
num_tokens = state.prompt_len
num_new_batched_tokens += num_tokens
num_kv_slots_needed = state.prompt_len + prev_generated_token_counts
# Restoring an evicted parallel-sampling request is done by separate
# Prefill and MultiQuery requests. The maximum below is an upper bound on the
# batch size increase due to this request.
# TODO(masahi): Prefill and EvalMultiQuery requests are handled separately by the model.
# So comparing the sum of their batched token counts against max_num_batched_tokens
# is not optimal.
num_new_batched_tokens += max(state.prompt_len, prev_generated_token_counts)

if num_new_batched_tokens > self.max_num_batched_tokens:
LOG.debug(
Expand All @@ -465,7 +536,6 @@ def try_grow_batch(self, num_new_batched_tokens) -> Optional[int]:
# We make sure that the KV cache will have enough free space for this request to proceed
# decoding for at least self.max_decode_steps steps.
# See the comment in check_prompt_too_long for the optimization involving the window size.
num_kv_slots_needed = min(num_tokens, self.model_context_window_size)
if (self.cache_manager.get_free_space() - num_kv_slots_needed) / (
len(self.current_batch) + 1
) < self.max_decode_steps * state.num_sequences:
Expand All @@ -477,7 +547,6 @@ def try_grow_batch(self, num_new_batched_tokens) -> Optional[int]:
return None

self.queue.popleft()
# TODO parallel sampling: Need update here when evicting multi-sequence requests is supported.
self.cache_manager.allocate(state.request_id, num_tokens, state.num_sequences)
self.current_batch[state.request_id] = state

Expand Down
30 changes: 28 additions & 2 deletions serve/mlc_serve/engine/model_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,37 @@ class PrefillRequest:
class DecodeRequest:
sequence_id: SequenceId
prompt_token_counts: int
# All tokens for this request, including prompt
# Decoded tokens for this sequence
token_ids: List[int]
sampling_params: SamplingParams


@dataclass
class DraftTokens:
token_ids: List[int]

@property
def num_tokens(self):
return len(self.token_ids)


@dataclass
class EvictedTokens:
token_ids: List[int]

@property
def num_tokens(self):
return len(self.token_ids)


@dataclass
class EvalMultiQueryRequest:
sequence_id: SequenceId
num_past_tokens: int
queries: Union[DraftTokens, EvictedTokens]
sampling_params: SamplingParams


@dataclass
class TextGenerationResult:
"""
Expand Down Expand Up @@ -125,7 +151,7 @@ class TextGenerator(Protocol):

def generate(
self,
requests: Sequence[Union[PrefillRequest, DecodeRequest]],
requests: Sequence[Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]],
kv_cache,
) -> List[TextGenerationResult]:
"""
Expand Down
Loading
Loading