From e7d3e8f9d67ffdfe2eadedcd6959c6fd0d7bac24 Mon Sep 17 00:00:00 2001 From: AkiyamaYummy <842720660@qq.com> Date: Fri, 31 Mar 2023 11:32:40 +0000 Subject: [PATCH] fix early stopping invalid Signed-off-by: AkiyamaYummy <842720660@qq.com> --- src/fastertransformer/kernels/stop_criteria_kernels.cu | 2 +- src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/fastertransformer/kernels/stop_criteria_kernels.cu b/src/fastertransformer/kernels/stop_criteria_kernels.cu index 5d6611153..a8d4b98fa 100644 --- a/src/fastertransformer/kernels/stop_criteria_kernels.cu +++ b/src/fastertransformer/kernels/stop_criteria_kernels.cu @@ -150,7 +150,7 @@ void invokeLengthCriterion(bool* finished, length_criterion<<>>( finished, should_stop, h_pinned_finished_sum_, sequence_limit_length, batch_size, beam_width, step); - while (((volatile size_t*)h_pinned_finished_sum_)[0] == -1) {}; + while (((volatile int*)h_pinned_finished_sum_)[0] == -1) {}; sync_check_cuda_error(); *should_stop = h_pinned_finished_sum_[0] == batch_size * beam_width; diff --git a/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc b/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc index ad9c3527b..6850b3083 100644 --- a/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc +++ b/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc @@ -1534,6 +1534,10 @@ void ParallelGpt::forward(std::unordered_map* outp POP_RANGE; } + if (*generation_should_stop_) { + break; + } + if (token_generated_cb_ && step_ + 1 < (int)gen_len) { setOutputTensors( output_tensors, input_tensors, gen_len, session_len, max_context_len, max_input_without_prompt_length);