diff --git a/src/fastertransformer/models/llama/Llama.cc b/src/fastertransformer/models/llama/Llama.cc index c889f2db4..01ebc0e48 100644 --- a/src/fastertransformer/models/llama/Llama.cc +++ b/src/fastertransformer/models/llama/Llama.cc @@ -140,8 +140,6 @@ void Llama::allocateBuffer( // prompt_learning weight batch ptrs prompt_learning_weight_batch_ = (const T**)(allocator_->reMalloc(prompt_learning_weight_batch_, sizeof(T*) * batchxbeam, false)); - tiled_prompt_lengths_buf_ = - (int*)(allocator_->reMalloc(tiled_prompt_lengths_buf_, sizeof(int) * batchxbeam, true)); tiled_input_ids_buf_ = (int*)(allocator_->reMalloc(tiled_input_ids_buf_, sizeof(int) * batchxbeam * max_input_len, true)); @@ -204,7 +202,6 @@ void Llama::freeBuffer() } allocator_->free((void**)(&prompt_learning_weight_batch_)); - allocator_->free((void**)(&tiled_prompt_lengths_buf_)); allocator_->free((void**)(&tiled_input_ids_buf_)); allocator_->free((void**)(&tiled_input_lengths_buf_)); @@ -639,22 +636,6 @@ void Llama::forward(std::unordered_map* output_ten sync_check_cuda_error(); } - // Prefix prompts - if (has_prefix_prompt_) { - cudaMemcpyAsync(prompt_learning_weight_batch_, - prefix_prompt_weight_batch_ptrs.data(), - sizeof(T*) * batch_size * beam_width, - cudaMemcpyDefault, - stream_); - cudaMemcpyAsync(tiled_prompt_lengths_buf_, - prefix_prompt_lengths.data(), - sizeof(int) * batch_size * beam_width, - cudaMemcpyDefault, - stream_); - } - - sync_check_cuda_error(); - // handle first step if (has_prefix_prompt_ || has_prefix_soft_prompt_ || max_input_length > 1) { invokeTileGptInputs(tiled_input_ids_buf_, @@ -707,7 +688,7 @@ void Llama::forward(std::unordered_map* output_ten invokeBuildDecoderAttentionMask(input_attention_mask_, tiled_input_lengths_buf_, - tiled_prompt_lengths_buf_, + (const int*)nullptr, // prefix_prompt_lengths batch_size * beam_width, max_input_length, max_prefix_prompt_length, @@ -838,7 +819,6 @@ void Llama::forward(std::unordered_map* output_ten invokeMaskPaddingTokens(masked_tokens_, input_tensors->at("input_lengths").getPtr(), // not_tiled - tiled_prompt_lengths_buf_, max_cache_seq_len, max_input_length + max_prefix_prompt_length, 0,