From 0c3ccc6f841bc53ff9ad5ed8312a51e9b8b582b9 Mon Sep 17 00:00:00 2001 From: Joe Cummings Date: Fri, 10 Feb 2023 11:26:57 -0500 Subject: [PATCH] Clamping beam_idxs, still fails trying to cat past_states to hidden_states --- torchtext/prototype/generate.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/torchtext/prototype/generate.py b/torchtext/prototype/generate.py index 10df9f870d..e08b180d47 100644 --- a/torchtext/prototype/generate.py +++ b/torchtext/prototype/generate.py @@ -298,11 +298,13 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_ # We could store this in model_kwargs num_hyps_in_prev_step = model_kwargs["past"][0][0].shape[0] - + num_finished_hyps_in_step = num_hyps_in_prev_step - len(prev_step_hyp_idxs) if num_finished_hyps_in_step > 0: beam_idxs = F.pad(beam_idxs, (0, num_finished_hyps_in_step), "constant", 0) - + + beam_idxs = torch.clamp(beam_idxs, max=len(prev_step_hyp_idxs) - 1) + reordered_cached = self.model._reorder_cache(model_kwargs["past"], beam_idxs) if num_finished_hyps_in_step > 0: @@ -310,13 +312,12 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_ for states in reordered_cached: sliced_state = () for state in states: - sliced_state = sliced_state + (state[:len(prev_step_hyp_idxs)],) + sliced_state = sliced_state + (state[: len(prev_step_hyp_idxs)],) sliced_cache = sliced_cache + (sliced_state,) reordered_cached = sliced_cache model_inputs["past_key_values"] = reordered_cached - # Forward pass outputs = self.model(**model_inputs)