Skip to content

Commit

Permalink
Adding tests
Browse files Browse the repository at this point in the history
  • Loading branch information
NotBioWaste905 committed Nov 20, 2024
1 parent 3d79cec commit 8e22b97
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 4 deletions.
1 change: 1 addition & 0 deletions chatsky/llm/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class BaseFilter(BaseModel, abc.ABC):
"""
Base class for all message history filters.
"""

def __call__(self, ctx: Context, request: Message, response: Message, model_name: str) -> bool:
"""
:param ctx: Context object.
Expand Down
81 changes: 77 additions & 4 deletions tests/llm/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@
from chatsky.llm._langchain_imports import HumanMessage, AIMessage, langchain_available
from chatsky.llm.llm_api import LLM_API
from chatsky.responses.llm import LLMResponse
from chatsky.llm.utils import message_to_langchain, attachment_to_content
from chatsky.llm.utils import message_to_langchain, attachment_to_content, context_to_history
from chatsky.llm.filters import IsImportant, FromTheModel
from chatsky.llm.methods import Contains, LogProb
from chatsky.core.message import Message, Image
from chatsky.core.context import Context
from chatsky.core.script import Node
from chatsky.core.node_label import AbsoluteNodeLabel
from chatsky.llm._langchain_imports import LLMResult, HumanMessage
from langchain_core.outputs.chat_generation import ChatGeneration

from chatsky import (
TRANSITIONS,
RESPONSE,
Expand Down Expand Up @@ -54,7 +58,7 @@ class MessageSchema(BaseModel):
history: list[str]

def __call__(self):
return {"history": self.history}
return self.model_dump()


@pytest.fixture
Expand Down Expand Up @@ -186,7 +190,30 @@ async def test_attachments(img, expected):
)
async def test_history(context, pipeline, hist, expected):
res = await LLMResponse(model_name="test_model", history=hist)(context)
assert res.text == expected
assert res == Message(expected, annotations={"__generated_by_model__": "test_model"})


async def test_context_to_history(context):
res = await context_to_history(
ctx=context, length=-1, filter_func=lambda *args: True, model_name="test_model", max_size=100
)
expected = [
HumanMessage(content=[{"type": "text", "text": "Request 0"}]),
AIMessage(content=[{"type": "text", "text": "Response 0"}]),
HumanMessage(content=[{"type": "text", "text": "Request 1"}]),
AIMessage(content=[{"type": "text", "text": "Response 1"}]),
HumanMessage(content=[{"type": "text", "text": "Request 2"}]),
AIMessage(content=[{"type": "text", "text": "Response 2"}]),
]
assert res == expected
res = await context_to_history(
ctx=context, length=1, filter_func=lambda *args: True, model_name="test_model", max_size=100
)
expected = [
HumanMessage(content=[{"type": "text", "text": "Request 2"}]),
AIMessage(content=[{"type": "text", "text": "Response 2"}]),
]
assert res == expected


def test_is_important_filter(filter_context):
Expand All @@ -205,8 +232,54 @@ def test_is_important_filter(filter_context):
def test_model_filter(filter_context):
filter_func = FromTheModel()
ctx = filter_context
# Test filtering important messages
# Test filtering messages from a certain model
assert filter_func(ctx, ctx.requests[1], ctx.responses[1], model_name="test_model")
assert not filter_func(ctx, ctx.requests[2], ctx.responses[2], model_name="test_model")
assert filter_func(ctx, ctx.requests[3], ctx.responses[3], model_name="test_model")
assert filter_func(ctx, ctx.requests[2], ctx.responses[3], model_name="test_model")


@pytest.fixture
def llmresult():
return LLMResult(
generations=[
[
ChatGeneration(
message=HumanMessage(content="this is a very IMPORTANT message"),
generation_info={
"logprobs": {
"content": [
{
"top_logprobs": [
{"token": "true", "logprob": 0.1},
{"token": "false", "logprob": 0.5},
]
}
]
}
},
)
]
]
)


async def test_base_method(llmresult):
c = Contains(pattern="")
assert await c.model_result_to_text(llmresult) == "this is a very IMPORTANT message"


async def test_contains_method(filter_context, llmresult):
ctx = filter_context
c = Contains(pattern="important")
assert await c(ctx, llmresult)
c = Contains(pattern="test")
assert not await c(ctx, llmresult)


async def test_logprob_method(filter_context, llmresult):
ctx = filter_context
c = LogProb(target_token="false", threshold=0.3)
assert await c(ctx, llmresult)
c = LogProb(target_token="true", threshold=0.3)
assert not await c(ctx, llmresult)

0 comments on commit 8e22b97

Please sign in to comment.