From 825840718129e35a2cb3c4756154429aa26cfc60 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Wed, 15 May 2024 16:12:46 -0400 Subject: [PATCH] x --- tests/unit_tests/test_server_client.py | 20 ++++--- tests/unit_tests/utils/stubs.py | 6 ++ .../unit_tests/utils/test_fake_chat_model.py | 56 +++++++++++-------- 3 files changed, 52 insertions(+), 30 deletions(-) create mode 100644 tests/unit_tests/utils/stubs.py diff --git a/tests/unit_tests/test_server_client.py b/tests/unit_tests/test_server_client.py index 927346bb..ff6089e1 100644 --- a/tests/unit_tests/test_server_client.py +++ b/tests/unit_tests/test_server_client.py @@ -71,6 +71,7 @@ from langserve.client import RemoteRunnable from langserve.lzstring import LZString from langserve.schema import CustomUserType +from tests.unit_tests.utils.stubs import AnyStr try: from pydantic.v1 import BaseModel, Field @@ -2558,7 +2559,7 @@ async def test_astream_events_with_prompt_model_parser_chain( "tags": ["seq:step:2"], }, { - "data": {"chunk": AIMessageChunk(content="Hello")}, + "data": {"chunk": AIMessageChunk(content="Hello", id=AnyStr())}, "event": "on_chat_model_stream", "name": "GenericFakeChatModel", "tags": ["seq:step:2"], @@ -2582,7 +2583,7 @@ async def test_astream_events_with_prompt_model_parser_chain( "tags": [], }, { - "data": {"chunk": AIMessageChunk(content=" ")}, + "data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())}, "event": "on_chat_model_stream", "name": "GenericFakeChatModel", "tags": ["seq:step:2"], @@ -2600,7 +2601,7 @@ async def test_astream_events_with_prompt_model_parser_chain( "tags": [], }, { - "data": {"chunk": AIMessageChunk(content="World!")}, + "data": {"chunk": AIMessageChunk(content="World!", id=AnyStr())}, "event": "on_chat_model_stream", "name": "GenericFakeChatModel", "tags": ["seq:step:2"], @@ -2632,7 +2633,9 @@ async def test_astream_events_with_prompt_model_parser_chain( [ ChatGenerationChunk( text="Hello World!", - message=AIMessageChunk(content="Hello World!"), + message=AIMessageChunk( + content="Hello World!", id=AnyStr() + ), ) ] ], @@ -2646,7 +2649,7 @@ async def test_astream_events_with_prompt_model_parser_chain( }, { "data": { - "input": AIMessageChunk(content="Hello World!"), + "input": AIMessageChunk(content="Hello World!", id=AnyStr()), "output": "Hello World!", }, "event": "on_parser_end", @@ -2784,10 +2787,13 @@ def get_by_session_id(session_id: str) -> BaseChatMessageHistory: result = await chain_with_history.ainvoke( {"question": "hi"}, {"configurable": {"session_id": "1"}} ) - assert result == AIMessage(content="Hello World!") + assert result == AIMessage(content="Hello World!", id=AnyStr()) assert store == { "1": InMemoryHistory( - messages=[HumanMessage(content="hi"), AIMessage(content="Hello World!")] + messages=[ + HumanMessage(content="hi"), + AIMessage(content="Hello World!", id=AnyStr()), + ] ) } diff --git a/tests/unit_tests/utils/stubs.py b/tests/unit_tests/utils/stubs.py new file mode 100644 index 00000000..38e84a3a --- /dev/null +++ b/tests/unit_tests/utils/stubs.py @@ -0,0 +1,6 @@ +from typing import Any + + +class AnyStr(str): + def __eq__(self, other: Any) -> bool: + return isinstance(other, str) diff --git a/tests/unit_tests/utils/test_fake_chat_model.py b/tests/unit_tests/utils/test_fake_chat_model.py index 89d0cd8b..a11813ba 100644 --- a/tests/unit_tests/utils/test_fake_chat_model.py +++ b/tests/unit_tests/utils/test_fake_chat_model.py @@ -8,6 +8,7 @@ from langchain_core.outputs import ChatGenerationChunk, GenerationChunk from tests.unit_tests.utils.llms import GenericFakeChatModel +from tests.unit_tests.utils.stubs import AnyStr def test_generic_fake_chat_model_invoke() -> None: @@ -15,11 +16,11 @@ def test_generic_fake_chat_model_invoke() -> None: infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")]) model = GenericFakeChatModel(messages=infinite_cycle) response = model.invoke("meow") - assert response == AIMessage(content="hello") + assert response == AIMessage(content="hello", id=AnyStr()) response = model.invoke("kitty") - assert response == AIMessage(content="goodbye") + assert response == AIMessage(content="goodbye", id=AnyStr()) response = model.invoke("meow") - assert response == AIMessage(content="hello") + assert response == AIMessage(content="hello", id=AnyStr()) async def test_generic_fake_chat_model_ainvoke() -> None: @@ -27,11 +28,11 @@ async def test_generic_fake_chat_model_ainvoke() -> None: infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")]) model = GenericFakeChatModel(messages=infinite_cycle) response = await model.ainvoke("meow") - assert response == AIMessage(content="hello") + assert response == AIMessage(content="hello", id=AnyStr()) response = await model.ainvoke("kitty") - assert response == AIMessage(content="goodbye") + assert response == AIMessage(content="goodbye", id=AnyStr()) response = await model.ainvoke("meow") - assert response == AIMessage(content="hello") + assert response == AIMessage(content="hello", id=AnyStr()) async def test_generic_fake_chat_model_stream() -> None: @@ -44,26 +45,28 @@ async def test_generic_fake_chat_model_stream() -> None: model = GenericFakeChatModel(messages=infinite_cycle) chunks = [chunk async for chunk in model.astream("meow")] assert chunks == [ - AIMessageChunk(content="hello"), - AIMessageChunk(content=" "), - AIMessageChunk(content="goodbye"), + AIMessageChunk(content="hello", id=AnyStr()), + AIMessageChunk(content=" ", id=AnyStr()), + AIMessageChunk(content="goodbye", id=AnyStr()), ] chunks = [chunk for chunk in model.stream("meow")] assert chunks == [ - AIMessageChunk(content="hello"), - AIMessageChunk(content=" "), - AIMessageChunk(content="goodbye"), + AIMessageChunk(content="hello", id=AnyStr()), + AIMessageChunk(content=" ", id=AnyStr()), + AIMessageChunk(content="goodbye", id=AnyStr()), ] # Test streaming of additional kwargs. # Relying on insertion order of the additional kwargs dict - message = AIMessage(content="", additional_kwargs={"foo": 42, "bar": 24}) + message = AIMessage( + content="", additional_kwargs={"foo": 42, "bar": 24}, id=AnyStr() + ) model = GenericFakeChatModel(messages=cycle([message])) chunks = [chunk async for chunk in model.astream("meow")] assert chunks == [ - AIMessageChunk(content="", additional_kwargs={"foo": 42}), - AIMessageChunk(content="", additional_kwargs={"bar": 24}), + AIMessageChunk(content="", additional_kwargs={"foo": 42}, id=AnyStr()), + AIMessageChunk(content="", additional_kwargs={"bar": 24}, id=AnyStr()), ] message = AIMessage( @@ -81,22 +84,28 @@ async def test_generic_fake_chat_model_stream() -> None: assert chunks == [ AIMessageChunk( - content="", additional_kwargs={"function_call": {"name": "move_file"}} + content="", + additional_kwargs={"function_call": {"name": "move_file"}}, + id=AnyStr(), ), AIMessageChunk( content="", additional_kwargs={ "function_call": {"arguments": '{\n "source_path": "foo"'} }, + id=AnyStr(), ), AIMessageChunk( - content="", additional_kwargs={"function_call": {"arguments": ","}} + content="", + additional_kwargs={"function_call": {"arguments": ","}}, + id=AnyStr(), ), AIMessageChunk( content="", additional_kwargs={ "function_call": {"arguments": '\n "destination_path": "bar"\n}'} }, + id=AnyStr(), ), ] @@ -116,6 +125,7 @@ async def test_generic_fake_chat_model_stream() -> None: 'destination_path": "bar"\n}', } }, + id=AnyStr(), ) @@ -128,9 +138,9 @@ async def test_generic_fake_chat_model_astream_log() -> None: ] final = log_patches[-1] assert final.state["streamed_output"] == [ - AIMessageChunk(content="hello"), - AIMessageChunk(content=" "), - AIMessageChunk(content="goodbye"), + AIMessageChunk(content="hello", id=AnyStr()), + AIMessageChunk(content=" ", id=AnyStr()), + AIMessageChunk(content="goodbye", id=AnyStr()), ] @@ -178,8 +188,8 @@ async def on_llm_new_token( # New model results = list(model.stream("meow", {"callbacks": [MyCustomAsyncHandler(tokens)]})) assert results == [ - AIMessageChunk(content="hello"), - AIMessageChunk(content=" "), - AIMessageChunk(content="goodbye"), + AIMessageChunk(content="hello", id=AnyStr()), + AIMessageChunk(content=" ", id=AnyStr()), + AIMessageChunk(content="goodbye", id=AnyStr()), ] assert tokens == ["hello", " ", "goodbye"]