Skip to content

Commit

Permalink
Fix T5 beam search using parallelize (#11717)
Browse files Browse the repository at this point in the history
  • Loading branch information
OyvindTafjord authored May 14, 2021
1 parent 218d552 commit bd3b599
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/transformers/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -1682,7 +1682,7 @@ def _reorder_cache(self, past, beam_idx):
for layer_past_state in layer_past_states:
# need to set correct `past` for each of the four key / value states
reordered_layer_past_states = reordered_layer_past_states + (
layer_past_state.index_select(0, beam_idx),
layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
)

assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
Expand Down

0 comments on commit bd3b599

Please sign in to comment.