Skip to content

Commit

Permalink
feat: Allow using file variables directly in the LLM node and support…
Browse files Browse the repository at this point in the history
… more file types. (#10679)

Co-authored-by: Joel <iamjoel007@gmail.com>
  • Loading branch information
laipz8200 and iamjoel authored Nov 22, 2024
1 parent 535c72c commit c5f7d65
Show file tree
Hide file tree
Showing 36 changed files with 1,036 additions and 268 deletions.
1 change: 0 additions & 1 deletion api/configs/app_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ class DifyConfig(
# read from dotenv format config file
env_file=".env",
env_file_encoding="utf-8",
frozen=True,
# ignore extra attributes
extra="ignore",
)
Expand Down
42 changes: 19 additions & 23 deletions api/core/app/app_config/easy_ui_based_app/model_config/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

class ModelConfigConverter:
@classmethod
def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> ModelConfigWithCredentialsEntity:
def convert(cls, app_config: EasyUIBasedAppConfig) -> ModelConfigWithCredentialsEntity:
"""
Convert app model config dict to entity.
:param app_config: app config
Expand All @@ -38,27 +38,23 @@ def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) ->
)

if model_credentials is None:
if not skip_check:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
else:
model_credentials = {}

if not skip_check:
# check model
provider_model = provider_model_bundle.configuration.get_provider_model(
model=model_config.model, model_type=ModelType.LLM
)

if provider_model is None:
model_name = model_config.model
raise ValueError(f"Model {model_name} not exist.")

if provider_model.status == ModelStatus.NO_CONFIGURE:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
elif provider_model.status == ModelStatus.NO_PERMISSION:
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")

# check model
provider_model = provider_model_bundle.configuration.get_provider_model(
model=model_config.model, model_type=ModelType.LLM
)

if provider_model is None:
model_name = model_config.model
raise ValueError(f"Model {model_name} not exist.")

if provider_model.status == ModelStatus.NO_CONFIGURE:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
elif provider_model.status == ModelStatus.NO_PERMISSION:
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")

# model config
completion_params = model_config.parameters
Expand All @@ -76,7 +72,7 @@ def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) ->

model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)

if not skip_check and not model_schema:
if not model_schema:
raise ValueError(f"Model {model_name} not exist.")

return ModelConfigWithCredentialsEntity(
Expand Down
5 changes: 4 additions & 1 deletion api/core/app/task_pipeline/workflow_cycle_manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,12 @@ def _handle_workflow_run_failed(
).total_seconds()
db.session.commit()

db.session.refresh(workflow_run)
db.session.close()

with Session(db.engine, expire_on_commit=False) as session:
session.add(workflow_run)
session.refresh(workflow_run)

if trace_manager:
trace_manager.add_trace_task(
TraceTask(
Expand Down
69 changes: 25 additions & 44 deletions api/core/file/file_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
from configs import dify_config
from core.file import file_repository
from core.helper import ssrf_proxy
from core.model_runtime.entities import AudioPromptMessageContent, ImagePromptMessageContent, VideoPromptMessageContent
from core.model_runtime.entities import (
AudioPromptMessageContent,
DocumentPromptMessageContent,
ImagePromptMessageContent,
VideoPromptMessageContent,
)
from extensions.ext_database import db
from extensions.ext_storage import storage

Expand All @@ -29,43 +34,25 @@ def get_attr(*, file: File, attr: FileAttribute):
return file.remote_url
case FileAttribute.EXTENSION:
return file.extension
case _:
raise ValueError(f"Invalid file attribute: {attr}")


def to_prompt_message_content(
f: File,
/,
*,
image_detail_config: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.LOW,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
):
"""
Convert a File object to an ImagePromptMessageContent or AudioPromptMessageContent object.
This function takes a File object and converts it to an appropriate PromptMessageContent
object, which can be used as a prompt for image or audio-based AI models.
Args:
f (File): The File object to convert.
detail (Optional[ImagePromptMessageContent.DETAIL]): The detail level for image prompts.
If not provided, defaults to ImagePromptMessageContent.DETAIL.LOW.
Returns:
Union[ImagePromptMessageContent, AudioPromptMessageContent]: An object containing the file data and detail level
Raises:
ValueError: If the file type is not supported or if required data is missing.
"""
match f.type:
case FileType.IMAGE:
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url":
data = _to_url(f)
else:
data = _to_base64_data_string(f)

return ImagePromptMessageContent(data=data, detail=image_detail_config)
case FileType.AUDIO:
encoded_string = _file_to_encoded_string(f)
encoded_string = _get_encoded_string(f)
if f.extension is None:
raise ValueError("Missing file extension")
return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip("."))
Expand All @@ -74,9 +61,20 @@ def to_prompt_message_content(
data = _to_url(f)
else:
data = _to_base64_data_string(f)
if f.extension is None:
raise ValueError("Missing file extension")
return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
case FileType.DOCUMENT:
data = _get_encoded_string(f)
if f.mime_type is None:
raise ValueError("Missing file mime_type")
return DocumentPromptMessageContent(
encode_format="base64",
mime_type=f.mime_type,
data=data,
)
case _:
raise ValueError("file type f.type is not supported")
raise ValueError(f"file type {f.type} is not supported")


def download(f: File, /):
Expand Down Expand Up @@ -118,40 +116,23 @@ def _get_encoded_string(f: File, /):
case FileTransferMethod.REMOTE_URL:
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
response.raise_for_status()
content = response.content
encoded_string = base64.b64encode(content).decode("utf-8")
return encoded_string
data = response.content
case FileTransferMethod.LOCAL_FILE:
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
data = _download_file_content(upload_file.key)
encoded_string = base64.b64encode(data).decode("utf-8")
return encoded_string
case FileTransferMethod.TOOL_FILE:
tool_file = file_repository.get_tool_file(session=db.session(), file=f)
data = _download_file_content(tool_file.file_key)
encoded_string = base64.b64encode(data).decode("utf-8")
return encoded_string
case _:
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")

encoded_string = base64.b64encode(data).decode("utf-8")
return encoded_string


def _to_base64_data_string(f: File, /):
encoded_string = _get_encoded_string(f)
return f"data:{f.mime_type};base64,{encoded_string}"


def _file_to_encoded_string(f: File, /):
match f.type:
case FileType.IMAGE:
return _to_base64_data_string(f)
case FileType.VIDEO:
return _to_base64_data_string(f)
case FileType.AUDIO:
return _get_encoded_string(f)
case _:
raise ValueError(f"file type {f.type} is not supported")


def _to_url(f: File, /):
if f.transfer_method == FileTransferMethod.REMOTE_URL:
if f.remote_url is None:
Expand Down
3 changes: 2 additions & 1 deletion api/core/memory/token_buffer_memory.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Sequence
from typing import Optional

from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
Expand Down Expand Up @@ -27,7 +28,7 @@ def __init__(self, conversation: Conversation, model_instance: ModelInstance) ->

def get_history_prompt_messages(
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
) -> list[PromptMessage]:
) -> Sequence[PromptMessage]:
"""
Get history prompt messages.
:param max_token_limit: max token limit
Expand Down
4 changes: 2 additions & 2 deletions api/core/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@ def _get_load_balancing_manager(

def invoke_llm(
self,
prompt_messages: list[PromptMessage],
prompt_messages: Sequence[PromptMessage],
model_parameters: Optional[dict] = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
Expand Down
9 changes: 5 additions & 4 deletions api/core/model_runtime/callbacks/base_callback.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Optional

from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
Expand Down Expand Up @@ -31,7 +32,7 @@ def on_before_invoke(
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:
Expand Down Expand Up @@ -60,7 +61,7 @@ def on_new_chunk(
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
):
Expand Down Expand Up @@ -90,7 +91,7 @@ def on_after_invoke(
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:
Expand Down Expand Up @@ -120,7 +121,7 @@ def on_invoke_error(
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:
Expand Down
2 changes: 2 additions & 0 deletions api/core/model_runtime/entities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .message_entities import (
AssistantPromptMessage,
AudioPromptMessageContent,
DocumentPromptMessageContent,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContent,
Expand Down Expand Up @@ -37,4 +38,5 @@
"LLMResultChunk",
"LLMResultChunkDelta",
"AudioPromptMessageContent",
"DocumentPromptMessageContent",
]
13 changes: 11 additions & 2 deletions api/core/model_runtime/entities/message_entities.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC
from collections.abc import Sequence
from enum import Enum
from typing import Optional
from typing import Literal, Optional

from pydantic import BaseModel, Field, field_validator

Expand Down Expand Up @@ -57,6 +58,7 @@ class PromptMessageContentType(Enum):
IMAGE = "image"
AUDIO = "audio"
VIDEO = "video"
DOCUMENT = "document"


class PromptMessageContent(BaseModel):
Expand Down Expand Up @@ -101,13 +103,20 @@ class DETAIL(str, Enum):
detail: DETAIL = DETAIL.LOW


class DocumentPromptMessageContent(PromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
encode_format: Literal["base64"]
mime_type: str
data: str


class PromptMessage(ABC, BaseModel):
"""
Model class for prompt message.
"""

role: PromptMessageRole
content: Optional[str | list[PromptMessageContent]] = None
content: Optional[str | Sequence[PromptMessageContent]] = None
name: Optional[str] = None

def is_empty(self) -> bool:
Expand Down
3 changes: 3 additions & 0 deletions api/core/model_runtime/entities/model_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ class ModelFeature(Enum):
AGENT_THOUGHT = "agent-thought"
VISION = "vision"
STREAM_TOOL_CALL = "stream-tool-call"
DOCUMENT = "document"
VIDEO = "video"
AUDIO = "audio"


class DefaultParameterName(str, Enum):
Expand Down
Loading

0 comments on commit c5f7d65

Please sign in to comment.