From 8f589b0ab4fd3ff7ff38358030c7e0f29da5b737 Mon Sep 17 00:00:00 2001 From: klshuster Date: Thu, 16 Sep 2021 15:10:09 -0400 Subject: [PATCH] fix cuda issue --- parlai/agents/rag/rag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parlai/agents/rag/rag.py b/parlai/agents/rag/rag.py index 76c13c67e2d..704221f0325 100644 --- a/parlai/agents/rag/rag.py +++ b/parlai/agents/rag/rag.py @@ -756,7 +756,7 @@ def _regret_rebatchify( for i in range(batch.batchsize): vec_i = pred_vecs[i] txt_i = self._v2t(vec_i) - query_i = torch.LongTensor(self.model.tokenize_query(txt_i)) + query_i = torch.LongTensor(self.model.tokenize_query(txt_i)).to(query_vec) if self.retriever_query == 'one_turn': new_queries.append(query_i) else: