Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Serving] Hybrid prefill #2604

Merged
merged 2 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 56 additions & 6 deletions cpp/serve/engine_actions/batch_prefill_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,21 @@ BatchPrefillBaseActionObj::BatchPrefillBaseActionObj(Array<Model> models,
*/
std::vector<BatchPrefillBaseActionObj::PrefillInput>
BatchPrefillBaseActionObj::GetRequestStateEntriesToPrefill(EngineState estate) {
// Preempt request state entries when decode cannot apply.
std::vector<RequestStateEntry> running_rsentries;
{
NVTXScopedRange nvtx_scope("BatchDecode getting requests");
running_rsentries = GetRunningRequestStateEntries(estate);
while (!(running_rsentries.size() <= models_[0]->GetNumAvailablePages())) {
if (estate->prefix_cache->TryFreeMemory()) continue;
RequestStateEntry preempted =
PreemptLastRunningRequestStateEntry(estate, models_, NullOpt, trace_recorder_);
if (preempted.same_as(running_rsentries.back())) {
running_rsentries.pop_back();
}
}
}

if (estate->waiting_queue.empty()) {
// No request to prefill.
return {};
Expand All @@ -44,13 +59,20 @@ BatchPrefillBaseActionObj::GetRequestStateEntriesToPrefill(EngineState estate) {
std::vector<std::vector<PrefillInput>> prefill_inputs_for_all_models;
prefill_inputs_for_all_models.reserve(models_.size());

int num_decode_inputs = static_cast<int>(running_rsentries.size());

// We first collect the inputs that can be prefilled for each model.
// Then we make a reduction to return the maximum common inputs.
for (int i = 0; i < static_cast<int>(models_.size()); ++i) {
std::vector<PrefillInput> prefill_inputs;
// - Try to prefill pending requests.
// - Try to prefill pending requests, in addition to reserved decode requests.
int total_input_length = 0;
int total_required_pages = 0;
int total_required_pages = num_decode_inputs;
// Reserve decode requests first.
for (const RequestStateEntry& rsentry : running_rsentries) {
prefill_inputs.push_back({rsentry, rsentry->mstates[i]->num_tokens_for_next_decode, 0});
total_input_length += rsentry->mstates[i]->num_tokens_for_next_decode;
}
int num_available_pages = models_[i]->GetNumAvailablePages();
int num_running_rsentries = GetRunningRequestStateEntries(estate).size();
int current_total_seq_len = models_[i]->GetCurrentTotalSequenceLength();
Expand Down Expand Up @@ -177,7 +199,8 @@ BatchPrefillBaseActionObj::GetRequestStateEntriesToPrefill(EngineState estate) {
std::min(num_prefill_inputs, static_cast<int>(prefill_inputs_for_all_models[i].size()));
}

if (num_prefill_inputs == 0) {
// If all inputs are decode inputs, since no prefill inputs can be added, skip prefill action
if (num_prefill_inputs == num_decode_inputs) {
return {};
}

Expand Down Expand Up @@ -259,6 +282,17 @@ bool BatchPrefillBaseActionObj::CanPrefill(EngineState estate, int num_prefill_r
std::pair<Array<Data>, int> BatchPrefillBaseActionObj::ChunkPrefillInputData(
const RequestModelState& mstate, int max_prefill_length) {
if (mstate->inputs.empty()) {
// If the request is a hybrid decode request
ICHECK(mstate->num_tokens_for_next_decode > 0);
int num_tokens = mstate->num_tokens_for_next_decode;
mstate->num_tokens_for_next_decode = 0;
std::vector<int32_t> decode_tokens;
decode_tokens.reserve(num_tokens);
for (auto begin = mstate->committed_tokens.end() - num_tokens;
begin != mstate->committed_tokens.end(); ++begin) {
decode_tokens.push_back(begin->GetTokenId());
}
return {{TokenData(decode_tokens)}, num_tokens};
}
ICHECK(!mstate->inputs.empty());
std::vector<Data> inputs;
Expand Down Expand Up @@ -378,11 +412,14 @@ std::vector<Request> BatchPrefillBaseActionObj::RemoveProcessedRequests(
break;
}
}
if (!pending_state_exists) {
if (!pending_state_exists &&
std::find(estate->waiting_queue.begin(), estate->waiting_queue.end(), rsentry->request) !=
estate->waiting_queue.end()) {
auto it =
std::find(estate->waiting_queue.begin(), estate->waiting_queue.end(), rsentry->request);
ICHECK(it != estate->waiting_queue.end());
estate->waiting_queue.erase(it);
if (it != estate->waiting_queue.end()) {
estate->waiting_queue.erase(it);
}
}
}
return processed_requests;
Expand All @@ -393,6 +430,19 @@ void BatchPrefillBaseActionObj::UpdateRequestStateEntriesWithSampleResults(
const std::vector<bool>& rsentry_activated, const std::vector<SampleResult>& sample_results) {
auto tnow = std::chrono::high_resolution_clock::now();
for (int i = 0; i < static_cast<int>(rsentries_for_sample.size()); ++i) {
// If the request is a hybrid decode request
if (rsentries_for_sample[i]->status == RequestStateStatus::kAlive &&
rsentries_for_sample[i]->child_indices.empty() &&
rsentries_for_sample[i]->mstates[0]->inputs.empty()) {
for (const RequestModelState& mstate : rsentries_for_sample[i]->mstates) {
CHECK(!mstate->require_retokenization_in_next_decode);
mstate->CommitToken(sample_results[i]);
// live update the output metrics
rsentries_for_sample[i]->rstate->metrics.completion_tokens += 1;
}
continue;
}

// Update all model states of the request state entry.
for (const RequestModelState& mstate : rsentries_for_sample[i]->mstates) {
mstate->CommitToken(sample_results[i]);
Expand Down
82 changes: 82 additions & 0 deletions tests/python/serve/test_serve_sync_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,9 +385,91 @@ def test_engine_generate(model: str):
print(f"Output {req_id}({i}):{output}\n")


@require_test_model("Llama-2-7b-chat-hf-q0f16-MLC")
def test_engine_hybrid_prefill(model: str):
"""Test engine **with hybrid prefill**.

- Add each single request step by step.
- All requests have the same generation length. But due to hybrid prefill,
the earlier request will decode with later request prefill, in single step.
So each request lasts the same steps, and stops generation step by step as well.
- Engine keeps running `step` for the generation length, to finish the last request.
Then check the output of each request.
"""

# Hyperparameters for tests (you can try different combinations)
num_requests = 10 # [4, 8, 10]
temperature = 0.9 # [0.8, 0.9, 1.0, 1.1]
repetition_penalty = 1.00 # [1.0, 1.01]
max_tokens = 15
np.random.seed(0)

# Output list
outputs: List[List[int]] = [[] for _ in range(num_requests)]
finish_time: List[Optional[int]] = [None] * num_requests

# Define the callback class for request generation results
class CallbackTimer:
timer: int = -1

def callback_getter(self) -> Callable[[List[RequestStreamOutput]], None]:
def fcallback(delta_outputs: List[RequestStreamOutput]):
for delta_output in delta_outputs:
request_id, stream_outputs = delta_output.unpack()
assert len(stream_outputs) == 1
if stream_outputs[0].finish_reason is not None:
print(f"Request {request_id} finished at step {self.timer}.")
outputs[int(request_id)] += stream_outputs[0].delta_token_ids
finish_time[int(request_id)] = self.timer

return fcallback

def step(self) -> None:
self.timer += 1

# Create engine
timer = CallbackTimer()
engine = SyncMLCEngine(
model=model,
mode="server",
request_stream_callback=timer.callback_getter(),
)

# Create requests
requests = create_requests(
engine,
num_requests,
temperature=temperature,
repetition_penalty=repetition_penalty,
max_tokens_low=max_tokens,
max_tokens_high=max_tokens + 1,
)

# Add all requests to engine step by step
for step, request in enumerate(requests):
engine.add_request(request)
timer.step()
assert timer.timer == step
engine.step()

# Run steps
for step in range(max_tokens):
timer.step()
assert timer.timer == step + num_requests
engine.step()

for req_id, (request, output, fin_time) in enumerate(zip(requests, outputs, finish_time)):
print(f"Prompt {req_id}: {request.inputs[0]}")
print(f"Output {req_id}:{engine.tokenizer.decode(output)}\n")
assert (
fin_time == req_id + request.generation_config.max_tokens - 1
), f"finish time = {fin_time}, max tokens = {req_id + request.generation_config.max_tokens - 1}"


if __name__ == "__main__":
test_engine_basic()
test_engine_continuous_batching_1()
test_engine_continuous_batching_2()
test_engine_continuous_batching_3()
test_engine_generate()
test_engine_hybrid_prefill()
Loading