-
Notifications
You must be signed in to change notification settings - Fork 488
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
357 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
# Copyright 2022-2023 XProbe Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import logging | ||
from typing import Iterator, Optional, Sequence, TypedDict, Union | ||
|
||
from ctransformers import AutoConfig | ||
|
||
from xinference.model.llm.ggml.ctransformers_util import generate_stream | ||
from xinference.types import Completion, CompletionChunk | ||
|
||
from ..core import LLM | ||
from ..llm_family import LLMFamilyV1, LLMSpecV1 | ||
from .llamacpp import SIZE_TO_GPU_LAYERS | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
# class AutoConfig(TypedDict, total=False): | ||
# top_k: int | ||
# top_p: float | ||
# temperature: float | ||
# repetition_penalty: float | ||
# last_n_tokens: float | ||
# seed: int | ||
# max_new_tokens: int | ||
# stop: List[str] | ||
# stream: bool | ||
# reset: bool | ||
# batch_size: int | ||
# threads: int | ||
# context_length: int | ||
# gpu_layers: int | ||
|
||
|
||
class CtransformerGenerateConfig(TypedDict, total=False): | ||
max_new_tokens: Optional[int] | ||
top_k: Optional[int] | ||
top_p: Optional[float] | ||
temperature: Optional[float] | ||
repetition_penalty: Optional[float] | ||
last_n_tokens: Optional[int] | ||
seed: Optional[int] | ||
batch_size: Optional[int] | ||
threads: Optional[int] | ||
stop: Optional[Sequence[str]] | ||
stream: Optional[bool] | ||
reset: Optional[bool] | ||
|
||
|
||
class CtransformerModel(LLM): | ||
def __init__( | ||
self, | ||
model_uid: str, | ||
model_family: "LLMFamilyV1", | ||
model_spec: "LLMSpecV1", | ||
quantization: str, | ||
model_path: str, | ||
ctransformerModelConfig: Optional[AutoConfig] = None, | ||
): | ||
super().__init__(model_uid, model_family, model_spec, quantization, model_path) | ||
|
||
closest_size = min( | ||
SIZE_TO_GPU_LAYERS.keys(), | ||
key=lambda x: abs(x - model_spec.model_size_in_billions), | ||
) | ||
self._gpu_layers = SIZE_TO_GPU_LAYERS[closest_size] | ||
self._ctransformer_model_config: AutoConfig = self._sanitize_model_config( | ||
model_path, ctransformerModelConfig | ||
) | ||
self._llm = None | ||
|
||
def _sanitize_model_config( | ||
self, model_path, ctransformerModelConfig: Optional[AutoConfig] | ||
) -> AutoConfig: | ||
if ctransformerModelConfig is None: | ||
ctransformerModelConfig = AutoConfig.from_pretrained( | ||
model_path, | ||
local_files_only=False, | ||
) | ||
|
||
return ctransformerModelConfig | ||
|
||
def _sanitize_generate_config( | ||
self, | ||
ctransformerGenerateConfig: Optional[CtransformerGenerateConfig], | ||
) -> CtransformerGenerateConfig: | ||
if ctransformerGenerateConfig is None: | ||
ctransformerGenerateConfig = CtransformerGenerateConfig() | ||
ctransformerGenerateConfig.setdefault("top_k", 40) | ||
ctransformerGenerateConfig.setdefault("top_p", 0.95) | ||
ctransformerGenerateConfig.setdefault("temperature", 0.8) | ||
ctransformerGenerateConfig.setdefault("repetition_penalty", 1.1) | ||
ctransformerGenerateConfig.setdefault("last_n_tokens", 64) | ||
ctransformerGenerateConfig.setdefault("seed", -1) | ||
ctransformerGenerateConfig.setdefault("batch_size", 8) | ||
ctransformerGenerateConfig.setdefault("threads", -1) | ||
ctransformerGenerateConfig.setdefault("stop", None) | ||
ctransformerGenerateConfig.setdefault("stream", None) | ||
ctransformerGenerateConfig.setdefault("reset", True) | ||
|
||
return ctransformerGenerateConfig | ||
|
||
def load(self): | ||
try: | ||
from ctransformers import AutoModelForCausalLM | ||
except ImportError: | ||
error_message = "Failed to import module 'ctransformers'" | ||
if self._is_darwin_and_apple_silicon(): | ||
system = "Metal" | ||
else: | ||
system = "CUDA" | ||
|
||
installation_guide = [ | ||
f"Please make sure 'ctransformers' is installed and {system} accelerator is provided.", | ||
f"You can install it by checking out the repository for command for {system} platform:" | ||
f"https://github.com/marella/ctransformers", | ||
] | ||
|
||
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}") | ||
|
||
self._llm = AutoModelForCausalLM.from_pretrained( | ||
model_path_or_repo_id=self._model_path, | ||
model_type=self._model_type, | ||
model_file=self._model_file, | ||
config=self._ctransformer_model_config, | ||
) | ||
|
||
@classmethod | ||
def match(cls, llm_family: LLMFamilyV1, llm_spec: LLMSpecV1) -> bool: | ||
if llm_spec.model_format != "ggmlv3": | ||
return False | ||
if llm_spec.model_id not in ["TheBloke/starcoder-GGML"]: | ||
return False | ||
if "chatglm" in llm_family.model_name: | ||
return False | ||
if "generate" not in llm_family.model_ability: | ||
return False | ||
return True | ||
|
||
def generate( | ||
self, prompt: str, generate_config: CtransformerGenerateConfig | ||
) -> Union[Completion, Iterator[CompletionChunk]]: | ||
def generator_wrapper( | ||
_prompt: str, | ||
_generate_config: CtransformerGenerateConfig, | ||
) -> Iterator[CompletionChunk]: | ||
assert self._llm is not None | ||
for _completion_chunk, _ in generate_stream( | ||
model=self._llm, prompt=_prompt, **_generate_config | ||
): | ||
yield _completion_chunk | ||
|
||
generate_config = self._sanitize_generate_config(generate_config) | ||
|
||
stream_or_not = generate_config.get("stream", False) | ||
if stream_or_not: | ||
return generator_wrapper(_prompt=prompt, _generate_config=generate_config) | ||
else: | ||
for completion_chunk, completion_usage in generate_stream( | ||
self._model, prompt=prompt, **generate_config | ||
): | ||
pass | ||
|
||
completion = Completion( | ||
id=completion_chunk["id"], | ||
object=completion_chunk["object"], | ||
created=completion_chunk["created"], | ||
model=completion_chunk["model"], | ||
choices=completion_chunk["choices"], | ||
usage=completion_usage, | ||
) | ||
return completion |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
# Copyright 2022-2023 XProbe Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import re | ||
import time | ||
import uuid | ||
from typing import Iterator, Optional, Sequence, Tuple | ||
|
||
from ctransformers.utils import utf8_split_incomplete | ||
|
||
from xinference.types import CompletionChoice, CompletionChunk, CompletionUsage | ||
|
||
|
||
def _get(*values): | ||
for value in values: | ||
if value is not None: | ||
return value | ||
|
||
|
||
def generate_stream( | ||
model, | ||
prompt: str, | ||
*, | ||
max_new_tokens: Optional[int] = None, | ||
top_k: Optional[int] = None, | ||
top_p: Optional[float] = None, | ||
temperature: Optional[float] = None, | ||
repetition_penalty: Optional[float] = None, | ||
last_n_tokens: Optional[int] = None, | ||
seed: Optional[int] = None, | ||
batch_size: Optional[int] = None, | ||
stream: Optional[bool] = True, | ||
threads: Optional[int] = None, | ||
stop: Optional[Sequence[str]] = None, | ||
reset: Optional[bool] = None, | ||
) -> Iterator[Tuple[CompletionChunk, CompletionUsage]]: | ||
max_new_tokens = _get(max_new_tokens) | ||
stop = _get(stop) or [] | ||
if isinstance(stop, str): | ||
stop = [stop] | ||
|
||
tokens = model.tokenize(prompt) | ||
|
||
stop_regex = re.compile("|".join(map(re.escape, stop))) | ||
count = 0 | ||
text = "" | ||
incomplete = b"" | ||
|
||
# parameters needed for Xinference. | ||
finish_reason = None | ||
|
||
for token in model.generate( | ||
tokens, | ||
top_k=top_k, | ||
top_p=top_p, | ||
temperature=temperature, | ||
repetition_penalty=repetition_penalty, | ||
last_n_tokens=last_n_tokens, | ||
seed=seed, | ||
batch_size=batch_size, | ||
threads=threads, | ||
reset=reset, | ||
): | ||
# Handle incomplete UTF-8 multi-byte characters. | ||
incomplete += model.detokenize([token], decode=False) | ||
complete, incomplete = utf8_split_incomplete(incomplete) | ||
output = complete.decode(errors="ignore") | ||
text += output | ||
|
||
# https://github.com/abetlen/llama-cpp-python/blob/1a13d76c487df1c8560132d10bda62d6e2f4fa93/llama_cpp/llama.py#L686-L706 | ||
# Check if one of the stop sequences is part of the text. | ||
# Note that the stop sequence may not always be at the end of text. | ||
if stop: | ||
match = stop_regex.search(text) | ||
if match: | ||
text = text[: match.start()] | ||
finish_reason = "stop" | ||
break | ||
|
||
# Avoid sending the longest suffix of text which is also a prefix | ||
# of a stop sequence, as it can form a stop sequence with the text | ||
# generated later. | ||
longest = 0 | ||
for s in stop: | ||
for i in range(len(s), 0, -1): | ||
if text.endswith(s[:i]): | ||
longest = max(i, longest) | ||
break | ||
|
||
end = len(text) - longest | ||
if end > 0: | ||
output = text[:end] | ||
completion_choice = CompletionChoice( | ||
text=output, index=0, logprobs=None, finish_reason=None | ||
) | ||
completion_chunk = CompletionChunk( | ||
id=str(uuid.uuid1()), | ||
object="text_completion", | ||
created=int(time.time()), | ||
model=model, | ||
choices=[completion_choice], | ||
) | ||
completion_usage = CompletionUsage( | ||
prompt_tokens=len(tokens), | ||
completion_tokens=count + 1, | ||
total_tokens=count + 1 + len(tokens), | ||
) | ||
|
||
yield completion_chunk, completion_usage | ||
text = text[end:] | ||
|
||
count += 1 | ||
if max_new_tokens is not None and count >= max_new_tokens: | ||
finish_reason = "length" | ||
break | ||
|
||
completion_choice = CompletionChoice( | ||
text=text, index=0, logprobs=None, finish_reason=finish_reason | ||
) | ||
completion_chunk = CompletionChunk( | ||
id=str(uuid.uuid1()), | ||
object="text_completion", | ||
created=int(time.time()), | ||
model=model, | ||
choices=[completion_choice], | ||
) | ||
completion_usage = CompletionUsage( | ||
prompt_tokens=len(tokens), | ||
completion_tokens=count, | ||
total_tokens=count + len(tokens), | ||
) | ||
|
||
yield completion_chunk, completion_usage |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters