diff --git a/src/seb/registered_models/e5_mistral.py b/src/seb/registered_models/e5_mistral.py index e66b3c75..c7413809 100644 --- a/src/seb/registered_models/e5_mistral.py +++ b/src/seb/registered_models/e5_mistral.py @@ -183,7 +183,7 @@ def encode( ) batched_embeddings.append(embeddings) - return torch.cat(batched_embeddings) + return torch.cat(batched_embeddings).to("cpu") def encode_corpus(self, corpus: list[dict[str, str]], **kwargs: Any): if isinstance(corpus, dict):