Skip to content

Commit

Permalink
FEAT: Added the model dtype parameter for embedding (currently only s…
Browse files Browse the repository at this point in the history
…upported for models gte-Qwen2). (#2120)

Co-authored-by: 胡子俊 <huzijun@cvte.com>
  • Loading branch information
Zzzz1111 and zzzz199605 authored Aug 23, 2024
1 parent 6a372d6 commit c6a58ba
Showing 1 changed file with 23 additions and 1 deletion.
24 changes: 23 additions & 1 deletion xinference/model/embedding/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,32 @@ def to(self, *args, **kwargs):
"gte" in self._model_spec.model_name.lower()
and "qwen2" in self._model_spec.model_name.lower()
):
import torch

torch_dtype_str = self._kwargs.get("torch_dtype")
if torch_dtype_str is not None:
try:
torch_dtype = getattr(torch, torch_dtype_str)
if torch_dtype not in [
torch.float16,
torch.float32,
torch.bfloat16,
]:
logger.warning(
f"Load embedding model with unsupported torch dtype : {torch_dtype_str}. Using default torch dtype: fp32."
)
torch_dtype = torch.float32
except AttributeError:
logger.warning(
f"Load embedding model with unknown torch dtype '{torch_dtype_str}'. Using default torch dtype: fp32."
)
torch_dtype = torch.float32
else:
torch_dtype = "auto"
self._model = XSentenceTransformer(
self._model_path,
device=self._device,
model_kwargs={"device_map": "auto"},
model_kwargs={"device_map": "auto", "torch_dtype": torch_dtype},
)
else:
self._model = SentenceTransformer(self._model_path, device=self._device)
Expand Down

0 comments on commit c6a58ba

Please sign in to comment.