Skip to content

Commit

Permalink
[LLM-CHAT] Enable gpu softmax for penality softmax (#2288)
Browse files Browse the repository at this point in the history
1. Avoid the cpu softmax for different penality config by
  having copy sync to gpu and use gpu softmax.
2. Disable decode token time counter for first token.
  • Loading branch information
krishnaraj36 authored May 8, 2024
1 parent 8a31986 commit 65f9716
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions cpp/llm_chat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ class LLMChat {
/*! \brief reset the runtime stats. */
void ResetRuntimeStats() {
this->prefill_total_tokens = 0;
this->decode_total_tokens = 0;
this->decode_total_tokens = -1;
this->embed_total_time = 0;
this->prefill_total_time = 0;
this->decode_total_time = 0;
Expand Down Expand Up @@ -1031,8 +1031,8 @@ class LLMChat {
int32_t next_token = this->SampleTokenFromLogits(logits_on_device, generation_config);

auto tend = std::chrono::high_resolution_clock::now();

this->decode_total_time += static_cast<double>((tend - tstart).count()) / 1e9;
if (this->decode_total_tokens >= 0)
this->decode_total_time += static_cast<double>((tend - tstart).count()) / 1e9;
this->decode_total_tokens += 1;
this->ProcessNextToken(next_token, generation_config);
}
Expand Down Expand Up @@ -1223,14 +1223,16 @@ class LLMChat {
if (gen_presence_penalty != 0.0f || gen_frequency_penalty != 0.0f) {
this->UpdateLogitsOrProbOnCPUSync(logits_on_device);
this->ApplyPresenceAndFrequencyPenaltyOnCPU(gen_presence_penalty, gen_frequency_penalty);
this->UpdateLogitsOrProbOnGPUSync(logits_on_device);
if (gen_temperature >= 1e-6f) {
this->ApplySoftmaxWithTemperatureOnCPU(gen_temperature);
this->UpdateLogitsOrProbOnCPUSync(this->Softmax(logits_on_device, this->temperature_arr_));
}
} else if (gen_repetition_penalty != 1.0f) {
this->UpdateLogitsOrProbOnCPUSync(logits_on_device);
this->ApplyRepetitionPenaltyOnCPU(gen_repetition_penalty);
this->UpdateLogitsOrProbOnGPUSync(logits_on_device);
if (gen_temperature >= 1e-6f) {
this->ApplySoftmaxWithTemperatureOnCPU(gen_temperature);
this->UpdateLogitsOrProbOnCPUSync(this->Softmax(logits_on_device, this->temperature_arr_));
}
} else {
if (gen_temperature < 1e-6f) {
Expand Down Expand Up @@ -1505,6 +1507,12 @@ class LLMChat {
TVMSynchronize(device_.device_type, device_.device_id, nullptr);
}

void UpdateLogitsOrProbOnGPUSync(NDArray logits_or_prob) {
logits_or_prob.CopyFrom(logits_on_cpu_);

TVMSynchronize(device_.device_type, device_.device_id, nullptr);
}

// Clear kv cache
void ResetKVCache() {
ft_.reset_kv_cache_func_(kv_cache_);
Expand Down Expand Up @@ -1547,7 +1555,7 @@ class LLMChat {
double decode_total_time = 0;
double sample_total_time = 0;
double prefill_total_time = 0;
int64_t decode_total_tokens = 0;
int64_t decode_total_tokens = -1;
int64_t prefill_total_tokens = 0;
//----------------------------
// Conversation
Expand Down

0 comments on commit 65f9716

Please sign in to comment.