From c6a58ba2a1e3b77cab8cc727e40f6bfcd3bb025e Mon Sep 17 00:00:00 2001 From: Zzzz1111 <34296741+Zzzz1111@users.noreply.github.com> Date: Fri, 23 Aug 2024 12:08:27 +0800 Subject: [PATCH] FEAT: Added the model dtype parameter for embedding (currently only supported for models gte-Qwen2). (#2120) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: 胡子俊 --- xinference/model/embedding/core.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/xinference/model/embedding/core.py b/xinference/model/embedding/core.py index ed8b6c6eb0..23a9766c48 100644 --- a/xinference/model/embedding/core.py +++ b/xinference/model/embedding/core.py @@ -154,10 +154,32 @@ def to(self, *args, **kwargs): "gte" in self._model_spec.model_name.lower() and "qwen2" in self._model_spec.model_name.lower() ): + import torch + + torch_dtype_str = self._kwargs.get("torch_dtype") + if torch_dtype_str is not None: + try: + torch_dtype = getattr(torch, torch_dtype_str) + if torch_dtype not in [ + torch.float16, + torch.float32, + torch.bfloat16, + ]: + logger.warning( + f"Load embedding model with unsupported torch dtype : {torch_dtype_str}. Using default torch dtype: fp32." + ) + torch_dtype = torch.float32 + except AttributeError: + logger.warning( + f"Load embedding model with unknown torch dtype '{torch_dtype_str}'. Using default torch dtype: fp32." + ) + torch_dtype = torch.float32 + else: + torch_dtype = "auto" self._model = XSentenceTransformer( self._model_path, device=self._device, - model_kwargs={"device_map": "auto"}, + model_kwargs={"device_map": "auto", "torch_dtype": torch_dtype}, ) else: self._model = SentenceTransformer(self._model_path, device=self._device)