diff --git a/libs/community/langchain_community/utilities/notdiamond.py b/libs/community/langchain_community/utilities/notdiamond.py index 87e0a0bc3604e..41819e0034c1d 100644 --- a/libs/community/langchain_community/utilities/notdiamond.py +++ b/libs/community/langchain_community/utilities/notdiamond.py @@ -36,7 +36,15 @@ def __init__( nd_llm_configs: Optional[List] = None, nd_api_key: Optional[str] = None, nd_client: Optional[Any] = None, + nd_kwargs: Optional[Dict[str, Any]] = None, ): + """ + Params: + nd_llm_configs: List of LLM configs to use. + nd_api_key: Not Diamond API key. + nd_client: Not Diamond client. + nd_kwargs: Keyword arguments to pass directly to model_select. + """ if not nd_client: if not nd_api_key or not nd_llm_configs: raise ValueError( @@ -68,10 +76,13 @@ def __init__( self.client = nd_client self.api_key = nd_client.api_key self.llm_configs = nd_client.llm_configs + self.nd_kwargs = nd_kwargs or dict() def _model_select(self, input: LanguageModelInput) -> str: messages = _convert_input_to_message_dicts(input) - _, provider = self.client.chat.completions.model_select(messages=messages) + _, provider = self.client.chat.completions.model_select( + messages=messages, **self.nd_kwargs + ) provider_str = _nd_provider_to_langchain_provider(str(provider)) return provider_str @@ -137,15 +148,26 @@ def __init__( nd_llm_configs: Optional[List] = None, nd_api_key: Optional[str] = None, nd_client: Optional[Any] = None, + nd_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Optional[Dict[Any, Any]], ) -> None: + """ + Params: + nd_llm_configs: List of LLM configs to use. + nd_api_key: Not Diamond API key. + nd_client: Not Diamond client. + nd_kwargs: Keyword arguments to pass directly to model_select. + """ + _nd_kwargs = {kw: kwargs[kw] for kw in kwargs.keys() if kw.startswith("nd_")} + if nd_kwargs: + _nd_kwargs.update(nd_kwargs) + self._ndrunnable = NotDiamondRunnable( nd_api_key=nd_api_key, nd_llm_configs=nd_llm_configs, nd_client=nd_client, + nd_kwargs=_nd_kwargs, ) - _nd_kwargs = {kw for kw in kwargs.keys() if kw.startswith("nd_")} - _routed_fields = ["model", "model_provider"] if configurable_fields is None: configurable_fields = [] @@ -179,7 +201,7 @@ def invoke( def batch( self, - inputs: List[LanguageModelInput], + inputs: Sequence[LanguageModelInput], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, **kwargs: Optional[Any], ) -> List[Any]: @@ -194,7 +216,7 @@ def batch( for i, ps in enumerate(provider_strs) ] - return self._configurable_model.batch(inputs, config=_configs) + return self._configurable_model.batch([i for i in inputs], config=_configs) async def astream( self, @@ -219,7 +241,7 @@ async def ainvoke( async def abatch( self, - inputs: List[LanguageModelInput], + inputs: Sequence[LanguageModelInput], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, **kwargs: Optional[Any], ) -> List[Any]: @@ -236,7 +258,9 @@ async def abatch( for i, ps in enumerate(provider_strs) ] - return await self._configurable_model.abatch(inputs, config=_configs) + return await self._configurable_model.abatch( + [i for i in inputs], config=_configs + ) def _build_model_config( self, provider_str: str, config: Optional[RunnableConfig] = None diff --git a/libs/community/tests/unit_tests/utilities/test_notdiamond.py b/libs/community/tests/unit_tests/utilities/test_notdiamond.py index 75e222b89100f..8fd50f3137391 100644 --- a/libs/community/tests/unit_tests/utilities/test_notdiamond.py +++ b/libs/community/tests/unit_tests/utilities/test_notdiamond.py @@ -43,7 +43,9 @@ def nd_client(llm_configs: List[Any]) -> Any: def not_diamond_runnable(nd_client: Any) -> NotDiamondRunnable: with patch("langchain_community.utilities.notdiamond.LLMConfig") as mock_llm_config: mock_llm_config.from_string.return_value = MagicMock(provider="openai") - runnable = NotDiamondRunnable(nd_client=nd_client) + runnable = NotDiamondRunnable( + nd_client=nd_client, nd_kwargs={"tradeoff": "cost"} + ) return runnable @@ -51,38 +53,52 @@ def not_diamond_runnable(nd_client: Any) -> NotDiamondRunnable: def not_diamond_routed_runnable(nd_client: Any) -> NotDiamondRoutedRunnable: with patch("langchain_community.utilities.notdiamond.LLMConfig") as mock_llm_config: mock_llm_config.from_string.return_value = MagicMock(provider="openai") - routed_runnable = NotDiamondRoutedRunnable(nd_client=nd_client) + routed_runnable = NotDiamondRoutedRunnable( + nd_client=nd_client, nd_kwargs={"tradeoff": "cost"} + ) routed_runnable._configurable_model = MagicMock(spec=_ConfigurableModel) return routed_runnable class TestNotDiamondRunnable: def test_model_select( - self, not_diamond_runnable: NotDiamondRunnable, llm_configs: List + self, + not_diamond_runnable: NotDiamondRunnable, + llm_configs: List, + nd_client: Any, ) -> None: - actual_select = not_diamond_runnable._model_select("Hello, world!") + prompt = "Hello, world!" + actual_select = not_diamond_runnable._model_select(prompt) assert str(actual_select) in [ _nd_provider_to_langchain_provider(str(config)) for config in llm_configs ] + assert nd_client.model_select.called_with(prompt, tradeoff="cost") @pytest.mark.asyncio async def test_amodel_select( - self, not_diamond_runnable: NotDiamondRunnable, llm_configs: List + self, + not_diamond_runnable: NotDiamondRunnable, + llm_configs: List, + nd_client: Any, ) -> None: - actual_select = await not_diamond_runnable._amodel_select("Hello, world!") + prompt = "Hello, world!" + actual_select = await not_diamond_runnable._amodel_select(prompt) assert str(actual_select) in [ _nd_provider_to_langchain_provider(str(config)) for config in llm_configs ] + assert nd_client.amodel_select.called_with(prompt, tradeoff="cost") class TestNotDiamondRoutedRunnable: def test_invoke( - self, not_diamond_routed_runnable: NotDiamondRoutedRunnable + self, not_diamond_routed_runnable: NotDiamondRoutedRunnable, nd_client: Any ) -> None: - not_diamond_routed_runnable.invoke("Hello, world!") + prompt = "Hello, world!" + not_diamond_routed_runnable.invoke(prompt) assert ( not_diamond_routed_runnable._configurable_model.invoke.called # type: ignore[attr-defined] ), f"{not_diamond_routed_runnable._configurable_model}" + assert nd_client.model_select.called_with(prompt, tradeoff="cost") # Check the call list call_list = ( @@ -93,19 +109,25 @@ def test_invoke( assert args[0] == "Hello, world!" def test_stream( - self, not_diamond_routed_runnable: NotDiamondRoutedRunnable + self, not_diamond_routed_runnable: NotDiamondRoutedRunnable, nd_client: Any ) -> None: - for result in not_diamond_routed_runnable.stream("Hello, world!"): + prompt = "Hello, world!" + for result in not_diamond_routed_runnable.stream(prompt): assert result is not None assert ( not_diamond_routed_runnable._configurable_model.stream.called # type: ignore[attr-defined] ), f"{not_diamond_routed_runnable._configurable_model}" + assert nd_client.model_select.called_with(prompt, tradeoff="cost") - def test_batch(self, not_diamond_routed_runnable: NotDiamondRoutedRunnable) -> None: - not_diamond_routed_runnable.batch(["Hello, world!", "How are you today?"]) + def test_batch( + self, not_diamond_routed_runnable: NotDiamondRoutedRunnable, nd_client: Any + ) -> None: + prompts = ["Hello, world!", "How are you today?"] + not_diamond_routed_runnable.batch(prompts) assert ( not_diamond_routed_runnable._configurable_model.batch.called # type: ignore[attr-defined] ), f"{not_diamond_routed_runnable._configurable_model}" + assert nd_client.model_select.called_with(prompts, tradeoff="cost") # Check the call list call_list = ( @@ -113,16 +135,18 @@ def test_batch(self, not_diamond_routed_runnable: NotDiamondRoutedRunnable) -> N ) assert len(call_list) == 1 args, kwargs = call_list[0] - assert args[0] == ["Hello, world!", "How are you today?"] + assert args[0] == prompts @pytest.mark.asyncio async def test_ainvoke( - self, not_diamond_routed_runnable: NotDiamondRoutedRunnable + self, not_diamond_routed_runnable: NotDiamondRoutedRunnable, nd_client: Any ) -> None: - await not_diamond_routed_runnable.ainvoke("Hello, world!") + prompt = "Hello, world!" + await not_diamond_routed_runnable.ainvoke(prompt) assert ( not_diamond_routed_runnable._configurable_model.ainvoke.called # type: ignore[attr-defined] ), f"{not_diamond_routed_runnable._configurable_model}" + assert nd_client.amodel_select.called_with(prompt, tradeoff="cost") # Check the call list call_list = ( @@ -134,24 +158,26 @@ async def test_ainvoke( @pytest.mark.asyncio async def test_astream( - self, not_diamond_routed_runnable: NotDiamondRoutedRunnable + self, not_diamond_routed_runnable: NotDiamondRoutedRunnable, nd_client: Any ) -> None: - async for result in not_diamond_routed_runnable.astream("Hello, world!"): + prompt = "Hello, world!" + async for result in not_diamond_routed_runnable.astream(prompt): assert result is not None assert ( not_diamond_routed_runnable._configurable_model.astream.called # type: ignore[attr-defined] ), f"{not_diamond_routed_runnable._configurable_model}" + assert nd_client.amodel_select.called_with(prompt, tradeoff="cost") @pytest.mark.asyncio async def test_abatch( - self, not_diamond_routed_runnable: NotDiamondRoutedRunnable + self, not_diamond_routed_runnable: NotDiamondRoutedRunnable, nd_client: Any ) -> None: - await not_diamond_routed_runnable.abatch( - ["Hello, world!", "How are you today?"] - ) + prompts = ["Hello, world!", "How are you today?"] + await not_diamond_routed_runnable.abatch(prompts) assert ( not_diamond_routed_runnable._configurable_model.abatch.called # type: ignore[attr-defined] ), f"{not_diamond_routed_runnable._configurable_model}" + assert nd_client.amodel_select.called_with(prompts, tradeoff="cost") # Check the call list call_list = ( @@ -159,7 +185,7 @@ async def test_abatch( ) assert len(call_list) == 1 args, kwargs = call_list[0] - assert args[0] == ["Hello, world!", "How are you today?"] + assert args[0] == prompts def test_invokable_mock(self) -> None: target_model = "openai/gpt-4o"