Skip to content

Commit

Permalink
Configurable Chat Formats (ggml-org#711)
Browse files Browse the repository at this point in the history
* Add configurable default chat completion format.

* Remove chat_template file to avoid circular import

* Update llama_types

* Add chat format
  • Loading branch information
abetlen authored Sep 29, 2023
1 parent a945404 commit 3bca770
Show file tree
Hide file tree
Showing 2 changed files with 330 additions and 19 deletions.
57 changes: 38 additions & 19 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from . import llama_cpp
from .llama_types import *
from .llama_grammar import LlamaGrammar
from . import llama_chat_format

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -243,6 +244,8 @@ def __init__(
lora_path: Optional[str] = None,
# Backend Params
numa: bool = False,
# Chat Format Params
chat_format: str = "llama-2",
# Misc
verbose: bool = True,
# Extra Params
Expand Down Expand Up @@ -273,6 +276,7 @@ def __init__(
lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
lora_path: Path to a LoRA file to apply to the model.
numa: Enable NUMA support. (NOTE: The initial value of this parameter is used for the remainder of the program as this value is set in llama_backend_init)
chat_format: String specifying the chat format to use when calling create_chat_completion.
verbose: Print verbose output to stderr.
kwargs: Unused keyword arguments (for additional backwards compatibility).
Expand Down Expand Up @@ -388,6 +392,8 @@ def __init__(

if self.verbose:
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)

self.chat_format = chat_format

self._n_vocab = self.n_vocab()
self._n_ctx = self.n_ctx()
Expand Down Expand Up @@ -1565,9 +1571,21 @@ def _convert_text_completion_chunks_to_chat(
],
}

def _convert_completion_to_chat(
self,
completion_or_chunks: Union[Completion, Iterator[CompletionChunk]],
stream: bool = False,
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
if stream:
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore
return self._convert_text_completion_chunks_to_chat(chunks)
else:
completion: Completion = completion_or_chunks # type: ignore
return self._convert_text_completion_to_chat(completion)

def create_chat_completion(
self,
messages: List[ChatCompletionMessage],
messages: List[ChatCompletionRequestMessage],
functions: Optional[List[ChatCompletionFunction]] = None,
function_call: Optional[Union[str, ChatCompletionFunctionCall]] = None,
temperature: float = 0.2,
Expand Down Expand Up @@ -1602,26 +1620,28 @@ def create_chat_completion(
Returns:
Generated chat completion or a stream of chat completion chunks.
"""
stop = (
stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else []
)
chat_history = "".join(
f'### {"Human" if message["role"] == "user" else "Assistant"}:{message["content"]}'
for message in messages

format = llama_chat_format.get_chat_format(self.chat_format)
result = format(
messages=messages,
)
PROMPT = chat_history + "### Assistant:"
PROMPT_STOP = ["### Assistant:", "### Human:"]
completion_or_chunks = self(
prompt=PROMPT,
stop=PROMPT_STOP + stop,
prompt = result.prompt
if result.stop is not None:
stop = [] if stop is None else [stop] if isinstance(stop, str) else stop
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
stop = stop + rstop

completion_or_chunks = self.create_completion(
prompt=prompt,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream=stream,
stop=stop,
max_tokens=max_tokens,
repeat_penalty=repeat_penalty,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
repeat_penalty=repeat_penalty,
tfs_z=tfs_z,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
Expand All @@ -1630,12 +1650,7 @@ def create_chat_completion(
logits_processor=logits_processor,
grammar=grammar,
)
if stream:
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore
return self._convert_text_completion_chunks_to_chat(chunks)
else:
completion: Completion = completion_or_chunks # type: ignore
return self._convert_text_completion_to_chat(completion)
return self._convert_completion_to_chat(completion_or_chunks, stream=stream) # type: ignore

def __del__(self):
if hasattr(self, "model") and self.model is not None:
Expand Down Expand Up @@ -1675,6 +1690,8 @@ def __getstate__(self):
lora_path=self.lora_path,
# Backend Params
numa=self.numa,
# Chat Format Params
chat_format=self.chat_format,
# Misc
verbose=self.verbose,
)
Expand Down Expand Up @@ -1708,6 +1725,8 @@ def __setstate__(self, state):
lora_path=state["lora_path"],
# Backend Params
numa=state["numa"],
# Chat Format Params
chat_format=state["chat_format"],
# Misc
verbose=state["verbose"],
)
Expand Down
Loading

0 comments on commit 3bca770

Please sign in to comment.