Skip to content

Commit

Permalink
x
Browse files Browse the repository at this point in the history
  • Loading branch information
eyurtsev committed May 15, 2024
1 parent c58ad19 commit 8258407
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 30 deletions.
20 changes: 13 additions & 7 deletions tests/unit_tests/test_server_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
from langserve.client import RemoteRunnable
from langserve.lzstring import LZString
from langserve.schema import CustomUserType
from tests.unit_tests.utils.stubs import AnyStr

try:
from pydantic.v1 import BaseModel, Field
Expand Down Expand Up @@ -2558,7 +2559,7 @@ async def test_astream_events_with_prompt_model_parser_chain(
"tags": ["seq:step:2"],
},
{
"data": {"chunk": AIMessageChunk(content="Hello")},
"data": {"chunk": AIMessageChunk(content="Hello", id=AnyStr())},
"event": "on_chat_model_stream",
"name": "GenericFakeChatModel",
"tags": ["seq:step:2"],
Expand All @@ -2582,7 +2583,7 @@ async def test_astream_events_with_prompt_model_parser_chain(
"tags": [],
},
{
"data": {"chunk": AIMessageChunk(content=" ")},
"data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())},
"event": "on_chat_model_stream",
"name": "GenericFakeChatModel",
"tags": ["seq:step:2"],
Expand All @@ -2600,7 +2601,7 @@ async def test_astream_events_with_prompt_model_parser_chain(
"tags": [],
},
{
"data": {"chunk": AIMessageChunk(content="World!")},
"data": {"chunk": AIMessageChunk(content="World!", id=AnyStr())},
"event": "on_chat_model_stream",
"name": "GenericFakeChatModel",
"tags": ["seq:step:2"],
Expand Down Expand Up @@ -2632,7 +2633,9 @@ async def test_astream_events_with_prompt_model_parser_chain(
[
ChatGenerationChunk(
text="Hello World!",
message=AIMessageChunk(content="Hello World!"),
message=AIMessageChunk(
content="Hello World!", id=AnyStr()
),
)
]
],
Expand All @@ -2646,7 +2649,7 @@ async def test_astream_events_with_prompt_model_parser_chain(
},
{
"data": {
"input": AIMessageChunk(content="Hello World!"),
"input": AIMessageChunk(content="Hello World!", id=AnyStr()),
"output": "Hello World!",
},
"event": "on_parser_end",
Expand Down Expand Up @@ -2784,10 +2787,13 @@ def get_by_session_id(session_id: str) -> BaseChatMessageHistory:
result = await chain_with_history.ainvoke(
{"question": "hi"}, {"configurable": {"session_id": "1"}}
)
assert result == AIMessage(content="Hello World!")
assert result == AIMessage(content="Hello World!", id=AnyStr())
assert store == {
"1": InMemoryHistory(
messages=[HumanMessage(content="hi"), AIMessage(content="Hello World!")]
messages=[
HumanMessage(content="hi"),
AIMessage(content="Hello World!", id=AnyStr()),
]
)
}

Expand Down
6 changes: 6 additions & 0 deletions tests/unit_tests/utils/stubs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from typing import Any


class AnyStr(str):
def __eq__(self, other: Any) -> bool:
return isinstance(other, str)
56 changes: 33 additions & 23 deletions tests/unit_tests/utils/test_fake_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,31 @@
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk

from tests.unit_tests.utils.llms import GenericFakeChatModel
from tests.unit_tests.utils.stubs import AnyStr


def test_generic_fake_chat_model_invoke() -> None:
# Will alternate between responding with hello and goodbye
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
model = GenericFakeChatModel(messages=infinite_cycle)
response = model.invoke("meow")
assert response == AIMessage(content="hello")
assert response == AIMessage(content="hello", id=AnyStr())
response = model.invoke("kitty")
assert response == AIMessage(content="goodbye")
assert response == AIMessage(content="goodbye", id=AnyStr())
response = model.invoke("meow")
assert response == AIMessage(content="hello")
assert response == AIMessage(content="hello", id=AnyStr())


async def test_generic_fake_chat_model_ainvoke() -> None:
# Will alternate between responding with hello and goodbye
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
model = GenericFakeChatModel(messages=infinite_cycle)
response = await model.ainvoke("meow")
assert response == AIMessage(content="hello")
assert response == AIMessage(content="hello", id=AnyStr())
response = await model.ainvoke("kitty")
assert response == AIMessage(content="goodbye")
assert response == AIMessage(content="goodbye", id=AnyStr())
response = await model.ainvoke("meow")
assert response == AIMessage(content="hello")
assert response == AIMessage(content="hello", id=AnyStr())


async def test_generic_fake_chat_model_stream() -> None:
Expand All @@ -44,26 +45,28 @@ async def test_generic_fake_chat_model_stream() -> None:
model = GenericFakeChatModel(messages=infinite_cycle)
chunks = [chunk async for chunk in model.astream("meow")]
assert chunks == [
AIMessageChunk(content="hello"),
AIMessageChunk(content=" "),
AIMessageChunk(content="goodbye"),
AIMessageChunk(content="hello", id=AnyStr()),
AIMessageChunk(content=" ", id=AnyStr()),
AIMessageChunk(content="goodbye", id=AnyStr()),
]

chunks = [chunk for chunk in model.stream("meow")]
assert chunks == [
AIMessageChunk(content="hello"),
AIMessageChunk(content=" "),
AIMessageChunk(content="goodbye"),
AIMessageChunk(content="hello", id=AnyStr()),
AIMessageChunk(content=" ", id=AnyStr()),
AIMessageChunk(content="goodbye", id=AnyStr()),
]

# Test streaming of additional kwargs.
# Relying on insertion order of the additional kwargs dict
message = AIMessage(content="", additional_kwargs={"foo": 42, "bar": 24})
message = AIMessage(
content="", additional_kwargs={"foo": 42, "bar": 24}, id=AnyStr()
)
model = GenericFakeChatModel(messages=cycle([message]))
chunks = [chunk async for chunk in model.astream("meow")]
assert chunks == [
AIMessageChunk(content="", additional_kwargs={"foo": 42}),
AIMessageChunk(content="", additional_kwargs={"bar": 24}),
AIMessageChunk(content="", additional_kwargs={"foo": 42}, id=AnyStr()),
AIMessageChunk(content="", additional_kwargs={"bar": 24}, id=AnyStr()),
]

message = AIMessage(
Expand All @@ -81,22 +84,28 @@ async def test_generic_fake_chat_model_stream() -> None:

assert chunks == [
AIMessageChunk(
content="", additional_kwargs={"function_call": {"name": "move_file"}}
content="",
additional_kwargs={"function_call": {"name": "move_file"}},
id=AnyStr(),
),
AIMessageChunk(
content="",
additional_kwargs={
"function_call": {"arguments": '{\n "source_path": "foo"'}
},
id=AnyStr(),
),
AIMessageChunk(
content="", additional_kwargs={"function_call": {"arguments": ","}}
content="",
additional_kwargs={"function_call": {"arguments": ","}},
id=AnyStr(),
),
AIMessageChunk(
content="",
additional_kwargs={
"function_call": {"arguments": '\n "destination_path": "bar"\n}'}
},
id=AnyStr(),
),
]

Expand All @@ -116,6 +125,7 @@ async def test_generic_fake_chat_model_stream() -> None:
'destination_path": "bar"\n}',
}
},
id=AnyStr(),
)


Expand All @@ -128,9 +138,9 @@ async def test_generic_fake_chat_model_astream_log() -> None:
]
final = log_patches[-1]
assert final.state["streamed_output"] == [
AIMessageChunk(content="hello"),
AIMessageChunk(content=" "),
AIMessageChunk(content="goodbye"),
AIMessageChunk(content="hello", id=AnyStr()),
AIMessageChunk(content=" ", id=AnyStr()),
AIMessageChunk(content="goodbye", id=AnyStr()),
]


Expand Down Expand Up @@ -178,8 +188,8 @@ async def on_llm_new_token(
# New model
results = list(model.stream("meow", {"callbacks": [MyCustomAsyncHandler(tokens)]}))
assert results == [
AIMessageChunk(content="hello"),
AIMessageChunk(content=" "),
AIMessageChunk(content="goodbye"),
AIMessageChunk(content="hello", id=AnyStr()),
AIMessageChunk(content=" ", id=AnyStr()),
AIMessageChunk(content="goodbye", id=AnyStr()),
]
assert tokens == ["hello", " ", "goodbye"]

0 comments on commit 8258407

Please sign in to comment.