Skip to content

Commit

Permalink
BUG: Fix cached tag on UI (#748)
Browse files Browse the repository at this point in the history
  • Loading branch information
ChengjieLi28 authored Dec 12, 2023
1 parent cda4f22 commit c58b18a
Show file tree
Hide file tree
Showing 23 changed files with 180 additions and 101 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
files: xinference
repos:
- repo: https://github.com/psf/black
rev: 23.1.0
rev: 24.1a1
hooks:
- id: black
- repo: https://github.com/pre-commit/pre-commit-hooks
Expand Down
9 changes: 2 additions & 7 deletions xinference/api/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,6 @@ def serve(self, logging_conf: Optional[dict] = None):
f"{pprint.pformat(invalid_routes)}"
)

for tp in [CreateChatCompletion, CreateCompletion]:
logger.debug("Dump request model fields:\n%s", tp.__fields__)

class SPAStaticFiles(StaticFiles):
async def get_response(self, path: str, scope):
response = await super().get_response(path, scope)
Expand Down Expand Up @@ -288,13 +285,11 @@ def read_main():
SPAStaticFiles(directory=ui_location, html=True),
)
else:
warnings.warn(
f"""
warnings.warn(f"""
Xinference ui is not built at expected directory: {ui_location}
To resolve this warning, navigate to {os.path.join(lib_location, "web/ui/")}
And build the Xinference ui by running "npm run build"
"""
)
""")

config = Config(
app=self._app, host=self._host, port=self._port, log_config=logging_conf
Expand Down
12 changes: 4 additions & 8 deletions xinference/core/chat_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,13 +282,10 @@ def retry(text, hist, max_tokens, temperature) -> Generator:
) as generate_interface:
history = gr.State([])

Markdown(
f"""
Markdown(f"""
<h1 style='text-align: center; margin-bottom: 1rem'>🚀 Xinference Generate Bot : {self.model_name} 🚀</h1>
"""
)
Markdown(
f"""
""")
Markdown(f"""
<div class="center">
Model ID: {self.model_uid}
</div>
Expand All @@ -301,8 +298,7 @@ def retry(text, hist, max_tokens, temperature) -> Generator:
<div class="center">
Model Quantization: {self.quantization}
</div>
"""
)
""")

with Column(variant="panel"):
textbox = Textbox(
Expand Down
6 changes: 3 additions & 3 deletions xinference/core/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ def __init__(self):
super().__init__()
self._worker_address_to_worker: Dict[str, xo.ActorRefType["WorkerActor"]] = {}
self._worker_status: Dict[str, WorkerStatus] = {}
self._replica_model_uid_to_worker: Dict[
str, xo.ActorRefType["WorkerActor"]
] = {}
self._replica_model_uid_to_worker: Dict[str, xo.ActorRefType["WorkerActor"]] = (
{}
)
self._model_uid_to_replica_info: Dict[str, ReplicaInfo] = {}
self._uptime = None
self._lock = asyncio.Lock()
Expand Down
4 changes: 3 additions & 1 deletion xinference/deploy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def get_config_dict(
"disable_existing_loggers": False,
"formatters": {
"formatter": {
"format": "%(asctime)s %(name)-12s %(process)d %(levelname)-8s %(message)s"
"format": (
"%(asctime)s %(name)-12s %(process)d %(levelname)-8s %(message)s"
)
},
},
"filters": {
Expand Down
6 changes: 5 additions & 1 deletion xinference/model/embedding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import json
import os

from .core import EmbeddingModelSpec, get_cache_status
from .core import MODEL_NAME_TO_REVISION, EmbeddingModelSpec, get_cache_status
from .custom import CustomEmbeddingModelSpec, register_embedding, unregister_embedding

_model_spec_json = os.path.join(os.path.dirname(__file__), "model_spec.json")
Expand All @@ -27,12 +27,16 @@
(spec["model_name"], EmbeddingModelSpec(**spec))
for spec in json.load(codecs.open(_model_spec_json, "r", encoding="utf-8"))
)
for model_name, model_spec in BUILTIN_EMBEDDING_MODELS.items():
MODEL_NAME_TO_REVISION[model_name].append(model_spec.model_revision)
MODELSCOPE_EMBEDDING_MODELS = dict(
(spec["model_name"], EmbeddingModelSpec(**spec))
for spec in json.load(
codecs.open(_model_spec_modelscope_json, "r", encoding="utf-8")
)
)
for model_name, model_spec in MODELSCOPE_EMBEDDING_MODELS.items():
MODEL_NAME_TO_REVISION[model_name].append(model_spec.model_revision)

from ...constants import XINFERENCE_MODEL_DIR

Expand Down
14 changes: 7 additions & 7 deletions xinference/model/embedding/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,23 @@
import logging
import os
import shutil
from typing import List, Optional, Tuple, Union, no_type_check
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union, no_type_check

import numpy as np
from pydantic import BaseModel

from ...constants import XINFERENCE_CACHE_DIR
from ...types import Embedding, EmbeddingData, EmbeddingUsage
from ..core import ModelDescription
from ..utils import valid_model_revision
from ..utils import is_model_cached, valid_model_revision

logger = logging.getLogger(__name__)

SUPPORTED_SCHEMES = ["s3"]
# Used for check whether the model is cached.
# Init when registering all the builtin models.
MODEL_NAME_TO_REVISION: Dict[str, List[str]] = defaultdict(list)


class EmbeddingModelSpec(BaseModel):
Expand Down Expand Up @@ -195,11 +199,7 @@ def cache(model_spec: EmbeddingModelSpec):
def get_cache_status(
model_spec: EmbeddingModelSpec,
) -> bool:
cache_dir = os.path.realpath(
os.path.join(XINFERENCE_CACHE_DIR, model_spec.model_name)
)
meta_path = os.path.join(cache_dir, "__valid_download")
return valid_model_revision(meta_path, model_spec.model_revision)
return is_model_cached(model_spec, MODEL_NAME_TO_REVISION)


class EmbeddingModel:
Expand Down
12 changes: 7 additions & 5 deletions xinference/model/llm/ggml/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,9 @@ def _convert_raw_text_chunks_to_chat(
{
"index": 0,
"delta": {
"content": token
if isinstance(token, str)
else token.content,
"content": (
token if isinstance(token, str) else token.content
),
},
"finish_reason": None,
}
Expand Down Expand Up @@ -223,8 +223,10 @@ def _handle_tools(generate_config) -> Optional[ChatCompletionMessage]:
chatglm_tools.append(elem["function"])
return {
"role": "system",
"content": f"Answer the following questions as best as you can. You have access to the following tools:\n"
f"{json.dumps(chatglm_tools, indent=4, ensure_ascii=False)}",
"content": (
f"Answer the following questions as best as you can. You have access to the following tools:\n"
f"{json.dumps(chatglm_tools, indent=4, ensure_ascii=False)}"
),
}

def chat(
Expand Down
56 changes: 41 additions & 15 deletions xinference/model/llm/llm_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,31 +588,57 @@ def cache_from_huggingface(
return cache_dir


def _check_revision(
llm_family: LLMFamilyV1,
llm_spec: "LLMSpecV1",
builtin: list,
meta_path: str,
) -> bool:
for family in builtin:
if llm_family.model_name == family.model_name:
specs = family.model_specs
for spec in specs:
if (
spec.model_format == "pytorch"
and spec.model_size_in_billions == llm_spec.model_size_in_billions
):
return valid_model_revision(meta_path, spec.model_revision)
return False


def get_cache_status(
llm_family: LLMFamilyV1,
llm_spec: "LLMSpecV1",
) -> Union[bool, List[bool]]:
"""
When calling this function from above, `llm_family` is constructed only from BUILTIN_LLM_FAMILIES,
so we should check both huggingface and modelscope cache files.
"""
cache_dir = _get_cache_dir(llm_family, llm_spec, create_if_not_exist=False)
# check revision for pytorch model
if llm_spec.model_format == "pytorch":
return _skip_download(
cache_dir,
llm_spec.model_format,
llm_spec.model_hub,
llm_spec.model_revision,
"none",
)
hf_meta_path = _get_meta_path(cache_dir, "pytorch", "huggingface", "none")
ms_meta_path = _get_meta_path(cache_dir, "pytorch", "modelscope", "none")
revisions = [
_check_revision(llm_family, llm_spec, BUILTIN_LLM_FAMILIES, hf_meta_path),
_check_revision(
llm_family, llm_spec, BUILTIN_MODELSCOPE_LLM_FAMILIES, ms_meta_path
),
]
return any(revisions)
# just check meta file for ggml and gptq model
elif llm_spec.model_format in ["ggmlv3", "ggufv2", "gptq"]:
ret = []
for q in llm_spec.quantizations:
ret.append(
_skip_download(
cache_dir,
llm_spec.model_format,
llm_spec.model_hub,
llm_spec.model_revision,
q,
)
assert q is not None
hf_meta_path = _get_meta_path(
cache_dir, llm_spec.model_format, "huggingface", q
)
ms_meta_path = _get_meta_path(
cache_dir, llm_spec.model_format, "modelscope", q
)
results = [os.path.exists(hf_meta_path), os.path.exists(ms_meta_path)]
ret.append(any(results))
return ret
else:
raise ValueError(f"Unsupported model format: {llm_spec.model_format}")
Expand Down
6 changes: 3 additions & 3 deletions xinference/model/llm/pytorch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,9 +442,9 @@ def _sanitize_generate_config(
and self.model_family.prompt_style
and self.model_family.prompt_style.stop_token_ids
):
generate_config[
"stop_token_ids"
] = self.model_family.prompt_style.stop_token_ids.copy()
generate_config["stop_token_ids"] = (
self.model_family.prompt_style.stop_token_ids.copy()
)

return generate_config

Expand Down
2 changes: 2 additions & 0 deletions xinference/model/llm/tests/test_llm_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ def test_meta_file():
cache_dir = cache_from_huggingface(family, spec, quantization=None)
meta_path = _get_meta_path(cache_dir, spec.model_format, spec.model_hub, None)
assert valid_model_revision(meta_path, "3d2b5f275bdf882b8775f902e1bfdb790e2cfc32")
shutil.rmtree(cache_dir)


def test_parse_uri():
Expand Down Expand Up @@ -878,6 +879,7 @@ def test_get_cache_status_pytorch():
model_size_in_billions=1,
quantizations=["4-bit", "8-bit", "none"],
model_id="facebook/opt-125m",
model_revision="3d2b5f275bdf882b8775f902e1bfdb790e2cfc32",
)
family = LLMFamilyV1(
version=1,
Expand Down
6 changes: 5 additions & 1 deletion xinference/model/rerank/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import json
import os

from .core import RerankModelSpec, get_cache_status
from .core import MODEL_NAME_TO_REVISION, RerankModelSpec, get_cache_status

_model_spec_json = os.path.join(os.path.dirname(__file__), "model_spec.json")
_model_spec_modelscope_json = os.path.join(
Expand All @@ -26,11 +26,15 @@
(spec["model_name"], RerankModelSpec(**spec))
for spec in json.load(codecs.open(_model_spec_json, "r", encoding="utf-8"))
)
for model_name, model_spec in BUILTIN_RERANK_MODELS.items():
MODEL_NAME_TO_REVISION[model_name].append(model_spec.model_revision)
MODELSCOPE_RERANK_MODELS = dict(
(spec["model_name"], RerankModelSpec(**spec))
for spec in json.load(
codecs.open(_model_spec_modelscope_json, "r", encoding="utf-8")
)
)
for model_name, model_spec in MODELSCOPE_RERANK_MODELS.items():
MODEL_NAME_TO_REVISION[model_name].append(model_spec.model_revision)
del _model_spec_json
del _model_spec_modelscope_json
13 changes: 7 additions & 6 deletions xinference/model/rerank/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import logging
import os
import uuid
from collections import defaultdict
from typing import Dict, List, Optional, Tuple

import numpy as np
Expand All @@ -23,10 +24,14 @@
from ...constants import XINFERENCE_CACHE_DIR
from ...types import Document, DocumentObj, Rerank
from ..core import ModelDescription
from ..utils import valid_model_revision
from ..utils import is_model_cached, valid_model_revision

logger = logging.getLogger(__name__)

# Used for check whether the model is cached.
# Init when registering all the builtin models.
MODEL_NAME_TO_REVISION: Dict[str, List[str]] = defaultdict(list)


class RerankModelSpec(BaseModel):
model_name: str
Expand Down Expand Up @@ -126,11 +131,7 @@ def rerank(
def get_cache_status(
model_spec: RerankModelSpec,
) -> bool:
cache_dir = os.path.realpath(
os.path.join(XINFERENCE_CACHE_DIR, model_spec.model_name)
)
meta_path = os.path.join(cache_dir, "__valid_download")
return valid_model_revision(meta_path, model_spec.model_revision)
return is_model_cached(model_spec, MODEL_NAME_TO_REVISION)


def cache(model_spec: RerankModelSpec):
Expand Down
15 changes: 13 additions & 2 deletions xinference/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
import os
from json import JSONDecodeError
from pathlib import Path
from typing import Callable, Dict, Optional, Tuple
from typing import Any, Callable, Dict, Optional, Tuple

from fsspec import AbstractFileSystem

from ..constants import XINFERENCE_ENV_MODEL_SRC
from ..constants import XINFERENCE_CACHE_DIR, XINFERENCE_ENV_MODEL_SRC

logger = logging.getLogger(__name__)
MAX_ATTEMPTS = 3
Expand Down Expand Up @@ -132,6 +132,17 @@ def valid_model_revision(
return real_revision == expected_model_revision


def is_model_cached(model_spec: Any, name_to_revisions_mapping: Dict):
cache_dir = os.path.realpath(
os.path.join(XINFERENCE_CACHE_DIR, model_spec.model_name)
)
meta_path = os.path.join(cache_dir, "__valid_download")
revisions = name_to_revisions_mapping[model_spec.model_name]
if model_spec.model_revision not in revisions: # Usually for UT
revisions.append(model_spec.model_revision)
return any([valid_model_revision(meta_path, revision) for revision in revisions])


def is_valid_model_name(model_name: str) -> bool:
import re

Expand Down
12 changes: 12 additions & 0 deletions xinference/web/ui/src/scenes/launch_model/embeddingCard.js
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,18 @@ const EmbeddingCard = ({
return <Chip label="ZH" variant="outlined" size="small" />
}
})()}
{(() => {
if (modelData.is_cached) {
return (
<Chip
label="Cached"
variant="outlined"
size="small"
sx={{ marginLeft: '10px' }}
/>
)
}
})()}
{(() => {
if (is_custom && customDeleted) {
return (
Expand Down
Loading

0 comments on commit c58b18a

Please sign in to comment.