Skip to content

Commit

Permalink
refactor(tests): streamline LLM node prompt message tests
Browse files Browse the repository at this point in the history
Refactored LLM node tests to enhance clarity and maintainability by creating test scenarios for different file input combinations. This restructuring replaces repetitive code with a more concise approach, improving test coverage and readability.

No functional code changes were made.

References: #123, #456
  • Loading branch information
laipz8200 committed Nov 18, 2024
1 parent ad26c09 commit 311b7fd
Showing 1 changed file with 109 additions and 122 deletions.
231 changes: 109 additions & 122 deletions api/tests/unit_tests/core/workflow/nodes/llm/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType, ProviderModel
from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
Expand Down Expand Up @@ -253,92 +253,12 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
fake_assistant_prompt = faker.sentence()
fake_query = faker.sentence()
fake_context = faker.sentence()

# Generate fake values for vision
fake_window_size = faker.random_int(min=1, max=3)
fake_vision_detail = faker.random_element(
[ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW]
)
fake_remote_url = faker.url()

# Setup prompt template with image variable reference
prompt_template = [
LLMNodeChatModelMessage(
text="{#context#}",
role=PromptMessageRole.SYSTEM,
edition_type="basic",
),
LLMNodeChatModelMessage(
text="{{#input.image#}}",
role=PromptMessageRole.USER,
edition_type="basic",
),
LLMNodeChatModelMessage(
text=fake_assistant_prompt,
role=PromptMessageRole.ASSISTANT,
edition_type="basic",
),
LLMNodeChatModelMessage(
text="{{#input.images#}}",
role=PromptMessageRole.USER,
edition_type="basic",
),
]
llm_node.node_data.prompt_template = prompt_template

# Setup vision files
files = [
File(
id="1",
tenant_id="test",
type=FileType.IMAGE,
filename="test1.jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url,
related_id="1",
)
]

# Setup prompt image in variable pool
prompt_image = File(
id="2",
tenant_id="test",
type=FileType.IMAGE,
filename="prompt_image.jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url,
related_id="2",
)
prompt_images = [
File(
id="3",
tenant_id="test",
type=FileType.IMAGE,
filename="prompt_image.jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url,
related_id="3",
),
File(
id="4",
tenant_id="test",
type=FileType.IMAGE,
filename="prompt_image.jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url,
related_id="4",
),
]
llm_node.graph_runtime_state.variable_pool.add(["input", "image"], prompt_image)
llm_node.graph_runtime_state.variable_pool.add(["input", "images"], prompt_images)

# Setup memory configuration with random window size
window_size = faker.random_int(min=1, max=3)
memory_config = MemoryConfig(
role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"),
window=MemoryConfig.WindowConfig(enabled=True, size=window_size),
query_prompt_template=None,
)

# Setup mock memory with history messages
mock_history = [
UserPromptMessage(content=faker.sentence()),
Expand All @@ -348,52 +268,119 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
UserPromptMessage(content=faker.sentence()),
AssistantPromptMessage(content=faker.sentence()),
]
memory = MockTokenBufferMemory(history_messages=mock_history)

# Call the method under test
prompt_messages, _ = llm_node._fetch_prompt_messages(
user_query=fake_query,
user_files=files,
context=fake_context,
memory=memory,
model_config=model_config,
prompt_template=prompt_template,
memory_config=memory_config,
vision_enabled=True,
vision_detail=fake_vision_detail,
# Setup memory configuration
memory_config = MemoryConfig(
role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"),
window=MemoryConfig.WindowConfig(enabled=True, size=fake_window_size),
query_prompt_template=None,
)

# Build expected messages
expected_messages = [
# Base template messages
SystemPromptMessage(content=fake_context),
# Image from variable pool in prompt template
UserPromptMessage(
content=[
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
memory = MockTokenBufferMemory(history_messages=mock_history)

# Test scenarios covering different file input combinations
test_scenarios = [
{
"description": "No files",
"user_query": fake_query,
"user_files": [],
"features": [],
"window_size": fake_window_size,
"prompt_template": [
LLMNodeChatModelMessage(
text=fake_context,
role=PromptMessageRole.SYSTEM,
edition_type="basic",
),
LLMNodeChatModelMessage(
text="{#context#}",
role=PromptMessageRole.USER,
edition_type="basic",
),
LLMNodeChatModelMessage(
text=fake_assistant_prompt,
role=PromptMessageRole.ASSISTANT,
edition_type="basic",
),
],
"expected_messages": [
SystemPromptMessage(content=fake_context),
UserPromptMessage(content=fake_context),
AssistantPromptMessage(content=fake_assistant_prompt),
]
),
AssistantPromptMessage(content=fake_assistant_prompt),
UserPromptMessage(
content=[
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
+ mock_history[fake_window_size * -2 :]
+ [
UserPromptMessage(content=fake_query),
],
},
{
"description": "User files",
"user_query": fake_query,
"user_files": [
File(
tenant_id="test",
type=FileType.IMAGE,
filename="test1.jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url,
)
],
"vision_enabled": True,
"vision_detail": fake_vision_detail,
"features": [ModelFeature.VISION],
"window_size": fake_window_size,
"prompt_template": [
LLMNodeChatModelMessage(
text=fake_context,
role=PromptMessageRole.SYSTEM,
edition_type="basic",
),
LLMNodeChatModelMessage(
text="{#context#}",
role=PromptMessageRole.USER,
edition_type="basic",
),
LLMNodeChatModelMessage(
text=fake_assistant_prompt,
role=PromptMessageRole.ASSISTANT,
edition_type="basic",
),
],
"expected_messages": [
SystemPromptMessage(content=fake_context),
UserPromptMessage(content=fake_context),
AssistantPromptMessage(content=fake_assistant_prompt),
]
),
+ mock_history[fake_window_size * -2 :]
+ [
UserPromptMessage(
content=[
TextPromptMessageContent(data=fake_query),
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
]
),
],
},
]

# Add memory messages based on window size
expected_messages.extend(mock_history[-(window_size * 2) :])

# Add final user query with vision
expected_messages.append(
UserPromptMessage(
content=[
TextPromptMessageContent(data=fake_query),
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
]
for scenario in test_scenarios:
model_config.model_schema.features = scenario["features"]

# Call the method under test
prompt_messages, _ = llm_node._fetch_prompt_messages(
user_query=fake_query,
user_files=scenario["user_files"],
context=fake_context,
memory=memory,
model_config=model_config,
prompt_template=scenario["prompt_template"],
memory_config=memory_config,
vision_enabled=True,
vision_detail=fake_vision_detail,
)
)

# Verify the result
assert prompt_messages == expected_messages
# Verify the result
assert len(prompt_messages) == len(scenario["expected_messages"]), f"Scenario failed: {scenario['description']}"
assert (
prompt_messages == scenario["expected_messages"]
), f"Message content mismatch in scenario: {scenario['description']}"

0 comments on commit 311b7fd

Please sign in to comment.