diff --git a/xinference/model/embedding/core.py b/xinference/model/embedding/core.py index 7ad985ff09..ae8b63eb4d 100644 --- a/xinference/model/embedding/core.py +++ b/xinference/model/embedding/core.py @@ -19,6 +19,7 @@ from typing import Dict, List, Literal, Optional, Tuple, Union, no_type_check import numpy as np +import torch from ...device_utils import empty_cache from ...types import Embedding, EmbeddingData, EmbeddingUsage @@ -34,7 +35,11 @@ EMBEDDING_EMPTY_CACHE_COUNT = int( os.getenv("XINFERENCE_EMBEDDING_EMPTY_CACHE_COUNT", "10") ) +EMBEDDING_EMPTY_CACHE_TOKENS = int( + os.getenv("XINFERENCE_EMBEDDING_EMPTY_CACHE_TOKENS", "8192") +) assert EMBEDDING_EMPTY_CACHE_COUNT > 0 +assert EMBEDDING_EMPTY_CACHE_TOKENS > 0 def get_embedding_model_descriptions(): @@ -149,6 +154,25 @@ class XSentenceTransformer(SentenceTransformer): def to(self, *args, **kwargs): pass + torch_dtype = None + if torch_dtype_str := self._kwargs.get("torch_dtype"): + try: + torch_dtype = getattr(torch, torch_dtype_str) + if torch_dtype not in [ + torch.float16, + torch.float32, + torch.bfloat16, + ]: + logger.warning( + f"Load embedding model with unsupported torch dtype : {torch_dtype_str}. Using default torch dtype: fp32." + ) + torch_dtype = torch.float32 + except AttributeError: + logger.warning( + f"Load embedding model with unknown torch dtype '{torch_dtype_str}'. Using default torch dtype: fp32." + ) + torch_dtype = torch.float32 + from ..utils import patch_trust_remote_code patch_trust_remote_code() @@ -156,42 +180,21 @@ def to(self, *args, **kwargs): "gte" in self._model_spec.model_name.lower() and "qwen2" in self._model_spec.model_name.lower() ): - import torch - - torch_dtype_str = self._kwargs.get("torch_dtype") - if torch_dtype_str is not None: - try: - torch_dtype = getattr(torch, torch_dtype_str) - if torch_dtype not in [ - torch.float16, - torch.float32, - torch.bfloat16, - ]: - logger.warning( - f"Load embedding model with unsupported torch dtype : {torch_dtype_str}. Using default torch dtype: fp32." - ) - torch_dtype = torch.float32 - except AttributeError: - logger.warning( - f"Load embedding model with unknown torch dtype '{torch_dtype_str}'. Using default torch dtype: fp32." - ) - torch_dtype = torch.float32 - else: - torch_dtype = "auto" + model_kwargs = {"device_map": "auto"} + if torch_dtype: + model_kwargs["torch_dtype"] = torch_dtype self._model = XSentenceTransformer( self._model_path, device=self._device, - model_kwargs={"device_map": "auto", "torch_dtype": torch_dtype}, + model_kwargs=model_kwargs, ) else: - self._model = SentenceTransformer(self._model_path, device=self._device) + model_kwargs = {"torch_dtype": torch_dtype} if torch_dtype else None + self._model = SentenceTransformer( + self._model_path, device=self._device, model_kwargs=model_kwargs + ) def create_embedding(self, sentences: Union[str, List[str]], **kwargs): - self._counter += 1 - if self._counter % EMBEDDING_EMPTY_CACHE_COUNT == 0: - logger.debug("Empty embedding cache.") - gc.collect() - empty_cache() from sentence_transformers import SentenceTransformer kwargs.setdefault("normalize_embeddings", True) @@ -309,7 +312,9 @@ def encode( features = model.tokenize(sentences_batch) features = batch_to_device(features, device) features.update(extra_features) - all_token_nums += sum([len(f) for f in features]) + # when batching, the attention mask 1 means there is a token + # thus we just sum up it to get the total number of tokens + all_token_nums += features["attention_mask"].sum().item() with torch.no_grad(): out_features = model.forward(features) @@ -393,13 +398,29 @@ def encode( usage = EmbeddingUsage( prompt_tokens=all_token_nums, total_tokens=all_token_nums ) - return Embedding( + result = Embedding( object="list", model=self._model_uid, data=embedding_list, usage=usage, ) + # clean cache if possible + self._counter += 1 + if ( + self._counter % EMBEDDING_EMPTY_CACHE_COUNT == 0 + or all_token_nums >= EMBEDDING_EMPTY_CACHE_TOKENS + ): + logger.debug( + "Empty embedding cache, calling count %s, all_token_nums %s", + self._counter, + all_token_nums, + ) + gc.collect() + empty_cache() + + return result + def match_embedding( model_name: str, diff --git a/xinference/model/embedding/tests/test_embedding_models.py b/xinference/model/embedding/tests/test_embedding_models.py index eda7731ec3..7ff47f7c83 100644 --- a/xinference/model/embedding/tests/test_embedding_models.py +++ b/xinference/model/embedding/tests/test_embedding_models.py @@ -76,6 +76,11 @@ def test_model(): assert len(r["data"]) == 4 for d in r["data"]: assert len(d["embedding"]) == 384 + n_token = 0 + for inp in input_texts: + input_ids = model._model.tokenize([inp])["input_ids"] + n_token += input_ids.shape[-1] + assert r["usage"]["total_tokens"] == n_token finally: if model_path is not None: