From b982d8e831233ab5dcf1e660d86f3e135df9c062 Mon Sep 17 00:00:00 2001 From: klshuster Date: Fri, 18 Nov 2022 12:49:04 -0500 Subject: [PATCH] fix device --- parlai/core/torch_generator_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index f7bf59475f5..caaf4946a7a 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -1596,7 +1596,7 @@ def advance(self, logprobs, step): hyp_device = self.partial_hyps.get_device() self.partial_hyps = torch.cat( ( - self.partial_hyps[path_selection.hypothesis_ids.long()], + self.partial_hyps[path_selection.hypothesis_ids.long().to(hyp_device)], path_selection.token_ids.view(path_selection.token_ids.shape[0], -1).to( hyp_device ),