diff --git a/cpp/serve/engine_actions/batch_prefill_base.cc b/cpp/serve/engine_actions/batch_prefill_base.cc index 50cdb1b8bf..827efb95c9 100644 --- a/cpp/serve/engine_actions/batch_prefill_base.cc +++ b/cpp/serve/engine_actions/batch_prefill_base.cc @@ -36,6 +36,21 @@ BatchPrefillBaseActionObj::BatchPrefillBaseActionObj(Array models, */ std::vector BatchPrefillBaseActionObj::GetRequestStateEntriesToPrefill(EngineState estate) { + // Preempt request state entries when decode cannot apply. + std::vector running_rsentries; + { + NVTXScopedRange nvtx_scope("BatchDecode getting requests"); + running_rsentries = GetRunningRequestStateEntries(estate); + while (!(running_rsentries.size() <= models_[0]->GetNumAvailablePages())) { + if (estate->prefix_cache->TryFreeMemory()) continue; + RequestStateEntry preempted = + PreemptLastRunningRequestStateEntry(estate, models_, NullOpt, trace_recorder_); + if (preempted.same_as(running_rsentries.back())) { + running_rsentries.pop_back(); + } + } + } + if (estate->waiting_queue.empty()) { // No request to prefill. return {}; @@ -44,13 +59,20 @@ BatchPrefillBaseActionObj::GetRequestStateEntriesToPrefill(EngineState estate) { std::vector> prefill_inputs_for_all_models; prefill_inputs_for_all_models.reserve(models_.size()); + int num_decode_inputs = static_cast(running_rsentries.size()); + // We first collect the inputs that can be prefilled for each model. // Then we make a reduction to return the maximum common inputs. for (int i = 0; i < static_cast(models_.size()); ++i) { std::vector prefill_inputs; - // - Try to prefill pending requests. + // - Try to prefill pending requests, in addition to reserved decode requests. int total_input_length = 0; - int total_required_pages = 0; + int total_required_pages = num_decode_inputs; + // Reserve decode requests first. + for (const RequestStateEntry& rsentry : running_rsentries) { + prefill_inputs.push_back({rsentry, rsentry->mstates[i]->num_tokens_for_next_decode, 0}); + total_input_length += rsentry->mstates[i]->num_tokens_for_next_decode; + } int num_available_pages = models_[i]->GetNumAvailablePages(); int num_running_rsentries = GetRunningRequestStateEntries(estate).size(); int current_total_seq_len = models_[i]->GetCurrentTotalSequenceLength(); @@ -177,7 +199,8 @@ BatchPrefillBaseActionObj::GetRequestStateEntriesToPrefill(EngineState estate) { std::min(num_prefill_inputs, static_cast(prefill_inputs_for_all_models[i].size())); } - if (num_prefill_inputs == 0) { + // If all inputs are decode inputs, since no prefill inputs can be added, skip prefill action + if (num_prefill_inputs == num_decode_inputs) { return {}; } @@ -259,6 +282,17 @@ bool BatchPrefillBaseActionObj::CanPrefill(EngineState estate, int num_prefill_r std::pair, int> BatchPrefillBaseActionObj::ChunkPrefillInputData( const RequestModelState& mstate, int max_prefill_length) { if (mstate->inputs.empty()) { + // If the request is a hybrid decode request + ICHECK(mstate->num_tokens_for_next_decode > 0); + int num_tokens = mstate->num_tokens_for_next_decode; + mstate->num_tokens_for_next_decode = 0; + std::vector decode_tokens; + decode_tokens.reserve(num_tokens); + for (auto begin = mstate->committed_tokens.end() - num_tokens; + begin != mstate->committed_tokens.end(); ++begin) { + decode_tokens.push_back(begin->GetTokenId()); + } + return {{TokenData(decode_tokens)}, num_tokens}; } ICHECK(!mstate->inputs.empty()); std::vector inputs; @@ -378,11 +412,14 @@ std::vector BatchPrefillBaseActionObj::RemoveProcessedRequests( break; } } - if (!pending_state_exists) { + if (!pending_state_exists && + std::find(estate->waiting_queue.begin(), estate->waiting_queue.end(), rsentry->request) != + estate->waiting_queue.end()) { auto it = std::find(estate->waiting_queue.begin(), estate->waiting_queue.end(), rsentry->request); - ICHECK(it != estate->waiting_queue.end()); - estate->waiting_queue.erase(it); + if (it != estate->waiting_queue.end()) { + estate->waiting_queue.erase(it); + } } } return processed_requests; @@ -393,6 +430,19 @@ void BatchPrefillBaseActionObj::UpdateRequestStateEntriesWithSampleResults( const std::vector& rsentry_activated, const std::vector& sample_results) { auto tnow = std::chrono::high_resolution_clock::now(); for (int i = 0; i < static_cast(rsentries_for_sample.size()); ++i) { + // If the request is a hybrid decode request + if (rsentries_for_sample[i]->status == RequestStateStatus::kAlive && + rsentries_for_sample[i]->child_indices.empty() && + rsentries_for_sample[i]->mstates[0]->inputs.empty()) { + for (const RequestModelState& mstate : rsentries_for_sample[i]->mstates) { + CHECK(!mstate->require_retokenization_in_next_decode); + mstate->CommitToken(sample_results[i]); + // live update the output metrics + rsentries_for_sample[i]->rstate->metrics.completion_tokens += 1; + } + continue; + } + // Update all model states of the request state entry. for (const RequestModelState& mstate : rsentries_for_sample[i]->mstates) { mstate->CommitToken(sample_results[i]); diff --git a/tests/python/serve/test_serve_sync_engine.py b/tests/python/serve/test_serve_sync_engine.py index b889628592..f8b9849fce 100644 --- a/tests/python/serve/test_serve_sync_engine.py +++ b/tests/python/serve/test_serve_sync_engine.py @@ -385,9 +385,91 @@ def test_engine_generate(model: str): print(f"Output {req_id}({i}):{output}\n") +@require_test_model("Llama-2-7b-chat-hf-q0f16-MLC") +def test_engine_hybrid_prefill(model: str): + """Test engine **with hybrid prefill**. + + - Add each single request step by step. + - All requests have the same generation length. But due to hybrid prefill, + the earlier request will decode with later request prefill, in single step. + So each request lasts the same steps, and stops generation step by step as well. + - Engine keeps running `step` for the generation length, to finish the last request. + Then check the output of each request. + """ + + # Hyperparameters for tests (you can try different combinations) + num_requests = 10 # [4, 8, 10] + temperature = 0.9 # [0.8, 0.9, 1.0, 1.1] + repetition_penalty = 1.00 # [1.0, 1.01] + max_tokens = 15 + np.random.seed(0) + + # Output list + outputs: List[List[int]] = [[] for _ in range(num_requests)] + finish_time: List[Optional[int]] = [None] * num_requests + + # Define the callback class for request generation results + class CallbackTimer: + timer: int = -1 + + def callback_getter(self) -> Callable[[List[RequestStreamOutput]], None]: + def fcallback(delta_outputs: List[RequestStreamOutput]): + for delta_output in delta_outputs: + request_id, stream_outputs = delta_output.unpack() + assert len(stream_outputs) == 1 + if stream_outputs[0].finish_reason is not None: + print(f"Request {request_id} finished at step {self.timer}.") + outputs[int(request_id)] += stream_outputs[0].delta_token_ids + finish_time[int(request_id)] = self.timer + + return fcallback + + def step(self) -> None: + self.timer += 1 + + # Create engine + timer = CallbackTimer() + engine = SyncMLCEngine( + model=model, + mode="server", + request_stream_callback=timer.callback_getter(), + ) + + # Create requests + requests = create_requests( + engine, + num_requests, + temperature=temperature, + repetition_penalty=repetition_penalty, + max_tokens_low=max_tokens, + max_tokens_high=max_tokens + 1, + ) + + # Add all requests to engine step by step + for step, request in enumerate(requests): + engine.add_request(request) + timer.step() + assert timer.timer == step + engine.step() + + # Run steps + for step in range(max_tokens): + timer.step() + assert timer.timer == step + num_requests + engine.step() + + for req_id, (request, output, fin_time) in enumerate(zip(requests, outputs, finish_time)): + print(f"Prompt {req_id}: {request.inputs[0]}") + print(f"Output {req_id}:{engine.tokenizer.decode(output)}\n") + assert ( + fin_time == req_id + request.generation_config.max_tokens - 1 + ), f"finish time = {fin_time}, max tokens = {req_id + request.generation_config.max_tokens - 1}" + + if __name__ == "__main__": test_engine_basic() test_engine_continuous_batching_1() test_engine_continuous_batching_2() test_engine_continuous_batching_3() test_engine_generate() + test_engine_hybrid_prefill()