Skip to content

Commit

Permalink
test(LAB-3258): add unit and e2e tests for create_conversation with…
Browse files Browse the repository at this point in the history
… system prompt
  • Loading branch information
paulruelle committed Dec 3, 2024
1 parent 72e9f6a commit 8a3d722
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 11 deletions.
16 changes: 13 additions & 3 deletions tests/e2e/test_e2e_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
MODEL_NAME = "E2E Test Model"
UPDATED_MODEL_NAME = "E2E Test Model Updated"
PROMPT = "Hello, world !"
SYSTEM_PROMPT = "You're an helpful AI assistant with cutsom instructions"

INTERFACE = {
"jobs": {
Expand Down Expand Up @@ -106,13 +107,22 @@ def get_project_model_by_id(models, model_id):
assert updated_project_model is not None
assert updated_project_model["configuration"] == updated_project_model_config_1

chat_items = kili.llm.create_conversation(project_id=project_id, prompt=PROMPT)
chat_items = kili.llm.create_conversation(project_id=project_id, initial_prompt=PROMPT)

assert len(chat_items) == 3
assert chat_items[0]["content"] == PROMPT
assert chat_items[0]["role"] == ChatItemRole.USER
assert chat_items[1]["role"] == ChatItemRole.ASSISTANT
assert chat_items[2]["role"] == ChatItemRole.ASSISTANT
assert chat_items[1]["role"] == chat_items[2]["role"] == ChatItemRole.ASSISTANT

chat_items = kili.llm.create_conversation(
project_id=project_id, initial_prompt=PROMPT, system_prompt=SYSTEM_PROMPT
)

assert len(chat_items) == 4
assert chat_items[0]["content"] == PROMPT
assert chat_items[0]["role"] == ChatItemRole.SYSTEM
assert chat_items[1]["role"] == ChatItemRole.USER
assert chat_items[2]["role"] == chat_items[3]["role"] == ChatItemRole.ASSISTANT

assets = kili.assets(project_id)
assert len(assets) == 1
Expand Down
80 changes: 72 additions & 8 deletions tests/unit/llm/test_create_conversation.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
from kili.domain.llm import ChatItemRole
from kili.llm.presentation.client.llm import LlmClientMethods


def test_create_conversation(mocker):
mock_get_current_user = {"id": "user_id"}
mock_llm_asset = {"id": "asset_id", "latestLabel": {"id": "label_id"}}
mock_chat_item = {
"id": "chat_item_id",
"asset_id": "asset_id",
"label_id": "label_id",
"prompt": "prompt text",
}
mock_chat_item = [
{
"asset_id": "asset_id",
"label_id": "label_id",
"parent_id": None,
"prompt": "prompt text",
"role": ChatItemRole.USER,
}
]

kili_api_gateway = mocker.MagicMock()
kili_api_gateway.get_current_user.return_value = mock_get_current_user
Expand All @@ -18,7 +22,7 @@ def test_create_conversation(mocker):

kili_llm = LlmClientMethods(kili_api_gateway)

result = kili_llm.create_conversation(project_id="project_id", prompt="prompt text")
result = kili_llm.create_conversation(project_id="project_id", initial_prompt="prompt text")

assert result == mock_chat_item

Expand All @@ -27,5 +31,65 @@ def test_create_conversation(mocker):
project_id="project_id", author_id="user_id", label_type="PREDICTION"
)
kili_api_gateway.create_chat_item.assert_called_once_with(
asset_id="asset_id", label_id="label_id", prompt="prompt text"
asset_id="asset_id",
label_id="label_id",
prompt="prompt text",
parent_id=None,
role=ChatItemRole.USER,
)


def test_create_conversation_with_system_prompt(mocker):
mock_get_current_user = {"id": "user_id"}
mock_llm_asset = {"id": "asset_id", "latestLabel": {"id": "label_id"}}
mock_system_chat_item = [
{
"id": "system_id",
"asset_id": "asset_id",
"label_id": "label_id",
"prompt": "system prompt text",
"role": ChatItemRole.SYSTEM,
}
]
mock_prompt_chat_item = [
{
"asset_id": "asset_id",
"label_id": "label_id",
"parent_id": "system_id",
"prompt": "user prompt text",
"role": ChatItemRole.USER,
}
]

kili_api_gateway = mocker.MagicMock()
kili_api_gateway.get_current_user.return_value = mock_get_current_user
kili_api_gateway.create_llm_asset.return_value = mock_llm_asset
kili_api_gateway.create_chat_item.side_effect = [mock_system_chat_item, mock_prompt_chat_item]

kili_llm = LlmClientMethods(kili_api_gateway)

result = kili_llm.create_conversation(
project_id="project_id",
initial_prompt="user prompt text",
system_prompt="system prompt text",
)

assert result == mock_system_chat_item + mock_prompt_chat_item

kili_api_gateway.get_current_user.assert_called_once_with(["id"])
kili_api_gateway.create_llm_asset.assert_called_once_with(
project_id="project_id", author_id="user_id", label_type="PREDICTION"
)
kili_api_gateway.create_chat_item.assert_any_call(
asset_id="asset_id",
label_id="label_id",
prompt="system prompt text",
role=ChatItemRole.SYSTEM,
)
kili_api_gateway.create_chat_item.assert_any_call(
asset_id="asset_id",
label_id="label_id",
prompt="user prompt text",
role=ChatItemRole.USER,
parent_id="system_id",
)

0 comments on commit 8a3d722

Please sign in to comment.