Skip to content

Commit

Permalink
Fix LLaMa beam search when using parallelize (#24224)
Browse files Browse the repository at this point in the history
* Fix LLaMa beam search when using parallelize

same issue as T5 #11717

* fix code format in modeling_llama.py

* fix format of _reorder_cache in modeling_llama.py
  • Loading branch information
FeiWang96 authored Jun 15, 2023
1 parent 7504be3 commit 33196b4
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,9 @@ def prepare_inputs_for_generation(
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past


Expand Down

0 comments on commit 33196b4

Please sign in to comment.