From 65f97160133c1264ca85bea5e940199ca778d811 Mon Sep 17 00:00:00 2001 From: krishnaraj36 Date: Wed, 8 May 2024 17:25:39 +0530 Subject: [PATCH] [LLM-CHAT] Enable gpu softmax for penality softmax (#2288) 1. Avoid the cpu softmax for different penality config by having copy sync to gpu and use gpu softmax. 2. Disable decode token time counter for first token. --- cpp/llm_chat.cc | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index 9485ccad02..93de185eb2 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -710,7 +710,7 @@ class LLMChat { /*! \brief reset the runtime stats. */ void ResetRuntimeStats() { this->prefill_total_tokens = 0; - this->decode_total_tokens = 0; + this->decode_total_tokens = -1; this->embed_total_time = 0; this->prefill_total_time = 0; this->decode_total_time = 0; @@ -1031,8 +1031,8 @@ class LLMChat { int32_t next_token = this->SampleTokenFromLogits(logits_on_device, generation_config); auto tend = std::chrono::high_resolution_clock::now(); - - this->decode_total_time += static_cast((tend - tstart).count()) / 1e9; + if (this->decode_total_tokens >= 0) + this->decode_total_time += static_cast((tend - tstart).count()) / 1e9; this->decode_total_tokens += 1; this->ProcessNextToken(next_token, generation_config); } @@ -1223,14 +1223,16 @@ class LLMChat { if (gen_presence_penalty != 0.0f || gen_frequency_penalty != 0.0f) { this->UpdateLogitsOrProbOnCPUSync(logits_on_device); this->ApplyPresenceAndFrequencyPenaltyOnCPU(gen_presence_penalty, gen_frequency_penalty); + this->UpdateLogitsOrProbOnGPUSync(logits_on_device); if (gen_temperature >= 1e-6f) { - this->ApplySoftmaxWithTemperatureOnCPU(gen_temperature); + this->UpdateLogitsOrProbOnCPUSync(this->Softmax(logits_on_device, this->temperature_arr_)); } } else if (gen_repetition_penalty != 1.0f) { this->UpdateLogitsOrProbOnCPUSync(logits_on_device); this->ApplyRepetitionPenaltyOnCPU(gen_repetition_penalty); + this->UpdateLogitsOrProbOnGPUSync(logits_on_device); if (gen_temperature >= 1e-6f) { - this->ApplySoftmaxWithTemperatureOnCPU(gen_temperature); + this->UpdateLogitsOrProbOnCPUSync(this->Softmax(logits_on_device, this->temperature_arr_)); } } else { if (gen_temperature < 1e-6f) { @@ -1505,6 +1507,12 @@ class LLMChat { TVMSynchronize(device_.device_type, device_.device_id, nullptr); } + void UpdateLogitsOrProbOnGPUSync(NDArray logits_or_prob) { + logits_or_prob.CopyFrom(logits_on_cpu_); + + TVMSynchronize(device_.device_type, device_.device_id, nullptr); + } + // Clear kv cache void ResetKVCache() { ft_.reset_kv_cache_func_(kv_cache_); @@ -1547,7 +1555,7 @@ class LLMChat { double decode_total_time = 0; double sample_total_time = 0; double prefill_total_time = 0; - int64_t decode_total_tokens = 0; + int64_t decode_total_tokens = -1; int64_t prefill_total_tokens = 0; //---------------------------- // Conversation