Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Allow using file variables directly in the LLM node and support more file types. #10679

Merged
merged 36 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
3c33c39
chore(deps): add faker
laipz8200 Nov 13, 2024
c8330e0
refactor(converter): simplify model credentials validation logic
laipz8200 Nov 13, 2024
61ea2dd
refactor: update stop parameter type to use Sequence instead of list
laipz8200 Nov 13, 2024
3687ea6
refactor: update jinja2_variables and prompt_config to use Sequence a…
laipz8200 Nov 13, 2024
223e03a
feat(errors): add new error classes for unsupported prompt types and …
laipz8200 Nov 13, 2024
bd60d0f
fix(tests): update Azure Rerank Model usage and clean imports
laipz8200 Nov 13, 2024
37e0a38
refactor(prompt): enhance type flexibility for prompt messages
laipz8200 Nov 14, 2024
9819825
refactor(model_runtime): use Sequence for content in PromptMessage
laipz8200 Nov 14, 2024
062c495
chore(config): remove unnecessary 'frozen' parameter for test
laipz8200 Nov 14, 2024
37b1347
fix(dependencies): update Faker version constraint
laipz8200 Nov 14, 2024
a018002
refactor(memory): use Sequence instead of list for prompt messages
laipz8200 Nov 14, 2024
6810529
refactor(model_manager): update parameter type for flexibility
laipz8200 Nov 14, 2024
070dc2d
Remove unnecessary data from log and text properties
laipz8200 Nov 14, 2024
fb506be
feat(llm_node): allow to use image file directly in the prompt.
laipz8200 Nov 14, 2024
651f584
Simplify test setup in LLM node tests
laipz8200 Nov 14, 2024
cd0a8ea
refactor(tests): streamline LLM node prompt message tests
laipz8200 Nov 14, 2024
b1a60bf
feat(tests): refactor LLMNode tests for clarity
laipz8200 Nov 14, 2024
8b1b81b
fix(node): handle empty text segments gracefully
laipz8200 Nov 14, 2024
1bfdbaf
feat: enhance image handling in prompt processing
laipz8200 Nov 14, 2024
ad9152f
fix: ensure workflow run persistence before refresh
laipz8200 Nov 14, 2024
800d64c
feat: add support for document, video, and audio content
laipz8200 Nov 14, 2024
7876d64
fix(file-manager): enforce file extension presence
laipz8200 Nov 14, 2024
f4bdff1
feat(config-prompt): add support for file variables
laipz8200 Nov 14, 2024
009c7c7
refactor(node.py): streamline template rendering
laipz8200 Nov 14, 2024
bbcf184
refactor(core): decouple LLMNode prompt handling
laipz8200 Nov 14, 2024
9e23313
feat(llm-panel): refine variable filtering logic
laipz8200 Nov 14, 2024
8039511
fix(api/core/app/task_pipeline/workflow_cycle_manage.py) workflow ses…
laipz8200 Nov 18, 2024
f83b775
feat(entities): add document prompt message content
laipz8200 Nov 18, 2024
313454e
feat(api): add document support in prompt message content
laipz8200 Nov 18, 2024
72258db
feat(api): support document file type in message handling
laipz8200 Nov 18, 2024
9692d57
feat(llm): add document support for message prompts
laipz8200 Nov 18, 2024
ed00f7b
fix: remove redundant exception message
laipz8200 Nov 18, 2024
02ec334
feat: code editor and code input can not insert file type vars
iamjoel Nov 19, 2024
5882cdc
chore: jinja import not choose file
iamjoel Nov 19, 2024
3528b2d
chore: use query not support file var
iamjoel Nov 19, 2024
0619e9a
chore: use query prompt support file type var
iamjoel Nov 19, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion api/configs/app_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ class DifyConfig(
# read from dotenv format config file
env_file=".env",
env_file_encoding="utf-8",
frozen=True,
# ignore extra attributes
extra="ignore",
)
Expand Down
42 changes: 19 additions & 23 deletions api/core/app/app_config/easy_ui_based_app/model_config/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

class ModelConfigConverter:
@classmethod
def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> ModelConfigWithCredentialsEntity:
def convert(cls, app_config: EasyUIBasedAppConfig) -> ModelConfigWithCredentialsEntity:
"""
Convert app model config dict to entity.
:param app_config: app config
Expand All @@ -38,27 +38,23 @@ def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) ->
)

if model_credentials is None:
if not skip_check:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
else:
model_credentials = {}

if not skip_check:
# check model
provider_model = provider_model_bundle.configuration.get_provider_model(
model=model_config.model, model_type=ModelType.LLM
)

if provider_model is None:
model_name = model_config.model
raise ValueError(f"Model {model_name} not exist.")

if provider_model.status == ModelStatus.NO_CONFIGURE:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
elif provider_model.status == ModelStatus.NO_PERMISSION:
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")

# check model
provider_model = provider_model_bundle.configuration.get_provider_model(
model=model_config.model, model_type=ModelType.LLM
)

if provider_model is None:
model_name = model_config.model
raise ValueError(f"Model {model_name} not exist.")

if provider_model.status == ModelStatus.NO_CONFIGURE:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
elif provider_model.status == ModelStatus.NO_PERMISSION:
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")

# model config
completion_params = model_config.parameters
Expand All @@ -76,7 +72,7 @@ def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) ->

model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)

if not skip_check and not model_schema:
if not model_schema:
raise ValueError(f"Model {model_name} not exist.")

return ModelConfigWithCredentialsEntity(
Expand Down
5 changes: 4 additions & 1 deletion api/core/app/task_pipeline/workflow_cycle_manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,12 @@ def _handle_workflow_run_failed(
).total_seconds()
db.session.commit()

db.session.refresh(workflow_run)
db.session.close()

with Session(db.engine, expire_on_commit=False) as session:
session.add(workflow_run)
session.refresh(workflow_run)

if trace_manager:
trace_manager.add_trace_task(
TraceTask(
Expand Down
69 changes: 25 additions & 44 deletions api/core/file/file_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
from configs import dify_config
from core.file import file_repository
from core.helper import ssrf_proxy
from core.model_runtime.entities import AudioPromptMessageContent, ImagePromptMessageContent, VideoPromptMessageContent
from core.model_runtime.entities import (
AudioPromptMessageContent,
DocumentPromptMessageContent,
ImagePromptMessageContent,
VideoPromptMessageContent,
)
from extensions.ext_database import db
from extensions.ext_storage import storage

Expand All @@ -29,43 +34,25 @@ def get_attr(*, file: File, attr: FileAttribute):
return file.remote_url
case FileAttribute.EXTENSION:
return file.extension
case _:
raise ValueError(f"Invalid file attribute: {attr}")


def to_prompt_message_content(
f: File,
/,
*,
image_detail_config: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.LOW,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
):
"""
Convert a File object to an ImagePromptMessageContent or AudioPromptMessageContent object.

This function takes a File object and converts it to an appropriate PromptMessageContent
object, which can be used as a prompt for image or audio-based AI models.

Args:
f (File): The File object to convert.
detail (Optional[ImagePromptMessageContent.DETAIL]): The detail level for image prompts.
If not provided, defaults to ImagePromptMessageContent.DETAIL.LOW.

Returns:
Union[ImagePromptMessageContent, AudioPromptMessageContent]: An object containing the file data and detail level

Raises:
ValueError: If the file type is not supported or if required data is missing.
"""
match f.type:
case FileType.IMAGE:
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url":
data = _to_url(f)
else:
data = _to_base64_data_string(f)

return ImagePromptMessageContent(data=data, detail=image_detail_config)
case FileType.AUDIO:
encoded_string = _file_to_encoded_string(f)
encoded_string = _get_encoded_string(f)
if f.extension is None:
raise ValueError("Missing file extension")
return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip("."))
Expand All @@ -74,9 +61,20 @@ def to_prompt_message_content(
data = _to_url(f)
else:
data = _to_base64_data_string(f)
if f.extension is None:
raise ValueError("Missing file extension")
return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
case FileType.DOCUMENT:
data = _get_encoded_string(f)
if f.mime_type is None:
raise ValueError("Missing file mime_type")
return DocumentPromptMessageContent(
encode_format="base64",
mime_type=f.mime_type,
data=data,
)
case _:
raise ValueError("file type f.type is not supported")
raise ValueError(f"file type {f.type} is not supported")


def download(f: File, /):
Expand Down Expand Up @@ -118,40 +116,23 @@ def _get_encoded_string(f: File, /):
case FileTransferMethod.REMOTE_URL:
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
response.raise_for_status()
content = response.content
encoded_string = base64.b64encode(content).decode("utf-8")
return encoded_string
data = response.content
case FileTransferMethod.LOCAL_FILE:
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
data = _download_file_content(upload_file.key)
encoded_string = base64.b64encode(data).decode("utf-8")
return encoded_string
case FileTransferMethod.TOOL_FILE:
tool_file = file_repository.get_tool_file(session=db.session(), file=f)
data = _download_file_content(tool_file.file_key)
encoded_string = base64.b64encode(data).decode("utf-8")
return encoded_string
case _:
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")

encoded_string = base64.b64encode(data).decode("utf-8")
return encoded_string


def _to_base64_data_string(f: File, /):
encoded_string = _get_encoded_string(f)
return f"data:{f.mime_type};base64,{encoded_string}"


def _file_to_encoded_string(f: File, /):
match f.type:
case FileType.IMAGE:
return _to_base64_data_string(f)
case FileType.VIDEO:
return _to_base64_data_string(f)
case FileType.AUDIO:
return _get_encoded_string(f)
case _:
raise ValueError(f"file type {f.type} is not supported")


def _to_url(f: File, /):
if f.transfer_method == FileTransferMethod.REMOTE_URL:
if f.remote_url is None:
Expand Down
3 changes: 2 additions & 1 deletion api/core/memory/token_buffer_memory.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Sequence
from typing import Optional

from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
Expand Down Expand Up @@ -27,7 +28,7 @@ def __init__(self, conversation: Conversation, model_instance: ModelInstance) ->

def get_history_prompt_messages(
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
) -> list[PromptMessage]:
) -> Sequence[PromptMessage]:
"""
Get history prompt messages.
:param max_token_limit: max token limit
Expand Down
4 changes: 2 additions & 2 deletions api/core/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@ def _get_load_balancing_manager(

def invoke_llm(
self,
prompt_messages: list[PromptMessage],
prompt_messages: Sequence[PromptMessage],
model_parameters: Optional[dict] = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
Expand Down
9 changes: 5 additions & 4 deletions api/core/model_runtime/callbacks/base_callback.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Optional

from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
Expand Down Expand Up @@ -31,7 +32,7 @@ def on_before_invoke(
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:
Expand Down Expand Up @@ -60,7 +61,7 @@ def on_new_chunk(
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
):
Expand Down Expand Up @@ -90,7 +91,7 @@ def on_after_invoke(
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:
Expand Down Expand Up @@ -120,7 +121,7 @@ def on_invoke_error(
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:
Expand Down
2 changes: 2 additions & 0 deletions api/core/model_runtime/entities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .message_entities import (
AssistantPromptMessage,
AudioPromptMessageContent,
DocumentPromptMessageContent,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContent,
Expand Down Expand Up @@ -37,4 +38,5 @@
"LLMResultChunk",
"LLMResultChunkDelta",
"AudioPromptMessageContent",
"DocumentPromptMessageContent",
]
13 changes: 11 additions & 2 deletions api/core/model_runtime/entities/message_entities.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC
from collections.abc import Sequence
from enum import Enum
from typing import Optional
from typing import Literal, Optional

from pydantic import BaseModel, Field, field_validator

Expand Down Expand Up @@ -57,6 +58,7 @@ class PromptMessageContentType(Enum):
IMAGE = "image"
AUDIO = "audio"
VIDEO = "video"
DOCUMENT = "document"


class PromptMessageContent(BaseModel):
Expand Down Expand Up @@ -101,13 +103,20 @@ class DETAIL(str, Enum):
detail: DETAIL = DETAIL.LOW


class DocumentPromptMessageContent(PromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
encode_format: Literal["base64"]
mime_type: str
data: str


class PromptMessage(ABC, BaseModel):
"""
Model class for prompt message.
"""

role: PromptMessageRole
content: Optional[str | list[PromptMessageContent]] = None
content: Optional[str | Sequence[PromptMessageContent]] = None
name: Optional[str] = None

def is_empty(self) -> bool:
Expand Down
3 changes: 3 additions & 0 deletions api/core/model_runtime/entities/model_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ class ModelFeature(Enum):
AGENT_THOUGHT = "agent-thought"
VISION = "vision"
STREAM_TOOL_CALL = "stream-tool-call"
DOCUMENT = "document"
VIDEO = "video"
AUDIO = "audio"


class DefaultParameterName(str, Enum):
Expand Down
Loading
Loading