Skip to content

Commit

Permalink
BUG: fix embedding token calculation & optimize memory (#2221)
Browse files Browse the repository at this point in the history
  • Loading branch information
qinxuye authored Sep 6, 2024
1 parent def9d4a commit 2198965
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 31 deletions.
83 changes: 52 additions & 31 deletions xinference/model/embedding/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Dict, List, Literal, Optional, Tuple, Union, no_type_check

import numpy as np
import torch

from ...device_utils import empty_cache
from ...types import Embedding, EmbeddingData, EmbeddingUsage
Expand All @@ -34,7 +35,11 @@
EMBEDDING_EMPTY_CACHE_COUNT = int(
os.getenv("XINFERENCE_EMBEDDING_EMPTY_CACHE_COUNT", "10")
)
EMBEDDING_EMPTY_CACHE_TOKENS = int(
os.getenv("XINFERENCE_EMBEDDING_EMPTY_CACHE_TOKENS", "8192")
)
assert EMBEDDING_EMPTY_CACHE_COUNT > 0
assert EMBEDDING_EMPTY_CACHE_TOKENS > 0


def get_embedding_model_descriptions():
Expand Down Expand Up @@ -149,49 +154,47 @@ class XSentenceTransformer(SentenceTransformer):
def to(self, *args, **kwargs):
pass

torch_dtype = None
if torch_dtype_str := self._kwargs.get("torch_dtype"):
try:
torch_dtype = getattr(torch, torch_dtype_str)
if torch_dtype not in [
torch.float16,
torch.float32,
torch.bfloat16,
]:
logger.warning(
f"Load embedding model with unsupported torch dtype : {torch_dtype_str}. Using default torch dtype: fp32."
)
torch_dtype = torch.float32
except AttributeError:
logger.warning(
f"Load embedding model with unknown torch dtype '{torch_dtype_str}'. Using default torch dtype: fp32."
)
torch_dtype = torch.float32

from ..utils import patch_trust_remote_code

patch_trust_remote_code()
if (
"gte" in self._model_spec.model_name.lower()
and "qwen2" in self._model_spec.model_name.lower()
):
import torch

torch_dtype_str = self._kwargs.get("torch_dtype")
if torch_dtype_str is not None:
try:
torch_dtype = getattr(torch, torch_dtype_str)
if torch_dtype not in [
torch.float16,
torch.float32,
torch.bfloat16,
]:
logger.warning(
f"Load embedding model with unsupported torch dtype : {torch_dtype_str}. Using default torch dtype: fp32."
)
torch_dtype = torch.float32
except AttributeError:
logger.warning(
f"Load embedding model with unknown torch dtype '{torch_dtype_str}'. Using default torch dtype: fp32."
)
torch_dtype = torch.float32
else:
torch_dtype = "auto"
model_kwargs = {"device_map": "auto"}
if torch_dtype:
model_kwargs["torch_dtype"] = torch_dtype
self._model = XSentenceTransformer(
self._model_path,
device=self._device,
model_kwargs={"device_map": "auto", "torch_dtype": torch_dtype},
model_kwargs=model_kwargs,
)
else:
self._model = SentenceTransformer(self._model_path, device=self._device)
model_kwargs = {"torch_dtype": torch_dtype} if torch_dtype else None
self._model = SentenceTransformer(
self._model_path, device=self._device, model_kwargs=model_kwargs
)

def create_embedding(self, sentences: Union[str, List[str]], **kwargs):
self._counter += 1
if self._counter % EMBEDDING_EMPTY_CACHE_COUNT == 0:
logger.debug("Empty embedding cache.")
gc.collect()
empty_cache()
from sentence_transformers import SentenceTransformer

kwargs.setdefault("normalize_embeddings", True)
Expand Down Expand Up @@ -309,7 +312,9 @@ def encode(
features = model.tokenize(sentences_batch)
features = batch_to_device(features, device)
features.update(extra_features)
all_token_nums += sum([len(f) for f in features])
# when batching, the attention mask 1 means there is a token
# thus we just sum up it to get the total number of tokens
all_token_nums += features["attention_mask"].sum().item()

with torch.no_grad():
out_features = model.forward(features)
Expand Down Expand Up @@ -393,13 +398,29 @@ def encode(
usage = EmbeddingUsage(
prompt_tokens=all_token_nums, total_tokens=all_token_nums
)
return Embedding(
result = Embedding(
object="list",
model=self._model_uid,
data=embedding_list,
usage=usage,
)

# clean cache if possible
self._counter += 1
if (
self._counter % EMBEDDING_EMPTY_CACHE_COUNT == 0
or all_token_nums >= EMBEDDING_EMPTY_CACHE_TOKENS
):
logger.debug(
"Empty embedding cache, calling count %s, all_token_nums %s",
self._counter,
all_token_nums,
)
gc.collect()
empty_cache()

return result


def match_embedding(
model_name: str,
Expand Down
5 changes: 5 additions & 0 deletions xinference/model/embedding/tests/test_embedding_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ def test_model():
assert len(r["data"]) == 4
for d in r["data"]:
assert len(d["embedding"]) == 384
n_token = 0
for inp in input_texts:
input_ids = model._model.tokenize([inp])["input_ids"]
n_token += input_ids.shape[-1]
assert r["usage"]["total_tokens"] == n_token

finally:
if model_path is not None:
Expand Down

0 comments on commit 2198965

Please sign in to comment.