Skip to content

Commit

Permalink
Adding optional nd_kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
acompa committed Sep 25, 2024
1 parent 187a673 commit 89bdcc6
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 29 deletions.
38 changes: 31 additions & 7 deletions libs/community/langchain_community/utilities/notdiamond.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,15 @@ def __init__(
nd_llm_configs: Optional[List] = None,
nd_api_key: Optional[str] = None,
nd_client: Optional[Any] = None,
nd_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Params:
nd_llm_configs: List of LLM configs to use.
nd_api_key: Not Diamond API key.
nd_client: Not Diamond client.
nd_kwargs: Keyword arguments to pass directly to model_select.
"""
if not nd_client:
if not nd_api_key or not nd_llm_configs:
raise ValueError(
Expand Down Expand Up @@ -68,10 +76,13 @@ def __init__(
self.client = nd_client
self.api_key = nd_client.api_key
self.llm_configs = nd_client.llm_configs
self.nd_kwargs = nd_kwargs or dict()

def _model_select(self, input: LanguageModelInput) -> str:
messages = _convert_input_to_message_dicts(input)
_, provider = self.client.chat.completions.model_select(messages=messages)
_, provider = self.client.chat.completions.model_select(
messages=messages, **self.nd_kwargs
)
provider_str = _nd_provider_to_langchain_provider(str(provider))
return provider_str

Expand Down Expand Up @@ -137,15 +148,26 @@ def __init__(
nd_llm_configs: Optional[List] = None,
nd_api_key: Optional[str] = None,
nd_client: Optional[Any] = None,
nd_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Optional[Dict[Any, Any]],
) -> None:
"""
Params:
nd_llm_configs: List of LLM configs to use.
nd_api_key: Not Diamond API key.
nd_client: Not Diamond client.
nd_kwargs: Keyword arguments to pass directly to model_select.
"""
_nd_kwargs = {kw: kwargs[kw] for kw in kwargs.keys() if kw.startswith("nd_")}
if nd_kwargs:
_nd_kwargs.update(nd_kwargs)

self._ndrunnable = NotDiamondRunnable(
nd_api_key=nd_api_key,
nd_llm_configs=nd_llm_configs,
nd_client=nd_client,
nd_kwargs=_nd_kwargs,
)
_nd_kwargs = {kw for kw in kwargs.keys() if kw.startswith("nd_")}

_routed_fields = ["model", "model_provider"]
if configurable_fields is None:
configurable_fields = []
Expand Down Expand Up @@ -179,7 +201,7 @@ def invoke(

def batch(
self,
inputs: List[LanguageModelInput],
inputs: Sequence[LanguageModelInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
**kwargs: Optional[Any],
) -> List[Any]:
Expand All @@ -194,7 +216,7 @@ def batch(
for i, ps in enumerate(provider_strs)
]

return self._configurable_model.batch(inputs, config=_configs)
return self._configurable_model.batch([i for i in inputs], config=_configs)

async def astream(
self,
Expand All @@ -219,7 +241,7 @@ async def ainvoke(

async def abatch(
self,
inputs: List[LanguageModelInput],
inputs: Sequence[LanguageModelInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
**kwargs: Optional[Any],
) -> List[Any]:
Expand All @@ -236,7 +258,9 @@ async def abatch(
for i, ps in enumerate(provider_strs)
]

return await self._configurable_model.abatch(inputs, config=_configs)
return await self._configurable_model.abatch(
[i for i in inputs], config=_configs
)

def _build_model_config(
self, provider_str: str, config: Optional[RunnableConfig] = None
Expand Down
70 changes: 48 additions & 22 deletions libs/community/tests/unit_tests/utilities/test_notdiamond.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,46 +43,62 @@ def nd_client(llm_configs: List[Any]) -> Any:
def not_diamond_runnable(nd_client: Any) -> NotDiamondRunnable:
with patch("langchain_community.utilities.notdiamond.LLMConfig") as mock_llm_config:
mock_llm_config.from_string.return_value = MagicMock(provider="openai")
runnable = NotDiamondRunnable(nd_client=nd_client)
runnable = NotDiamondRunnable(
nd_client=nd_client, nd_kwargs={"tradeoff": "cost"}
)
return runnable


@pytest.fixture
def not_diamond_routed_runnable(nd_client: Any) -> NotDiamondRoutedRunnable:
with patch("langchain_community.utilities.notdiamond.LLMConfig") as mock_llm_config:
mock_llm_config.from_string.return_value = MagicMock(provider="openai")
routed_runnable = NotDiamondRoutedRunnable(nd_client=nd_client)
routed_runnable = NotDiamondRoutedRunnable(
nd_client=nd_client, nd_kwargs={"tradeoff": "cost"}
)
routed_runnable._configurable_model = MagicMock(spec=_ConfigurableModel)
return routed_runnable


class TestNotDiamondRunnable:
def test_model_select(
self, not_diamond_runnable: NotDiamondRunnable, llm_configs: List
self,
not_diamond_runnable: NotDiamondRunnable,
llm_configs: List,
nd_client: Any,
) -> None:
actual_select = not_diamond_runnable._model_select("Hello, world!")
prompt = "Hello, world!"
actual_select = not_diamond_runnable._model_select(prompt)
assert str(actual_select) in [
_nd_provider_to_langchain_provider(str(config)) for config in llm_configs
]
assert nd_client.model_select.called_with(prompt, tradeoff="cost")

@pytest.mark.asyncio
async def test_amodel_select(
self, not_diamond_runnable: NotDiamondRunnable, llm_configs: List
self,
not_diamond_runnable: NotDiamondRunnable,
llm_configs: List,
nd_client: Any,
) -> None:
actual_select = await not_diamond_runnable._amodel_select("Hello, world!")
prompt = "Hello, world!"
actual_select = await not_diamond_runnable._amodel_select(prompt)
assert str(actual_select) in [
_nd_provider_to_langchain_provider(str(config)) for config in llm_configs
]
assert nd_client.amodel_select.called_with(prompt, tradeoff="cost")


class TestNotDiamondRoutedRunnable:
def test_invoke(
self, not_diamond_routed_runnable: NotDiamondRoutedRunnable
self, not_diamond_routed_runnable: NotDiamondRoutedRunnable, nd_client: Any
) -> None:
not_diamond_routed_runnable.invoke("Hello, world!")
prompt = "Hello, world!"
not_diamond_routed_runnable.invoke(prompt)
assert (
not_diamond_routed_runnable._configurable_model.invoke.called # type: ignore[attr-defined]
), f"{not_diamond_routed_runnable._configurable_model}"
assert nd_client.model_select.called_with(prompt, tradeoff="cost")

# Check the call list
call_list = (
Expand All @@ -93,36 +109,44 @@ def test_invoke(
assert args[0] == "Hello, world!"

def test_stream(
self, not_diamond_routed_runnable: NotDiamondRoutedRunnable
self, not_diamond_routed_runnable: NotDiamondRoutedRunnable, nd_client: Any
) -> None:
for result in not_diamond_routed_runnable.stream("Hello, world!"):
prompt = "Hello, world!"
for result in not_diamond_routed_runnable.stream(prompt):
assert result is not None
assert (
not_diamond_routed_runnable._configurable_model.stream.called # type: ignore[attr-defined]
), f"{not_diamond_routed_runnable._configurable_model}"
assert nd_client.model_select.called_with(prompt, tradeoff="cost")

def test_batch(self, not_diamond_routed_runnable: NotDiamondRoutedRunnable) -> None:
not_diamond_routed_runnable.batch(["Hello, world!", "How are you today?"])
def test_batch(
self, not_diamond_routed_runnable: NotDiamondRoutedRunnable, nd_client: Any
) -> None:
prompts = ["Hello, world!", "How are you today?"]
not_diamond_routed_runnable.batch(prompts)
assert (
not_diamond_routed_runnable._configurable_model.batch.called # type: ignore[attr-defined]
), f"{not_diamond_routed_runnable._configurable_model}"
assert nd_client.model_select.called_with(prompts, tradeoff="cost")

# Check the call list
call_list = (
not_diamond_routed_runnable._configurable_model.batch.call_args_list # type: ignore[attr-defined]
)
assert len(call_list) == 1
args, kwargs = call_list[0]
assert args[0] == ["Hello, world!", "How are you today?"]
assert args[0] == prompts

@pytest.mark.asyncio
async def test_ainvoke(
self, not_diamond_routed_runnable: NotDiamondRoutedRunnable
self, not_diamond_routed_runnable: NotDiamondRoutedRunnable, nd_client: Any
) -> None:
await not_diamond_routed_runnable.ainvoke("Hello, world!")
prompt = "Hello, world!"
await not_diamond_routed_runnable.ainvoke(prompt)
assert (
not_diamond_routed_runnable._configurable_model.ainvoke.called # type: ignore[attr-defined]
), f"{not_diamond_routed_runnable._configurable_model}"
assert nd_client.amodel_select.called_with(prompt, tradeoff="cost")

# Check the call list
call_list = (
Expand All @@ -134,32 +158,34 @@ async def test_ainvoke(

@pytest.mark.asyncio
async def test_astream(
self, not_diamond_routed_runnable: NotDiamondRoutedRunnable
self, not_diamond_routed_runnable: NotDiamondRoutedRunnable, nd_client: Any
) -> None:
async for result in not_diamond_routed_runnable.astream("Hello, world!"):
prompt = "Hello, world!"
async for result in not_diamond_routed_runnable.astream(prompt):
assert result is not None
assert (
not_diamond_routed_runnable._configurable_model.astream.called # type: ignore[attr-defined]
), f"{not_diamond_routed_runnable._configurable_model}"
assert nd_client.amodel_select.called_with(prompt, tradeoff="cost")

@pytest.mark.asyncio
async def test_abatch(
self, not_diamond_routed_runnable: NotDiamondRoutedRunnable
self, not_diamond_routed_runnable: NotDiamondRoutedRunnable, nd_client: Any
) -> None:
await not_diamond_routed_runnable.abatch(
["Hello, world!", "How are you today?"]
)
prompts = ["Hello, world!", "How are you today?"]
await not_diamond_routed_runnable.abatch(prompts)
assert (
not_diamond_routed_runnable._configurable_model.abatch.called # type: ignore[attr-defined]
), f"{not_diamond_routed_runnable._configurable_model}"
assert nd_client.amodel_select.called_with(prompts, tradeoff="cost")

# Check the call list
call_list = (
not_diamond_routed_runnable._configurable_model.abatch.call_args_list # type: ignore[attr-defined]
)
assert len(call_list) == 1
args, kwargs = call_list[0]
assert args[0] == ["Hello, world!", "How are you today?"]
assert args[0] == prompts

def test_invokable_mock(self) -> None:
target_model = "openai/gpt-4o"
Expand Down

0 comments on commit 89bdcc6

Please sign in to comment.