diff --git a/model2vec/model.py b/model2vec/model.py index 017207c..2d88e94 100644 --- a/model2vec/model.py +++ b/model2vec/model.py @@ -111,7 +111,7 @@ def save_pretrained(self, path: PathLike, model_name: str | None = None) -> None """ save_pretrained( folder_path=Path(path), - embeddings=self.embedding.weight.numpy(), + embeddings=self.embedding.weight.cpu().numpy(), tokenizer=self.tokenizer, config=self.config, base_model_name=self.base_model_name,