diff --git a/.github/workflows/python.yaml b/.github/workflows/python.yaml index 921ff6e40a..5332fd3f08 100644 --- a/.github/workflows/python.yaml +++ b/.github/workflows/python.yaml @@ -82,6 +82,7 @@ jobs: - { os: windows-latest, python-version: 3.10 } include: - { os: self-hosted, module: gpu, python-version: 3.9} + - { os: macos-latest, module: metal, python-version: "3.10" } steps: - name: Check out code @@ -109,6 +110,9 @@ jobs: sudo rm -rf "/usr/local/share/boost" sudo rm -rf "$AGENT_TOOLSDIRECTORY" fi + if [ "$MODULE" == "metal" ]; then + pip install mlx-lm + fi pip install "llama-cpp-python==0.2.77" pip install transformers pip install attrdict @@ -162,6 +166,10 @@ jobs: ${{ env.SELF_HOST_PYTHON }} -m pytest --timeout=1500 \ -W ignore::PendingDeprecationWarning \ --cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/model/audio/tests/test_chattts.py + elif [ "$MODULE" == "metal" ]; then + pytest --timeout=1500 \ + -W ignore::PendingDeprecationWarning \ + --cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/model/llm/mlx/tests/test_mlx.py else pytest --timeout=1500 \ -W ignore::PendingDeprecationWarning \ diff --git a/doc/source/models/builtin/llm/qwen2-instruct.rst b/doc/source/models/builtin/llm/qwen2-instruct.rst index 16d3baf516..629ce9aefa 100644 --- a/doc/source/models/builtin/llm/qwen2-instruct.rst +++ b/doc/source/models/builtin/llm/qwen2-instruct.rst @@ -206,7 +206,71 @@ chosen quantization method from the options listed above:: xinference launch --model-engine ${engine} --model-name qwen2-instruct --size-in-billions 72 --model-format awq --quantization ${quantization} -Model Spec 13 (ggufv2, 0_5 Billion) +Model Spec 13 (mlx, 0_5 Billion) +++++++++++++++++++++++++++++++++++++++++ + +- **Model Format:** mlx +- **Model Size (in billions):** 0_5 +- **Quantizations:** 4-bit +- **Engines**: MLX +- **Model ID:** Qwen/Qwen2-0.5B-Instruct-MLX +- **Model Hubs**: `Hugging Face `__, `ModelScope `__ + +Execute the following command to launch the model, remember to replace ``${quantization}`` with your +chosen quantization method from the options listed above:: + + xinference launch --model-engine ${engine} --model-name qwen2-instruct --size-in-billions 0_5 --model-format mlx --quantization ${quantization} + + +Model Spec 14 (mlx, 1_5 Billion) +++++++++++++++++++++++++++++++++++++++++ + +- **Model Format:** mlx +- **Model Size (in billions):** 1_5 +- **Quantizations:** 4-bit +- **Engines**: MLX +- **Model ID:** Qwen/Qwen2-1.5B-Instruct-MLX +- **Model Hubs**: `Hugging Face `__, `ModelScope `__ + +Execute the following command to launch the model, remember to replace ``${quantization}`` with your +chosen quantization method from the options listed above:: + + xinference launch --model-engine ${engine} --model-name qwen2-instruct --size-in-billions 1_5 --model-format mlx --quantization ${quantization} + + +Model Spec 15 (mlx, 7 Billion) +++++++++++++++++++++++++++++++++++++++++ + +- **Model Format:** mlx +- **Model Size (in billions):** 7 +- **Quantizations:** 4-bit +- **Engines**: MLX +- **Model ID:** Qwen/Qwen2-7B-Instruct-MLX +- **Model Hubs**: `Hugging Face `__, `ModelScope `__ + +Execute the following command to launch the model, remember to replace ``${quantization}`` with your +chosen quantization method from the options listed above:: + + xinference launch --model-engine ${engine} --model-name qwen2-instruct --size-in-billions 7 --model-format mlx --quantization ${quantization} + + +Model Spec 16 (mlx, 72 Billion) +++++++++++++++++++++++++++++++++++++++++ + +- **Model Format:** mlx +- **Model Size (in billions):** 72 +- **Quantizations:** 4-bit +- **Engines**: MLX +- **Model ID:** mlx-community/Qwen2-72B-4bit +- **Model Hubs**: `Hugging Face `__ + +Execute the following command to launch the model, remember to replace ``${quantization}`` with your +chosen quantization method from the options listed above:: + + xinference launch --model-engine ${engine} --model-name qwen2-instruct --size-in-billions 72 --model-format mlx --quantization ${quantization} + + +Model Spec 17 (ggufv2, 0_5 Billion) ++++++++++++++++++++++++++++++++++++++++ - **Model Format:** ggufv2 @@ -222,7 +286,7 @@ chosen quantization method from the options listed above:: xinference launch --model-engine ${engine} --model-name qwen2-instruct --size-in-billions 0_5 --model-format ggufv2 --quantization ${quantization} -Model Spec 14 (ggufv2, 1_5 Billion) +Model Spec 18 (ggufv2, 1_5 Billion) ++++++++++++++++++++++++++++++++++++++++ - **Model Format:** ggufv2 @@ -238,7 +302,7 @@ chosen quantization method from the options listed above:: xinference launch --model-engine ${engine} --model-name qwen2-instruct --size-in-billions 1_5 --model-format ggufv2 --quantization ${quantization} -Model Spec 15 (ggufv2, 7 Billion) +Model Spec 19 (ggufv2, 7 Billion) ++++++++++++++++++++++++++++++++++++++++ - **Model Format:** ggufv2 @@ -254,7 +318,7 @@ chosen quantization method from the options listed above:: xinference launch --model-engine ${engine} --model-name qwen2-instruct --size-in-billions 7 --model-format ggufv2 --quantization ${quantization} -Model Spec 16 (ggufv2, 72 Billion) +Model Spec 20 (ggufv2, 72 Billion) ++++++++++++++++++++++++++++++++++++++++ - **Model Format:** ggufv2 diff --git a/setup.cfg b/setup.cfg index a2ef1422ad..b99dd6492e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -103,6 +103,7 @@ all = optimum outlines==0.0.34 # sglang errored for outlines > 0.0.34 sglang[all] ; sys_platform=='linux' + mlx-lm ; sys_platform=='darwin' and platform_machine=='arm64' attrdict # For deepseek VL timm>=0.9.16 # For deepseek VL torchvision # For deepseek VL @@ -143,6 +144,8 @@ vllm = vllm>=0.2.6 sglang = sglang[all] +mlx = + mlx-lm embedding = sentence-transformers>=2.7.0 rerank = diff --git a/xinference/model/llm/__init__.py b/xinference/model/llm/__init__.py index fb56d82488..1313d73196 100644 --- a/xinference/model/llm/__init__.py +++ b/xinference/model/llm/__init__.py @@ -34,6 +34,7 @@ BUILTIN_MODELSCOPE_LLM_FAMILIES, LLAMA_CLASSES, LLM_ENGINES, + MLX_CLASSES, SGLANG_CLASSES, SUPPORTED_ENGINES, TRANSFORMERS_CLASSES, @@ -42,6 +43,7 @@ GgmlLLMSpecV1, LLMFamilyV1, LLMSpecV1, + MLXLLMSpecV1, PromptStyleV1, PytorchLLMSpecV1, get_cache_status, @@ -112,6 +114,7 @@ def generate_engine_config_by_model_family(model_family): def _install(): from .ggml.chatglm import ChatglmCppChatModel from .ggml.llamacpp import LlamaCppChatModel, LlamaCppModel + from .mlx.core import MLXChatModel, MLXModel from .pytorch.baichuan import BaichuanPytorchChatModel from .pytorch.chatglm import ChatglmPytorchChatModel from .pytorch.cogvlm2 import CogVLM2Model @@ -147,6 +150,7 @@ def _install(): ) SGLANG_CLASSES.extend([SGLANGModel, SGLANGChatModel]) VLLM_CLASSES.extend([VLLMModel, VLLMChatModel]) + MLX_CLASSES.extend([MLXModel, MLXChatModel]) TRANSFORMERS_CLASSES.extend( [ BaichuanPytorchChatModel, @@ -176,6 +180,7 @@ def _install(): SUPPORTED_ENGINES["SGLang"] = SGLANG_CLASSES SUPPORTED_ENGINES["Transformers"] = TRANSFORMERS_CLASSES SUPPORTED_ENGINES["llama.cpp"] = LLAMA_CLASSES + SUPPORTED_ENGINES["MLX"] = MLX_CLASSES json_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), "llm_family.json" diff --git a/xinference/model/llm/llm_family.json b/xinference/model/llm/llm_family.json index 6d51c56c81..b1ab45ea2a 100644 --- a/xinference/model/llm/llm_family.json +++ b/xinference/model/llm/llm_family.json @@ -2549,6 +2549,38 @@ ], "model_id": "Qwen/Qwen2-72B-Instruct-AWQ" }, + { + "model_format": "mlx", + "model_size_in_billions": "0_5", + "quantizations": [ + "4-bit" + ], + "model_id": "Qwen/Qwen2-0.5B-Instruct-MLX" + }, + { + "model_format": "mlx", + "model_size_in_billions": "1_5", + "quantizations": [ + "4-bit" + ], + "model_id": "Qwen/Qwen2-1.5B-Instruct-MLX" + }, + { + "model_format": "mlx", + "model_size_in_billions": 7, + "quantizations": [ + "4-bit" + ], + "model_id": "Qwen/Qwen2-7B-Instruct-MLX" + }, + { + "model_format": "mlx", + "model_size_in_billions": 72, + "quantizations": [ + "4-bit" + ], + "model_id": "mlx-community/Qwen2-72B-Instruct-4bit" + }, { "model_format": "ggufv2", "model_size_in_billions": "0_5", diff --git a/xinference/model/llm/llm_family.py b/xinference/model/llm/llm_family.py index a405d8f532..7ea575fb65 100644 --- a/xinference/model/llm/llm_family.py +++ b/xinference/model/llm/llm_family.py @@ -107,6 +107,28 @@ def validate_model_size_with_radix(cls, v: object) -> object: return v +class MLXLLMSpecV1(BaseModel): + model_format: Literal["mlx"] + # Must in order that `str` first, then `int` + model_size_in_billions: Union[str, int] + quantizations: List[str] + model_id: Optional[str] + model_hub: str = "huggingface" + model_uri: Optional[str] + model_revision: Optional[str] + + @validator("model_size_in_billions", pre=False) + def validate_model_size_with_radix(cls, v: object) -> object: + if isinstance(v, str): + if ( + "_" in v + ): # for example, "1_8" just returns "1_8", otherwise int("1_8") returns 18 + return v + else: + return int(v) + return v + + class PromptStyleV1(BaseModel): style_name: str system_prompt: str = "" @@ -226,7 +248,7 @@ def parse_raw( LLMSpecV1 = Annotated[ - Union[GgmlLLMSpecV1, PytorchLLMSpecV1], + Union[GgmlLLMSpecV1, PytorchLLMSpecV1, MLXLLMSpecV1], Field(discriminator="model_format"), ] @@ -249,6 +271,8 @@ def parse_raw( VLLM_CLASSES: List[Type[LLM]] = [] +MLX_CLASSES: List[Type[LLM]] = [] + LLM_ENGINES: Dict[str, Dict[str, List[Dict[str, Any]]]] = {} SUPPORTED_ENGINES: Dict[str, List[Type[LLM]]] = {} @@ -549,7 +573,7 @@ def _get_meta_path( return os.path.join(cache_dir, "__valid_download") else: return os.path.join(cache_dir, f"__valid_download_{model_hub}") - elif model_format in ["ggmlv3", "ggufv2", "gptq", "awq"]: + elif model_format in ["ggmlv3", "ggufv2", "gptq", "awq", "mlx"]: assert quantization is not None if model_hub == "huggingface": return os.path.join(cache_dir, f"__valid_download_{quantization}") @@ -588,7 +612,7 @@ def _skip_download( logger.warning(f"Cache {cache_dir} exists, but it was from {hub}") return True return False - elif model_format in ["ggmlv3", "ggufv2", "gptq", "awq"]: + elif model_format in ["ggmlv3", "ggufv2", "gptq", "awq", "mlx"]: assert quantization is not None return os.path.exists( _get_meta_path(cache_dir, model_format, model_hub, quantization) @@ -683,7 +707,7 @@ def cache_from_csghub( ): return cache_dir - if llm_spec.model_format in ["pytorch", "gptq", "awq"]: + if llm_spec.model_format in ["pytorch", "gptq", "awq", "mlx"]: download_dir = retry_download( snapshot_download, llm_family.model_name, @@ -751,7 +775,7 @@ def cache_from_modelscope( ): return cache_dir - if llm_spec.model_format in ["pytorch", "gptq", "awq"]: + if llm_spec.model_format in ["pytorch", "gptq", "awq", "mlx"]: download_dir = retry_download( snapshot_download, llm_family.model_name, @@ -820,8 +844,8 @@ def cache_from_huggingface( if not IS_NEW_HUGGINGFACE_HUB: use_symlinks = {"local_dir_use_symlinks": True, "local_dir": cache_dir} - if llm_spec.model_format in ["pytorch", "gptq", "awq"]: - assert isinstance(llm_spec, PytorchLLMSpecV1) + if llm_spec.model_format in ["pytorch", "gptq", "awq", "mlx"]: + assert isinstance(llm_spec, (PytorchLLMSpecV1, MLXLLMSpecV1)) download_dir = retry_download( huggingface_hub.snapshot_download, llm_family.model_name, @@ -910,7 +934,7 @@ def get_cache_status( ] return any(revisions) # just check meta file for ggml and gptq model - elif llm_spec.model_format in ["ggmlv3", "ggufv2", "gptq", "awq"]: + elif llm_spec.model_format in ["ggmlv3", "ggufv2", "gptq", "awq", "mlx"]: ret = [] for q in llm_spec.quantizations: assert q is not None diff --git a/xinference/model/llm/llm_family_modelscope.json b/xinference/model/llm/llm_family_modelscope.json index 436ddc68c5..c3b2969bdf 100644 --- a/xinference/model/llm/llm_family_modelscope.json +++ b/xinference/model/llm/llm_family_modelscope.json @@ -2921,6 +2921,33 @@ "model_id": "qwen/Qwen2-72B-Instruct-AWQ", "model_hub": "modelscope" }, + { + "model_format": "mlx", + "model_size_in_billions": "0_5", + "quantizations": [ + "4-bit" + ], + "model_id": "qwen/Qwen2-0.5B-Instruct-MLX", + "model_hub": "modelscope" + }, + { + "model_format": "mlx", + "model_size_in_billions": "1_5", + "quantizations": [ + "4-bit" + ], + "model_id": "qwen/Qwen2-1.5B-Instruct-MLX", + "model_hub": "modelscope" + }, + { + "model_format": "mlx", + "model_size_in_billions": 7, + "quantizations": [ + "4-bit" + ], + "model_id": "qwen/Qwen2-7B-Instruct-MLX", + "model_hub": "modelscope" + }, { "model_format": "ggufv2", "model_size_in_billions": "0_5", diff --git a/xinference/model/llm/mlx/__init__.py b/xinference/model/llm/mlx/__init__.py new file mode 100644 index 0000000000..37f6558d95 --- /dev/null +++ b/xinference/model/llm/mlx/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/xinference/model/llm/mlx/core.py b/xinference/model/llm/mlx/core.py new file mode 100644 index 0000000000..c344c2f594 --- /dev/null +++ b/xinference/model/llm/mlx/core.py @@ -0,0 +1,408 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import platform +import sys +import time +import uuid +from typing import Dict, Iterable, Iterator, List, Optional, TypedDict, Union + +from ....fields import max_tokens_field +from ....types import ( + ChatCompletion, + ChatCompletionChunk, + ChatCompletionMessage, + Completion, + CompletionChoice, + CompletionChunk, + CompletionUsage, + LoRA, +) +from ..core import LLM +from ..llm_family import LLMFamilyV1, LLMSpecV1 +from ..utils import ChatModelMixin + +logger = logging.getLogger(__name__) + + +class MLXModelConfig(TypedDict, total=False): + revision: Optional[str] + max_gpu_memory: str + trust_remote_code: bool + + +class MLXGenerateConfig(TypedDict, total=False): + max_tokens: int + temperature: float + repetition_penalty: Optional[float] + repetition_context_size: Optional[float] + top_p: float + logit_bias: Optional[Dict[int, float]] + stop: Optional[Union[str, List[str]]] + stop_token_ids: Optional[Union[int, List[int]]] + stream: bool + stream_options: Optional[Union[dict, None]] + + +class MLXModel(LLM): + def __init__( + self, + model_uid: str, + model_family: "LLMFamilyV1", + model_spec: "LLMSpecV1", + quantization: str, + model_path: str, + model_config: Optional[MLXModelConfig] = None, + peft_model: Optional[List[LoRA]] = None, + ): + super().__init__(model_uid, model_family, model_spec, quantization, model_path) + self._use_fast_tokenizer = True + self._model_config: MLXModelConfig = self._sanitize_model_config(model_config) + if peft_model is not None: + raise ValueError("MLX engine has not supported lora yet") + + def _sanitize_model_config( + self, model_config: Optional[MLXModelConfig] + ) -> MLXModelConfig: + if model_config is None: + model_config = MLXModelConfig() + model_config.setdefault("revision", self.model_spec.model_revision) + model_config.setdefault("trust_remote_code", True) + return model_config + + def _sanitize_generate_config( + self, + generate_config: Optional[MLXGenerateConfig], + ) -> MLXGenerateConfig: + if generate_config is None: + generate_config = MLXGenerateConfig() + + generate_config.setdefault("max_tokens", max_tokens_field.default) + # default config is adapted from + # https://github.com/ml-explore/mlx-examples/blob/f212b770d8b5143e23102eda20400ae43340f844/llms/mlx_lm/utils.py#L129 + generate_config.setdefault("temperature", 0.0) + generate_config.setdefault("repetition_penalty", None) + generate_config.setdefault("repetition_context_size", 20) + generate_config.setdefault("top_p", 1.0) + generate_config.setdefault("logit_bias", None) + return generate_config + + def _load_model(self, **kwargs): + try: + from mlx_lm import load + except ImportError: + error_message = "Failed to import module 'mlx_lm'" + installation_guide = [ + "Please make sure 'mlx_lm' is installed. ", + "You can install it by `pip install mlx_lm`\n", + ] + + raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}") + + tokenizer_config = dict( + use_fast=self._use_fast_tokenizer, + trust_remote_code=kwargs["trust_remote_code"], + revision=kwargs["revision"], + ) + logger.debug( + "loading model with tokenizer config: %s, model config: %s", + tokenizer_config, + self._model_config, + ) + + return load( + self.model_path, + tokenizer_config=tokenizer_config, + model_config=self._model_config, + ) + + def load(self): + kwargs = {} + kwargs["revision"] = self._model_config.get( + "revision", self.model_spec.model_revision + ) + kwargs["trust_remote_code"] = self._model_config.get("trust_remote_code") + + self._model, self._tokenizer = self._load_model(**kwargs) + + @classmethod + def match( + cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str + ) -> bool: + if llm_spec.model_format not in ["mlx"]: + return False + if sys.platform != "darwin" or platform.processor() != "arm": + # only work for Mac M chips + return False + if "generate" not in llm_family.model_ability: + return False + return True + + def _generate_stream(self, prompt: str, kwargs: MLXGenerateConfig): + import mlx.core as mx + from mlx_lm.utils import generate_step + + model = self._model + model_uid = self.model_uid + tokenizer = self._tokenizer + max_tokens = kwargs["max_tokens"] + chunk_id = str(uuid.uuid4()) + stop_token_ids = kwargs.get("stop_token_ids", []) + stream = kwargs.get("stream", False) + stream_options = kwargs.pop("stream_options", None) + include_usage = ( + stream_options["include_usage"] + if isinstance(stream_options, dict) + else False + ) + + prompt_tokens = mx.array(tokenizer.encode(prompt)) + input_echo_len = len(prompt_tokens) + + i = 0 + start = time.time() + output = "" + for (token, _), i in zip( + generate_step( + prompt_tokens, + model, + temp=kwargs["temperature"], + repetition_penalty=kwargs["repetition_penalty"], + repetition_context_size=kwargs["repetition_context_size"], + top_p=kwargs["top_p"], + logit_bias=kwargs["logit_bias"], + ), + range(max_tokens), + ): + if token == tokenizer.eos_token_id or token in stop_token_ids: # type: ignore + break + + # Yield the last segment if streaming + out = tokenizer.decode( + token, + skip_special_tokens=True, + spaces_between_special_tokens=False, + clean_up_tokenization_spaces=True, + ) + + if stream: + # this special character is mainly for qwen + out = out.strip("�") + output = out + else: + output += out + + completion_choice = CompletionChoice( + text=output, index=0, logprobs=None, finish_reason=None + ) + completion_chunk = CompletionChunk( + id=chunk_id, + object="text_completion", + created=int(time.time()), + model=model_uid, + choices=[completion_choice], + ) + completion_usage = CompletionUsage( + prompt_tokens=input_echo_len, + completion_tokens=i, + total_tokens=(input_echo_len + i), + ) + + yield completion_chunk, completion_usage + + logger.info( + f"Average generation speed: {i / (time.time() - start):.2f} tokens/s." + ) + + if i == max_tokens - 1: + finish_reason = "length" + else: + finish_reason = "stop" + + if stream: + completion_choice = CompletionChoice( + text="", index=0, logprobs=None, finish_reason=finish_reason + ) + else: + completion_choice = CompletionChoice( + text=output, index=0, logprobs=None, finish_reason=finish_reason + ) + + completion_chunk = CompletionChunk( + id=chunk_id, + object="text_completion", + created=int(time.time()), + model=model_uid, + choices=[completion_choice], + ) + completion_usage = CompletionUsage( + prompt_tokens=input_echo_len, + completion_tokens=i, + total_tokens=(input_echo_len + i), + ) + + yield completion_chunk, completion_usage + + if include_usage: + completion_chunk = CompletionChunk( + id=chunk_id, + object="text_completion", + created=int(time.time()), + model=model_uid, + choices=[], + ) + completion_usage = CompletionUsage( + prompt_tokens=input_echo_len, + completion_tokens=i, + total_tokens=(input_echo_len + i), + ) + yield completion_chunk, completion_usage + + def generate( + self, prompt: str, generate_config: Optional[MLXGenerateConfig] = None + ) -> Union[Completion, Iterator[CompletionChunk]]: + def generator_wrapper( + prompt: str, generate_config: MLXGenerateConfig + ) -> Iterator[CompletionChunk]: + for completion_chunk, completion_usage in self._generate_stream( + prompt, + generate_config, + ): + completion_chunk["usage"] = completion_usage + yield completion_chunk + + logger.debug( + "Enter generate, prompt: %s, generate config: %s", prompt, generate_config + ) + + generate_config = self._sanitize_generate_config(generate_config) + + assert self._model is not None + assert self._tokenizer is not None + + stream = generate_config.get("stream", False) + if not stream: + for completion_chunk, completion_usage in self._generate_stream( + prompt, + generate_config, + ): + pass + completion = Completion( + id=completion_chunk["id"], + object=completion_chunk["object"], + created=completion_chunk["created"], + model=completion_chunk["model"], + choices=completion_chunk["choices"], + usage=completion_usage, + ) + return completion + else: + return generator_wrapper(prompt, generate_config) + + +class MLXChatModel(MLXModel, ChatModelMixin): + def __init__( + self, + model_uid: str, + model_family: "LLMFamilyV1", + model_spec: "LLMSpecV1", + quantization: str, + model_path: str, + model_config: Optional[MLXModelConfig] = None, + peft_model: Optional[List[LoRA]] = None, + ): + super().__init__( + model_uid, + model_family, + model_spec, + quantization, + model_path, + model_config, + peft_model, + ) + + def _sanitize_generate_config( + self, + generate_config: Optional[MLXGenerateConfig], + ) -> MLXGenerateConfig: + generate_config = super()._sanitize_generate_config(generate_config) + if ( + (not generate_config.get("stop")) + and self.model_family.prompt_style + and self.model_family.prompt_style.stop + ): + generate_config["stop"] = self.model_family.prompt_style.stop.copy() + if ( + generate_config.get("stop_token_ids", None) is None + and self.model_family.prompt_style + and self.model_family.prompt_style.stop_token_ids + ): + generate_config[ + "stop_token_ids" + ] = self.model_family.prompt_style.stop_token_ids.copy() + + return generate_config + + @classmethod + def match( + cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str + ) -> bool: + if llm_spec.model_format not in ["mlx"]: + return False + if sys.platform != "darwin" or platform.processor() != "arm": + # only work for Mac M chips + return False + if "chat" not in llm_family.model_ability: + return False + return True + + def chat( + self, + prompt: str, + system_prompt: Optional[str] = None, + chat_history: Optional[List[ChatCompletionMessage]] = None, + generate_config: Optional[MLXGenerateConfig] = None, + ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]: + tools = generate_config.pop("tools", []) if generate_config else None # type: ignore + full_prompt = self.get_full_prompt( + self.model_family, prompt, system_prompt, chat_history, tools + ) + + generate_config = self._sanitize_generate_config(generate_config) + # TODO(codingl2k1): qwen hacky to set stop for function call. + model_family = self.model_family.model_family or self.model_family.model_name + if tools and model_family in ["qwen-chat", "qwen1.5-chat"]: + stop = generate_config.get("stop") + if isinstance(stop, str): + generate_config["stop"] = [stop, "Observation:"] + elif isinstance(stop, Iterable): + assert not isinstance(stop, str) + generate_config["stop"] = list(stop) + ["Observation:"] + else: + generate_config["stop"] = "Observation:" + + stream = generate_config.get("stream", False) + if stream: + it = self.generate(full_prompt, generate_config) + assert isinstance(it, Iterator) + return self._to_chat_completion_chunks(it) + else: + c = self.generate(full_prompt, generate_config) + assert not isinstance(c, Iterator) + if tools: + return self._tool_calls_completion( + self.model_family, self.model_uid, c, tools + ) + return self._to_chat_completion(c) diff --git a/xinference/model/llm/mlx/tests/__init__.py b/xinference/model/llm/mlx/tests/__init__.py new file mode 100644 index 0000000000..37f6558d95 --- /dev/null +++ b/xinference/model/llm/mlx/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/xinference/model/llm/mlx/tests/test_mlx.py b/xinference/model/llm/mlx/tests/test_mlx.py new file mode 100644 index 0000000000..4fe69fd34f --- /dev/null +++ b/xinference/model/llm/mlx/tests/test_mlx.py @@ -0,0 +1,41 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import platform +import sys + +import pytest + +from .....client import Client + + +@pytest.mark.skipif( + sys.platform != "darwin" or platform.processor() != "arm", + reason="MLX only works for Apple silicon chip", +) +def test_load_mlx(setup): + endpoint, _ = setup + client = Client(endpoint) + + model_uid = client.launch_model( + model_name="qwen2-instruct", + model_engine="MLX", + model_size_in_billions="0_5", + model_format="mlx", + quantization="4-bit", + ) + assert len(client.list_models()) == 1 + model = client.get_model(model_uid) + completion = model.chat("write a poem.") + assert "content" in completion["choices"][0]["message"] + assert len(completion["choices"][0]["message"]["content"]) != 0 diff --git a/xinference/model/llm/utils.py b/xinference/model/llm/utils.py index 34a7a21720..0ab967a344 100644 --- a/xinference/model/llm/utils.py +++ b/xinference/model/llm/utils.py @@ -765,6 +765,16 @@ def _tool_calls_completion(cls, model_family, model_uid, c, tools): "usage": usage, } + @classmethod + def get_full_prompt(cls, model_family, prompt, system_prompt, chat_history, tools): + assert model_family.prompt_style is not None + prompt_style = model_family.prompt_style.copy() + if system_prompt: + prompt_style.system_prompt = system_prompt + chat_history = chat_history or [] + full_prompt = cls.get_prompt(prompt, chat_history, prompt_style, tools=tools) + return full_prompt + def get_file_location( llm_family: LLMFamilyV1, spec: LLMSpecV1, quantization: str @@ -781,7 +791,7 @@ def get_file_location( is_cached = cache_status assert isinstance(is_cached, bool) - if spec.model_format in ["pytorch", "gptq", "awq"]: + if spec.model_format in ["pytorch", "gptq", "awq", "mlx"]: return cache_dir, is_cached elif spec.model_format in ["ggmlv3", "ggufv2"]: assert isinstance(spec, GgmlLLMSpecV1)