diff --git a/src/fastertransformer/kernels/stop_criteria_kernels.cu b/src/fastertransformer/kernels/stop_criteria_kernels.cu index 5d6611153..a8d4b98fa 100644 --- a/src/fastertransformer/kernels/stop_criteria_kernels.cu +++ b/src/fastertransformer/kernels/stop_criteria_kernels.cu @@ -150,7 +150,7 @@ void invokeLengthCriterion(bool* finished, length_criterion<<>>( finished, should_stop, h_pinned_finished_sum_, sequence_limit_length, batch_size, beam_width, step); - while (((volatile size_t*)h_pinned_finished_sum_)[0] == -1) {}; + while (((volatile int*)h_pinned_finished_sum_)[0] == -1) {}; sync_check_cuda_error(); *should_stop = h_pinned_finished_sum_[0] == batch_size * beam_width; diff --git a/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc b/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc index ad9c3527b..6850b3083 100644 --- a/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc +++ b/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc @@ -1534,6 +1534,10 @@ void ParallelGpt::forward(std::unordered_map* outp POP_RANGE; } + if (*generation_should_stop_) { + break; + } + if (token_generated_cb_ && step_ + 1 < (int)gen_len) { setOutputTensors( output_tensors, input_tensors, gen_len, session_len, max_context_len, max_input_without_prompt_length);