Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added bind_tools support for ChatMLX along with small fix in _stream #28743

Merged
merged 2 commits into from
Dec 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 79 additions & 6 deletions libs/community/langchain_community/chat_models/mlx.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
"""MLX Chat Wrapper."""

from typing import Any, Iterator, List, Optional
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Literal,
Optional,
Sequence,
Type,
Union,
)

from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
Expand All @@ -20,6 +32,9 @@
ChatResult,
LLMResult,
)
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool

from langchain_community.llms.mlx_pipeline import MLXPipeline

Expand Down Expand Up @@ -94,7 +109,6 @@ def _to_chat_prompt(
raise ValueError("Last message must be a HumanMessage!")

messages_dicts = [self._to_chatml_format(m) for m in messages]

return self.tokenizer.apply_chat_template(
messages_dicts,
tokenize=tokenize,
Expand Down Expand Up @@ -173,15 +187,18 @@ def _stream(
generate_step(
prompt_tokens,
self.llm.model,
temp,
repetition_penalty,
repetition_context_size,
temp=temp,
repetition_penalty=repetition_penalty,
repetition_context_size=repetition_context_size,
),
range(max_new_tokens),
):
# identify text to yield
text: Optional[str] = None
text = self.tokenizer.decode(token.item())
if not isinstance(token, int):
text = self.tokenizer.decode(token.item())
else:
text = self.tokenizer.decode(token)

# yield text, if any
if text:
Expand All @@ -193,3 +210,59 @@ def _stream(
# break if stop sequence found
if token == eos_token_id or (stop is not None and text in stop):
break

def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
*,
tool_choice: Optional[Union[dict, str, Literal["auto", "none"], bool]] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tool-like objects to this chat model.

Assumes model is compatible with OpenAI tool-calling API.

Args:
tools: A list of tool definitions to bind to this chat model.
Supports any tool definition handled by
:meth:`langchain_core.utils.function_calling.convert_to_openai_tool`.
tool_choice: Which tool to require the model to call.
Must be the name of the single provided function or
"auto" to automatically determine which function to call
(if any), or a dict of the form:
{"type": "function", "function": {"name": <<tool_name>>}}.
**kwargs: Any additional parameters to pass to the
:class:`~langchain.runnable.Runnable` constructor.
"""

formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
if tool_choice is not None and tool_choice:
if len(formatted_tools) != 1:
raise ValueError(
"When specifying `tool_choice`, you must provide exactly one "
f"tool. Received {len(formatted_tools)} tools."
)
if isinstance(tool_choice, str):
if tool_choice not in ("auto", "none"):
tool_choice = {
"type": "function",
"function": {"name": tool_choice},
}
elif isinstance(tool_choice, bool):
tool_choice = formatted_tools[0]
elif isinstance(tool_choice, dict):
if (
formatted_tools[0]["function"]["name"]
!= tool_choice["function"]["name"]
):
raise ValueError(
f"Tool choice {tool_choice} was specified, but the only "
f"provided tool was {formatted_tools[0]['function']['name']}."
)
else:
raise ValueError(
f"Unrecognized tool_choice type. Expected str, bool or dict. "
f"Received: {tool_choice}"
)
kwargs["tool_choice"] = tool_choice
return super().bind(tools=formatted_tools, **kwargs)
Loading