Skip to content

Commit

Permalink
ENH: handle legacy cache (#266)
Browse files Browse the repository at this point in the history
  • Loading branch information
UranusSeven authored Jul 28, 2023
1 parent 3351361 commit 518fdf9
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 30 deletions.
5 changes: 2 additions & 3 deletions xinference/core/gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from ..locale.utils import Locale
from ..model.llm import LLM_FAMILIES, LLMFamilyV1, match_llm
from ..model.llm.llm_family import cache_from_huggingface
from ..model.llm.llm_family import cache
from .api import SyncSupervisorAPI

if TYPE_CHECKING:
Expand Down Expand Up @@ -312,7 +312,6 @@ def select_model(
_model_format: str,
_model_size_in_billions: str,
_quantization: str,
progress=gr.Progress(),
):
match_result = match_llm(
_model_name,
Expand All @@ -328,7 +327,7 @@ def select_model(
)

llm_family, llm_spec, _quantization = match_result
cache_from_huggingface(llm_family, llm_spec, _quantization)
cache(llm_family, llm_spec, _quantization)

model_uid = self._create_model(
_model_name, int(_model_size_in_billions), _model_format, _quantization
Expand Down
6 changes: 2 additions & 4 deletions xinference/core/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,11 +313,9 @@ async def launch_builtin_model(
llm_family, llm_spec, quantization = match_result
assert quantization is not None

from ..model.llm.llm_family import cache_from_huggingface
from ..model.llm.llm_family import cache

save_path = await asyncio.to_thread(
cache_from_huggingface, llm_family, llm_spec, quantization
)
save_path = await asyncio.to_thread(cache, llm_family, llm_spec, quantization)

llm_cls = match_llm_cls(llm_family, llm_spec)
if not llm_cls:
Expand Down
5 changes: 5 additions & 0 deletions xinference/model/llm/ggml/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ def load(self):
),
)

# handle legacy cache.
legacy_model_file_path = os.path.join(self.model_path, "model.bin")
if os.path.exists(legacy_model_file_path):
model_file_path = legacy_model_file_path

self._llm = chatglm_cpp.Pipeline(Path(model_file_path))

@classmethod
Expand Down
18 changes: 12 additions & 6 deletions xinference/model/llm/ggml/llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,19 @@ def load(self):

raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")

self._llm = Llama(
model_path=os.path.join(
self.model_path,
self.model_spec.model_file_name_template.format(
quantization=self.quantization
),
# handle legacy cache.
model_path = os.path.join(
self.model_path,
self.model_spec.model_file_name_template.format(
quantization=self.quantization
),
)
legacy_model_file_path = os.path.join(self.model_path, "model.bin")
if os.path.exists(legacy_model_file_path):
model_path = legacy_model_file_path

self._llm = Llama(
model_path=model_path,
verbose=False,
**self._llamacpp_model_config,
)
Expand Down
33 changes: 19 additions & 14 deletions xinference/model/llm/llm_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
from typing import List, Optional, Union

Expand All @@ -20,6 +21,8 @@

from xinference.constants import XINFERENCE_CACHE_DIR

logger = logging.getLogger(__name__)


class GgmlLLMSpecV1(BaseModel):
model_format: Literal["ggmlv3"]
Expand Down Expand Up @@ -68,12 +71,13 @@ class LLMFamilyV1(BaseModel):
LLM_FAMILIES: List[LLMFamilyV1] = []


def _generate_cache_path_ggml(
self,
def get_legacy_cache_path(
model_name: str,
model_format: str,
model_size_in_billions: Optional[int] = None,
quantization: Optional[str] = None,
):
full_name = f"{str(self)}-{model_size_in_billions}b-{quantization}"
) -> str:
full_name = f"{model_name}-{model_format}-{model_size_in_billions}b-{quantization}"
save_dir = os.path.join(XINFERENCE_CACHE_DIR, full_name)
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
Expand All @@ -86,16 +90,17 @@ def cache(
llm_spec: "LLMSpecV1",
quantization: Optional[str] = None,
) -> str:
return cache_from_huggingface(llm_family, llm_spec, quantization)


def cache_legacy(
llm_family: LLMFamilyV1,
llm_spec: "LLMSpecV1",
quantization: Optional[str] = None,
) -> str:
# TODO: handle legacy
return ""
legacy_cache_path = get_legacy_cache_path(
llm_family.model_name,
llm_spec.model_format,
llm_spec.model_size_in_billions,
quantization,
)
if os.path.exists(legacy_cache_path):
logger.debug("Legacy cache path exists: %s", legacy_cache_path)
return os.path.dirname(legacy_cache_path)
else:
return cache_from_huggingface(llm_family, llm_spec, quantization)


def cache_from_huggingface(
Expand Down
48 changes: 45 additions & 3 deletions xinference/model/llm/tests/test_llm_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,15 +171,15 @@ def test_cache_from_huggingface_ggml():
model_format="ggmlv3",
model_size_in_billions=3,
model_id="TheBloke/orca_mini_3B-GGML",
quantizations=[""],
quantizations=["q4_0"],
model_file_name_template="README.md",
)
family = LLMFamilyV1(
version=1,
model_type="LLM",
model_name="opt",
model_name="orca",
model_lang=["en"],
model_ability=["embed", "generate"],
model_ability=["embed", "chat"],
model_specs=[spec],
prompt_style=None,
)
Expand All @@ -191,3 +191,45 @@ def test_cache_from_huggingface_ggml():
assert os.path.exists(cache_dir)
assert os.path.exists(os.path.join(cache_dir, "README.md"))
assert os.path.islink(os.path.join(cache_dir, "README.md"))


def test_legacy_cache():
import os

from ..llm_family import cache, get_legacy_cache_path

spec = GgmlLLMSpecV1(
model_format="ggmlv3",
model_size_in_billions=3,
model_id="TheBloke/orca_mini_3B-GGML",
quantizations=["q8_0"],
model_file_name_template="README.md",
)
family = LLMFamilyV1(
version=1,
model_type="LLM",
model_name="orca",
model_lang=["en"],
model_ability=["embed", "chat"],
model_specs=[spec],
prompt_style=None,
)

cache_path = get_legacy_cache_path(
family.model_name,
spec.model_format,
spec.model_size_in_billions,
quantization="q8_0",
)

assert cache(
llm_family=family, llm_spec=spec, quantization="q8_0"
) != os.path.dirname(cache_path)

os.makedirs(os.path.dirname(cache_path), exist_ok=True)
with open(cache_path, "w") as fd:
fd.write("foo")

assert cache(
llm_family=family, llm_spec=spec, quantization="q8_0"
) == os.path.dirname(cache_path)

0 comments on commit 518fdf9

Please sign in to comment.