From ebf9cbcf0e23659436045a92b94b2f667d071b7d Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Mon, 19 Aug 2024 10:55:58 -0400 Subject: [PATCH 1/9] update groq --- .../integration_tests/test_chat_models.py | 37 +++++++++++++++++++ .../tests/integration_tests/test_standard.py | 8 +++- 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/libs/partners/groq/tests/integration_tests/test_chat_models.py b/libs/partners/groq/tests/integration_tests/test_chat_models.py index be8814bc3e7ad..2e5a9620b2287 100644 --- a/libs/partners/groq/tests/integration_tests/test_chat_models.py +++ b/libs/partners/groq/tests/integration_tests/test_chat_models.py @@ -14,6 +14,7 @@ ) from langchain_core.outputs import ChatGeneration, LLMResult from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.tools import tool from langchain_groq import ChatGroq from tests.unit_tests.fake.callbacks import ( @@ -393,6 +394,42 @@ class Joke(BaseModel): assert len(result.punchline) != 0 +def test_tool_calling_no_arguments() -> None: + # Note: this is a variant of a test in langchain_standard_tests + # that as of 2024-08-19 fails with "Failed to call a function. Please + # adjust your prompt." when `tool_choice="any"` is specified, but + # passes when `tool_choice` is not specified. + model = ChatGroq(model="llama-3.1-70b-versatile", temperature=0) # type: ignore[call-arg] + + @tool + def magic_function_no_args() -> int: + """Calculates a magic function.""" + return 5 + + model_with_tools = model.bind_tools([magic_function_no_args]) + query = "What is the value of magic_function()? Use the tool." + result = model_with_tools.invoke(query) + assert isinstance(result, AIMessage) + assert len(result.tool_calls) == 1 + tool_call = result.tool_calls[0] + assert tool_call["name"] == "magic_function_no_args" + assert tool_call["args"] == {} + assert tool_call["id"] is not None + assert tool_call["type"] == "tool_call" + + # Test streaming + full: Optional[BaseMessageChunk] = None + for chunk in model_with_tools.stream(query): + full = chunk if full is None else full + chunk # type: ignore + assert isinstance(full, AIMessage) + assert len(full.tool_calls) == 1 + tool_call = full.tool_calls[0] + assert tool_call["name"] == "magic_function_no_args" + assert tool_call["args"] == {} + assert tool_call["id"] is not None + assert tool_call["type"] == "tool_call" + + # Groq does not currently support N > 1 # @pytest.mark.scheduled # def test_chat_multiple_completions() -> None: diff --git a/libs/partners/groq/tests/integration_tests/test_standard.py b/libs/partners/groq/tests/integration_tests/test_standard.py index 6feab74f60677..d3483c9d650eb 100644 --- a/libs/partners/groq/tests/integration_tests/test_standard.py +++ b/libs/partners/groq/tests/integration_tests/test_standard.py @@ -28,11 +28,17 @@ class TestGroqLlama(BaseTestGroq): @property def chat_model_params(self) -> dict: return { - "model": "llama-3.1-70b-versatile", + "model": "llama-3.1-8b-instant", "temperature": 0, "rate_limiter": rate_limiter, } + @pytest.mark.xfail( + reason=("Fails with 'Failed to call a function. Please adjust your prompt.'") + ) + def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None: + super().test_tool_calling_with_no_arguments(model) + @pytest.mark.xfail( reason=("Fails with 'Failed to call a function. Please adjust your prompt.'") ) From 06b62eed3674b788d5606c4ce83e1242028720a3 Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Mon, 19 Aug 2024 10:56:39 -0400 Subject: [PATCH 2/9] update ChatTogether.bind_tools --- .../langchain_together/chat_models.py | 53 ++++++++++++++++++- .../test_chat_models_standard.py | 4 -- 2 files changed, 52 insertions(+), 5 deletions(-) diff --git a/libs/partners/together/langchain_together/chat_models.py b/libs/partners/together/langchain_together/chat_models.py index 76d79d8d29d75..346774d0c0a30 100644 --- a/libs/partners/together/langchain_together/chat_models.py +++ b/libs/partners/together/langchain_together/chat_models.py @@ -2,18 +2,28 @@ from typing import ( Any, + Callable, Dict, List, + Literal, Optional, + Sequence, + Type, + Union, ) import openai +from langchain_core.language_models import LanguageModelInput from langchain_core.language_models.chat_models import LangSmithParams -from langchain_core.pydantic_v1 import Field, SecretStr, root_validator +from langchain_core.messages import BaseMessage +from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator +from langchain_core.runnables import Runnable +from langchain_core.tools import BaseTool from langchain_core.utils import ( from_env, secret_from_env, ) +from langchain_core.utils.function_calling import convert_to_openai_tool from langchain_openai.chat_models.base import BaseChatOpenAI @@ -362,3 +372,44 @@ def validate_environment(cls, values: Dict) -> Dict: **client_params, **async_specific ).chat.completions return values + + def bind_tools( + self, + tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], + *, + tool_choice: Optional[Union[dict, str, Literal["auto", "any"], bool]] = None, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + """Bind tool-like objects to this chat model. + + Assumes model is compatible with Together tool-calling API. + + Args: + tools: A list of tool definitions to bind to this chat model. + Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic + models, callables, and BaseTools will be automatically converted to + their schema dictionary representation. + tool_choice: Which tool to require the model to call. + Options are: + name of the tool (str): calls corresponding tool; + "auto": automatically selects a tool (including no tool); + "none": does not call a tool; + "any" or "required": force at least one tool to be called; + True: forces tool call (requires `tools` be length 1); + False: no effect; + + or a dict of the form: + {"type": "function", "function": {"name": <>}}. + **kwargs: Any additional parameters to pass to the + :class:`~langchain.runnable.Runnable` constructor. + """ + if tool_choice == "any" and len(tools) == 1: + # Together specifies tool_choice via "auto" or a dict. + # https://docs.together.ai/docs/tool-call-with-other-models#tool_choice + formatted_tool = convert_to_openai_tool(tools[0]) + tool_name = formatted_tool["function"]["name"] + tool_choice = {"type": "function", "function": {"name": tool_name}} + else: + pass + + return super().bind_tools(tools=tools, tool_choice=tool_choice, **kwargs) diff --git a/libs/partners/together/tests/integration_tests/test_chat_models_standard.py b/libs/partners/together/tests/integration_tests/test_chat_models_standard.py index 2250873f4b659..b2377870de58e 100644 --- a/libs/partners/together/tests/integration_tests/test_chat_models_standard.py +++ b/libs/partners/together/tests/integration_tests/test_chat_models_standard.py @@ -28,10 +28,6 @@ def chat_model_params(self) -> dict: "rate_limiter": rate_limiter, } - @pytest.mark.xfail(reason=("May not call a tool.")) - def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None: - super().test_tool_calling_with_no_arguments(model) - @pytest.mark.xfail(reason="Not yet supported.") def test_usage_metadata_streaming(self, model: BaseChatModel) -> None: super().test_usage_metadata_streaming(model) From 1b78a7b33cccf83da1f372d5c6e71abd01d21021 Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Mon, 19 Aug 2024 11:02:57 -0400 Subject: [PATCH 3/9] specify tool_choice in tool calling tests --- .../integration_tests/chat_models.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py b/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py index bcb47a4c151a7..d6d08ea93edbd 100644 --- a/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py +++ b/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py @@ -170,7 +170,7 @@ def test_stop_sequence(self, model: BaseChatModel) -> None: def test_tool_calling(self, model: BaseChatModel) -> None: if not self.has_tool_calling: pytest.skip("Test requires tool calling.") - model_with_tools = model.bind_tools([magic_function]) + model_with_tools = model.bind_tools([magic_function], tool_choice="any") # Test invoke query = "What is the value of magic_function(3)? Use the tool." @@ -188,7 +188,7 @@ def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None: if not self.has_tool_calling: pytest.skip("Test requires tool calling.") - model_with_tools = model.bind_tools([magic_function_no_args]) + model_with_tools = model.bind_tools([magic_function_no_args], tool_choice="any") query = "What is the value of magic_function()? Use the tool." result = model_with_tools.invoke(query) _validate_tool_call_message_no_args(result) @@ -212,7 +212,7 @@ def test_bind_runnables_as_tools(self, model: BaseChatModel) -> None: name="greeting_generator", description="Generate a greeting in a particular style of speaking.", ) - model_with_tools = model.bind_tools([tool_]) + model_with_tools = model.bind_tools([tool_], tool_choice="any") query = "Using the tool, generate a Pirate greeting." result = model_with_tools.invoke(query) assert isinstance(result, AIMessage) From 57d6497ed3219333a878bdb5e0c75f0edb48429b Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Mon, 19 Aug 2024 11:29:58 -0400 Subject: [PATCH 4/9] update docstring --- libs/partners/together/langchain_together/chat_models.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/libs/partners/together/langchain_together/chat_models.py b/libs/partners/together/langchain_together/chat_models.py index 346774d0c0a30..09678e5b5b47a 100644 --- a/libs/partners/together/langchain_together/chat_models.py +++ b/libs/partners/together/langchain_together/chat_models.py @@ -393,8 +393,7 @@ def bind_tools( Options are: name of the tool (str): calls corresponding tool; "auto": automatically selects a tool (including no tool); - "none": does not call a tool; - "any" or "required": force at least one tool to be called; + "any": force at least one tool to be called; True: forces tool call (requires `tools` be length 1); False: no effect; From 0ed918a9e917fe559ff19a186957f842a9a281da Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Mon, 19 Aug 2024 15:16:59 -0400 Subject: [PATCH 5/9] revert --- .../langchain_together/chat_models.py | 52 +------------------ .../test_chat_models_standard.py | 4 ++ 2 files changed, 5 insertions(+), 51 deletions(-) diff --git a/libs/partners/together/langchain_together/chat_models.py b/libs/partners/together/langchain_together/chat_models.py index 09678e5b5b47a..76d79d8d29d75 100644 --- a/libs/partners/together/langchain_together/chat_models.py +++ b/libs/partners/together/langchain_together/chat_models.py @@ -2,28 +2,18 @@ from typing import ( Any, - Callable, Dict, List, - Literal, Optional, - Sequence, - Type, - Union, ) import openai -from langchain_core.language_models import LanguageModelInput from langchain_core.language_models.chat_models import LangSmithParams -from langchain_core.messages import BaseMessage -from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator -from langchain_core.runnables import Runnable -from langchain_core.tools import BaseTool +from langchain_core.pydantic_v1 import Field, SecretStr, root_validator from langchain_core.utils import ( from_env, secret_from_env, ) -from langchain_core.utils.function_calling import convert_to_openai_tool from langchain_openai.chat_models.base import BaseChatOpenAI @@ -372,43 +362,3 @@ def validate_environment(cls, values: Dict) -> Dict: **client_params, **async_specific ).chat.completions return values - - def bind_tools( - self, - tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], - *, - tool_choice: Optional[Union[dict, str, Literal["auto", "any"], bool]] = None, - **kwargs: Any, - ) -> Runnable[LanguageModelInput, BaseMessage]: - """Bind tool-like objects to this chat model. - - Assumes model is compatible with Together tool-calling API. - - Args: - tools: A list of tool definitions to bind to this chat model. - Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic - models, callables, and BaseTools will be automatically converted to - their schema dictionary representation. - tool_choice: Which tool to require the model to call. - Options are: - name of the tool (str): calls corresponding tool; - "auto": automatically selects a tool (including no tool); - "any": force at least one tool to be called; - True: forces tool call (requires `tools` be length 1); - False: no effect; - - or a dict of the form: - {"type": "function", "function": {"name": <>}}. - **kwargs: Any additional parameters to pass to the - :class:`~langchain.runnable.Runnable` constructor. - """ - if tool_choice == "any" and len(tools) == 1: - # Together specifies tool_choice via "auto" or a dict. - # https://docs.together.ai/docs/tool-call-with-other-models#tool_choice - formatted_tool = convert_to_openai_tool(tools[0]) - tool_name = formatted_tool["function"]["name"] - tool_choice = {"type": "function", "function": {"name": tool_name}} - else: - pass - - return super().bind_tools(tools=tools, tool_choice=tool_choice, **kwargs) diff --git a/libs/partners/together/tests/integration_tests/test_chat_models_standard.py b/libs/partners/together/tests/integration_tests/test_chat_models_standard.py index b2377870de58e..2250873f4b659 100644 --- a/libs/partners/together/tests/integration_tests/test_chat_models_standard.py +++ b/libs/partners/together/tests/integration_tests/test_chat_models_standard.py @@ -28,6 +28,10 @@ def chat_model_params(self) -> dict: "rate_limiter": rate_limiter, } + @pytest.mark.xfail(reason=("May not call a tool.")) + def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None: + super().test_tool_calling_with_no_arguments(model) + @pytest.mark.xfail(reason="Not yet supported.") def test_usage_metadata_streaming(self, model: BaseChatModel) -> None: super().test_usage_metadata_streaming(model) From 4470674e6784d1a2b300f1735a798bd6e7d5185d Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Mon, 19 Aug 2024 15:30:36 -0400 Subject: [PATCH 6/9] allow specification of tool_choice_value --- .../tests/integration_tests/test_standard.py | 7 ++++- .../test_chat_models_standard.py | 9 ++++--- .../integration_tests/chat_models.py | 26 ++++++++++++++++--- .../unit_tests/chat_models.py | 5 ++++ 4 files changed, 39 insertions(+), 8 deletions(-) diff --git a/libs/partners/groq/tests/integration_tests/test_standard.py b/libs/partners/groq/tests/integration_tests/test_standard.py index d3483c9d650eb..38fe554c5c779 100644 --- a/libs/partners/groq/tests/integration_tests/test_standard.py +++ b/libs/partners/groq/tests/integration_tests/test_standard.py @@ -1,6 +1,6 @@ """Standard LangChain interface tests""" -from typing import Type +from typing import Optional, Type import pytest from langchain_core.language_models import BaseChatModel @@ -33,6 +33,11 @@ def chat_model_params(self) -> dict: "rate_limiter": rate_limiter, } + @property + def tool_choice_value(self) -> Optional[str]: + """Value to use for tool choice when used in tests.""" + return "any" + @pytest.mark.xfail( reason=("Fails with 'Failed to call a function. Please adjust your prompt.'") ) diff --git a/libs/partners/together/tests/integration_tests/test_chat_models_standard.py b/libs/partners/together/tests/integration_tests/test_chat_models_standard.py index 2250873f4b659..bbf9f502c0304 100644 --- a/libs/partners/together/tests/integration_tests/test_chat_models_standard.py +++ b/libs/partners/together/tests/integration_tests/test_chat_models_standard.py @@ -1,6 +1,6 @@ """Standard LangChain interface tests""" -from typing import Type +from typing import Optional, Type import pytest from langchain_core.language_models import BaseChatModel @@ -28,9 +28,10 @@ def chat_model_params(self) -> dict: "rate_limiter": rate_limiter, } - @pytest.mark.xfail(reason=("May not call a tool.")) - def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None: - super().test_tool_calling_with_no_arguments(model) + @property + def tool_choice_value(self) -> Optional[str]: + """Value to use for tool choice when used in tests.""" + return "dict" @pytest.mark.xfail(reason="Not yet supported.") def test_usage_metadata_streaming(self, model: BaseChatModel) -> None: diff --git a/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py b/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py index d6d08ea93edbd..14df1ba91a2e5 100644 --- a/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py +++ b/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py @@ -170,7 +170,11 @@ def test_stop_sequence(self, model: BaseChatModel) -> None: def test_tool_calling(self, model: BaseChatModel) -> None: if not self.has_tool_calling: pytest.skip("Test requires tool calling.") - model_with_tools = model.bind_tools([magic_function], tool_choice="any") + if self.tool_choice_value == "dict": + tool_choice = {"type": "function", "function": {"name": "magic_function"}} + else: + tool_choice = self.tool_choice_value + model_with_tools = model.bind_tools([magic_function], tool_choice=tool_choice) # Test invoke query = "What is the value of magic_function(3)? Use the tool." @@ -188,7 +192,16 @@ def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None: if not self.has_tool_calling: pytest.skip("Test requires tool calling.") - model_with_tools = model.bind_tools([magic_function_no_args], tool_choice="any") + if self.tool_choice_value == "dict": + tool_choice = { + "type": "function", + "function": {"name": "magic_function_no_args"}, + } + else: + tool_choice = self.tool_choice_value + model_with_tools = model.bind_tools( + [magic_function_no_args], tool_choice=tool_choice + ) query = "What is the value of magic_function()? Use the tool." result = model_with_tools.invoke(query) _validate_tool_call_message_no_args(result) @@ -212,7 +225,14 @@ def test_bind_runnables_as_tools(self, model: BaseChatModel) -> None: name="greeting_generator", description="Generate a greeting in a particular style of speaking.", ) - model_with_tools = model.bind_tools([tool_], tool_choice="any") + if self.tool_choice_value == "dict": + tool_choice = { + "type": "function", + "function": {"name": "greeting_generator"}, + } + else: + tool_choice = self.tool_choice_value + model_with_tools = model.bind_tools([tool_], tool_choice=tool_choice) query = "Using the tool, generate a Pirate greeting." result = model_with_tools.invoke(query) assert isinstance(result, AIMessage) diff --git a/libs/standard-tests/langchain_standard_tests/unit_tests/chat_models.py b/libs/standard-tests/langchain_standard_tests/unit_tests/chat_models.py index ed73771dbdae0..6597b16177be4 100644 --- a/libs/standard-tests/langchain_standard_tests/unit_tests/chat_models.py +++ b/libs/standard-tests/langchain_standard_tests/unit_tests/chat_models.py @@ -96,6 +96,11 @@ def model(self) -> BaseChatModel: def has_tool_calling(self) -> bool: return self.chat_model_class.bind_tools is not BaseChatModel.bind_tools + @property + def tool_choice_value(self) -> Optional[str]: + """Value to use for tool choice when used in tests.""" + return None + @property def has_structured_output(self) -> bool: return ( From a68fa35a90b75c2e92a7ccaf97e64d18eb85fda0 Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Mon, 19 Aug 2024 15:35:24 -0400 Subject: [PATCH 7/9] add type hint --- .../integration_tests/chat_models.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py b/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py index 14df1ba91a2e5..28b9d81bb19c5 100644 --- a/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py +++ b/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py @@ -1,6 +1,6 @@ import base64 import json -from typing import List, Optional +from typing import List, Optional, Union import httpx import pytest @@ -171,7 +171,10 @@ def test_tool_calling(self, model: BaseChatModel) -> None: if not self.has_tool_calling: pytest.skip("Test requires tool calling.") if self.tool_choice_value == "dict": - tool_choice = {"type": "function", "function": {"name": "magic_function"}} + tool_choice: Union[dict, str, None] = { + "type": "function", + "function": {"name": "magic_function"}, + } else: tool_choice = self.tool_choice_value model_with_tools = model.bind_tools([magic_function], tool_choice=tool_choice) @@ -193,7 +196,7 @@ def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None: pytest.skip("Test requires tool calling.") if self.tool_choice_value == "dict": - tool_choice = { + tool_choice: Union[dict, str, None] = { "type": "function", "function": {"name": "magic_function_no_args"}, } @@ -226,7 +229,7 @@ def test_bind_runnables_as_tools(self, model: BaseChatModel) -> None: description="Generate a greeting in a particular style of speaking.", ) if self.tool_choice_value == "dict": - tool_choice = { + tool_choice: Union[dict, str, None] = { "type": "function", "function": {"name": "greeting_generator"}, } From 7202cdd1621946a118a59d0113fcc8e6a803fafa Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Mon, 19 Aug 2024 15:54:31 -0400 Subject: [PATCH 8/9] dict -> tool name --- .../test_chat_models_standard.py | 2 +- .../integration_tests/chat_models.py | 23 ++++++------------- 2 files changed, 8 insertions(+), 17 deletions(-) diff --git a/libs/partners/together/tests/integration_tests/test_chat_models_standard.py b/libs/partners/together/tests/integration_tests/test_chat_models_standard.py index bbf9f502c0304..18c167f8a91dc 100644 --- a/libs/partners/together/tests/integration_tests/test_chat_models_standard.py +++ b/libs/partners/together/tests/integration_tests/test_chat_models_standard.py @@ -31,7 +31,7 @@ def chat_model_params(self) -> dict: @property def tool_choice_value(self) -> Optional[str]: """Value to use for tool choice when used in tests.""" - return "dict" + return "tool_name" @pytest.mark.xfail(reason="Not yet supported.") def test_usage_metadata_streaming(self, model: BaseChatModel) -> None: diff --git a/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py b/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py index 28b9d81bb19c5..32d922a3e4a36 100644 --- a/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py +++ b/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py @@ -1,6 +1,6 @@ import base64 import json -from typing import List, Optional, Union +from typing import List, Optional import httpx import pytest @@ -170,11 +170,8 @@ def test_stop_sequence(self, model: BaseChatModel) -> None: def test_tool_calling(self, model: BaseChatModel) -> None: if not self.has_tool_calling: pytest.skip("Test requires tool calling.") - if self.tool_choice_value == "dict": - tool_choice: Union[dict, str, None] = { - "type": "function", - "function": {"name": "magic_function"}, - } + if self.tool_choice_value == "tool_name": + tool_choice: Optional[str] = "magic_function" else: tool_choice = self.tool_choice_value model_with_tools = model.bind_tools([magic_function], tool_choice=tool_choice) @@ -195,11 +192,8 @@ def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None: if not self.has_tool_calling: pytest.skip("Test requires tool calling.") - if self.tool_choice_value == "dict": - tool_choice: Union[dict, str, None] = { - "type": "function", - "function": {"name": "magic_function_no_args"}, - } + if self.tool_choice_value == "tool_name": + tool_choice: Optional[str] = "magic_function_no_args" else: tool_choice = self.tool_choice_value model_with_tools = model.bind_tools( @@ -228,11 +222,8 @@ def test_bind_runnables_as_tools(self, model: BaseChatModel) -> None: name="greeting_generator", description="Generate a greeting in a particular style of speaking.", ) - if self.tool_choice_value == "dict": - tool_choice: Union[dict, str, None] = { - "type": "function", - "function": {"name": "greeting_generator"}, - } + if self.tool_choice_value == "tool_name": + tool_choice: Optional[str] = "greeting_generator" else: tool_choice = self.tool_choice_value model_with_tools = model.bind_tools([tool_], tool_choice=tool_choice) From fc11b472182726b30f6bd6abf2166f923454cc00 Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Mon, 19 Aug 2024 16:03:44 -0400 Subject: [PATCH 9/9] updat --- .../mistralai/tests/integration_tests/test_standard.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/libs/partners/mistralai/tests/integration_tests/test_standard.py b/libs/partners/mistralai/tests/integration_tests/test_standard.py index 965cd03c4b178..cea6399ee4cd8 100644 --- a/libs/partners/mistralai/tests/integration_tests/test_standard.py +++ b/libs/partners/mistralai/tests/integration_tests/test_standard.py @@ -1,6 +1,6 @@ """Standard LangChain interface tests""" -from typing import Type +from typing import Optional, Type from langchain_core.language_models import BaseChatModel from langchain_standard_tests.integration_tests import ( # type: ignore[import-not-found] @@ -18,3 +18,8 @@ def chat_model_class(self) -> Type[BaseChatModel]: @property def chat_model_params(self) -> dict: return {"model": "mistral-large-latest", "temperature": 0} + + @property + def tool_choice_value(self) -> Optional[str]: + """Value to use for tool choice when used in tests.""" + return "any"