Skip to content

Commit

Permalink
fix: more mistral v3 function calling nuances
Browse files Browse the repository at this point in the history
  • Loading branch information
zhudotexe committed May 30, 2024
1 parent 807d6c8 commit 65cf956
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 12 deletions.
26 changes: 20 additions & 6 deletions kani/engines/huggingface/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,25 +192,35 @@ def _get_generate_args(self, prompt: str | torch.Tensor, **hyperparams):
return input_toks, input_len, hyperparams

async def predict(
self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams
self,
messages: list[ChatMessage],
functions: list[AIFunction] | None = None,
*,
decode_kwargs: dict = None,
**hyperparams,
) -> Completion:
"""
Given the current context of messages and available functions, get the next predicted chat message from the LM.
:param messages: The messages in the current chat context. ``sum(message_len(m) for m in messages)`` is
guaranteed to be less than max_context_size.
:param functions: The functions the LM is allowed to call.
:param decode_kwargs: Any arguments to pass to AutoTokenizer.decode(). Defaults to
``dict(skip_special_tokens=True)``.
:param hyperparams: Any additional parameters to pass to GenerationMixin.generate(). (See
https://huggingface.co/docs/transformers/main_classes/text_generation)
"""
if decode_kwargs is None:
decode_kwargs = dict(skip_special_tokens=True)

prompt = self.build_prompt(messages, functions)
input_toks, input_len, hyperparams = self._get_generate_args(prompt, **hyperparams)

# run it through the model
output = self.model.generate(input_toks, **hyperparams)
# decode to tokens
# the completion shouldn't include the prompt or stop token
content = self.tokenizer.decode(output[0][input_len:-1]).strip()
content = self.tokenizer.decode(output[0][input_len:], **decode_kwargs).strip()
return Completion(
ChatMessage.assistant(content), prompt_tokens=input_len, completion_tokens=len(output[0]) - (input_len + 1)
)
Expand All @@ -220,7 +230,8 @@ async def stream(
messages: list[ChatMessage],
functions: list[AIFunction] | None = None,
*,
streamer_timeout=None,
streamer_timeout: float | None = None,
decode_kwargs: dict = None,
**hyperparams,
) -> AsyncIterable[str | BaseCompletion]:
"""
Expand All @@ -230,14 +241,17 @@ async def stream(
guaranteed to be less than max_context_size.
:param functions: The functions the LM is allowed to call.
:param streamer_timeout: The maximum number of seconds to wait for the next token when streaming.
:param decode_kwargs: Any arguments to pass to AutoTokenizer.decode(). Defaults to
``dict(skip_special_tokens=True)``.
:param hyperparams: Any additional parameters to pass to GenerationMixin.generate(). (See
https://huggingface.co/docs/transformers/main_classes/text_generation)
"""
if decode_kwargs is None:
decode_kwargs = dict(skip_special_tokens=True)

prompt = self.build_prompt(messages, functions)
input_toks, input_len, hyperparams = self._get_generate_args(prompt, **hyperparams)
streamer = TextIteratorStreamer(
self.tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=streamer_timeout
)
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, timeout=streamer_timeout, **decode_kwargs)

# run it through the model in another thread so that we can get the tokens in this thread
output_toks = None
Expand Down
23 changes: 17 additions & 6 deletions kani/prompts/impl/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,17 @@ def ensure_available_tools(msgs: list[ChatMessage], functions: list[AIFunction])
class MixtralFunctionCallingAdapter(WrapperEngine):
"""Common Mixtral-8x22B function calling parsing wrapper."""

@staticmethod
def _parse_tool_calls(content: str) -> tuple[str, list[ToolCall]]:
tool_json = re.search(r"\[TOOL_CALLS]\s*(.+)</s>", content, re.IGNORECASE | re.DOTALL)
def __init__(self, *args, tool_call_token="[TOOL_CALLS]", eos_token="</s>", **kwargs):
super().__init__(*args, **kwargs)
self.tool_call_token = tool_call_token
self.eos_token = eos_token

def _parse_tool_calls(self, content: str) -> tuple[str, list[ToolCall]]:
tool_json = re.search(
rf"{re.escape(self.tool_call_token)}\s*(.+)(?:{re.escape(self.eos_token)})?",
content,
re.IGNORECASE | re.DOTALL,
)
if tool_json is None:
return content, []
actions = json.loads(tool_json.group(1))
Expand All @@ -142,28 +150,31 @@ def _parse_tool_calls(content: str) -> tuple[str, list[ToolCall]]:
return content[: tool_json.start()], tool_calls

async def predict(self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams):
hyperparams.setdefault("decode_kwargs", dict(skip_special_tokens=False))
completion = await super().predict(messages, functions, **hyperparams)

# if we have tools, parse
if functions:
completion.message.content, completion.message.tool_calls = self._parse_tool_calls(completion.message.text)
completion.message.content = completion.message.content.removesuffix(self.eos_token).strip()

return completion

async def stream(self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams):
content_parts = []
in_tool_call = False
inner_completion = None
hyperparams.setdefault("decode_kwargs", dict(skip_special_tokens=False))

# consume from the inner iterator, yielding as normal until we see a tool call or a completion
async for elem in super().stream(messages, functions, **hyperparams):
if isinstance(elem, str):
content_parts.append(elem)
# if we see the start of a tool call, stop yielding and start buffering
if elem == "[TOOL_CALLS]":
if elem.startswith(self.tool_call_token):
in_tool_call = True
# otherwise yield the string
if not in_tool_call:
if not in_tool_call and elem != self.eos_token:
yield elem
else:
# save the inner completion
Expand All @@ -179,7 +190,7 @@ async def stream(self, messages: list[ChatMessage], functions: list[AIFunction]
content, tool_calls = self._parse_tool_calls(content)
if inner_completion:
tool_calls = (inner_completion.message.tool_calls or []) + tool_calls
yield Completion(ChatMessage.assistant(content, tool_calls=tool_calls))
yield Completion(ChatMessage.assistant(content.removesuffix(self.eos_token).strip(), tool_calls=tool_calls))


MistralFunctionCallingAdapter = MixtralFunctionCallingAdapter

0 comments on commit 65cf956

Please sign in to comment.