Skip to content

Commit

Permalink
Clamping beam_idxs, still fails trying to cat past_states to hidden_s…
Browse files Browse the repository at this point in the history
…tates
  • Loading branch information
joecummings committed Feb 10, 2023
1 parent 0899ad8 commit 0c3ccc6
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions torchtext/prototype/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,25 +298,26 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_

# We could store this in model_kwargs
num_hyps_in_prev_step = model_kwargs["past"][0][0].shape[0]

num_finished_hyps_in_step = num_hyps_in_prev_step - len(prev_step_hyp_idxs)
if num_finished_hyps_in_step > 0:
beam_idxs = F.pad(beam_idxs, (0, num_finished_hyps_in_step), "constant", 0)


beam_idxs = torch.clamp(beam_idxs, max=len(prev_step_hyp_idxs) - 1)

reordered_cached = self.model._reorder_cache(model_kwargs["past"], beam_idxs)

if num_finished_hyps_in_step > 0:
sliced_cache = ()
for states in reordered_cached:
sliced_state = ()
for state in states:
sliced_state = sliced_state + (state[:len(prev_step_hyp_idxs)],)
sliced_state = sliced_state + (state[: len(prev_step_hyp_idxs)],)
sliced_cache = sliced_cache + (sliced_state,)
reordered_cached = sliced_cache

model_inputs["past_key_values"] = reordered_cached


# Forward pass
outputs = self.model(**model_inputs)

Expand Down

0 comments on commit 0c3ccc6

Please sign in to comment.