diff --git a/parlai/agents/rag/rag.py b/parlai/agents/rag/rag.py index 76c13c67e2d..704221f0325 100644 --- a/parlai/agents/rag/rag.py +++ b/parlai/agents/rag/rag.py @@ -756,7 +756,7 @@ def _regret_rebatchify( for i in range(batch.batchsize): vec_i = pred_vecs[i] txt_i = self._v2t(vec_i) - query_i = torch.LongTensor(self.model.tokenize_query(txt_i)) + query_i = torch.LongTensor(self.model.tokenize_query(txt_i)).to(query_vec) if self.retriever_query == 'one_turn': new_queries.append(query_i) else: