Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: the consistency of MultiModalPromptMessageContent #11721

Merged
merged 6 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions api/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,7 @@ UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
UPLOAD_AUDIO_FILE_SIZE_LIMIT=50

# Model configuration
MULTIMODAL_SEND_IMAGE_FORMAT=base64
MULTIMODAL_SEND_VIDEO_FORMAT=base64
MULTIMODAL_SEND_FORMAT=base64
PROMPT_GENERATION_MAX_TOKENS=512
CODE_GENERATION_MAX_TOKENS=1024

Expand Down
9 changes: 2 additions & 7 deletions api/configs/feature/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,13 +666,8 @@ class IndexingConfig(BaseSettings):


class VisionFormatConfig(BaseSettings):
MULTIMODAL_SEND_IMAGE_FORMAT: Literal["base64", "url"] = Field(
description="Format for sending images in multimodal contexts ('base64' or 'url'), default is base64",
default="base64",
)

MULTIMODAL_SEND_VIDEO_FORMAT: Literal["base64", "url"] = Field(
description="Format for sending videos in multimodal contexts ('base64' or 'url'), default is base64",
MULTIMODAL_SEND_FORMAT: Literal["base64", "url"] = Field(
description="Format for sending files in multimodal contexts ('base64' or 'url'), default is base64",
default="base64",
)

Expand Down
57 changes: 25 additions & 32 deletions api/core/file/file_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,33 +42,31 @@ def to_prompt_message_content(
*,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
):
match f.type:
case FileType.IMAGE:
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url":
data = _to_url(f)
else:
data = _to_base64_data_string(f)

return ImagePromptMessageContent(data=data, detail=image_detail_config, format=f.extension.lstrip("."))
case FileType.AUDIO:
data = _to_base64_data_string(f)
if f.extension is None:
raise ValueError("Missing file extension")
return AudioPromptMessageContent(data=data, format=f.extension.lstrip("."))
case FileType.VIDEO:
if dify_config.MULTIMODAL_SEND_VIDEO_FORMAT == "url":
data = _to_url(f)
else:
data = _to_base64_data_string(f)
if f.extension is None:
raise ValueError("Missing file extension")
return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
case FileType.DOCUMENT:
data = _to_base64_data_string(f)
return DocumentPromptMessageContent(encode_format="base64", data=data, format=f.extension.lstrip("."))
case _:
raise ValueError(f"file type {f.type} is not supported")
if f.extension is None:
raise ValueError("Missing file extension")
if f.mime_type is None:
raise ValueError("Missing file mime_type")

params = {
"b64data": _get_encoded_string(f) if dify_config.MULTIMODAL_SEND_FORMAT == "base64" else "",
"url": _to_url(f) if dify_config.MULTIMODAL_SEND_FORMAT == "url" else "",
"format": f.extension.removeprefix("."),
"mime_type": f.mime_type,
}
if f.type == FileType.IMAGE:
params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW

prompt_class_map = {
FileType.IMAGE: ImagePromptMessageContent,
FileType.AUDIO: AudioPromptMessageContent,
FileType.VIDEO: VideoPromptMessageContent,
FileType.DOCUMENT: DocumentPromptMessageContent,
}

try:
return prompt_class_map[f.type](**params)
except KeyError:
raise ValueError(f"file type {f.type} is not supported")


def download(f: File, /):
Expand Down Expand Up @@ -122,11 +120,6 @@ def _get_encoded_string(f: File, /):
return encoded_string


def _to_base64_data_string(f: File, /):
encoded_string = _get_encoded_string(f)
return f"data:{f.mime_type};base64,{encoded_string}"


def _to_url(f: File, /):
if f.transfer_method == FileTransferMethod.REMOTE_URL:
if f.remote_url is None:
Expand Down
39 changes: 24 additions & 15 deletions api/core/model_runtime/entities/message_entities.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from abc import ABC
from collections.abc import Sequence
from enum import Enum, StrEnum
from typing import Literal, Optional
from typing import Optional

from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field, computed_field, field_validator


class PromptMessageRole(Enum):
Expand Down Expand Up @@ -67,7 +67,6 @@ class PromptMessageContent(BaseModel):
"""

type: PromptMessageContentType
data: str


class TextPromptMessageContent(PromptMessageContent):
Expand All @@ -76,21 +75,35 @@ class TextPromptMessageContent(PromptMessageContent):
"""

type: PromptMessageContentType = PromptMessageContentType.TEXT
data: str


class MultiModalPromptMessageContent(PromptMessageContent):
"""
Model class for multi-modal prompt message content.
"""

type: PromptMessageContentType
format: str = Field(..., description="the format of multi-modal file")
b64data: str = Field("", description="the base64 data of multi-modal file")
laipz8200 marked this conversation as resolved.
Show resolved Hide resolved
url: str = Field("", description="the url of multi-modal file")
mime_type: str = Field(..., description="the mime type of multi-modal file")

@computed_field(return_type=str)
@property
def data(self):
return self.url or f"data:{self.mime_type};base64,{self.b64data}"

class VideoPromptMessageContent(PromptMessageContent):

class VideoPromptMessageContent(MultiModalPromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.VIDEO
data: str = Field(..., description="Base64 encoded video data")
format: str = Field(..., description="Video format")


class AudioPromptMessageContent(PromptMessageContent):
class AudioPromptMessageContent(MultiModalPromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.AUDIO
data: str = Field(..., description="Base64 encoded audio data")
format: str = Field(..., description="Audio format")


class ImagePromptMessageContent(PromptMessageContent):
class ImagePromptMessageContent(MultiModalPromptMessageContent):
"""
Model class for image prompt message content.
"""
Expand All @@ -101,14 +114,10 @@ class DETAIL(StrEnum):

type: PromptMessageContentType = PromptMessageContentType.IMAGE
detail: DETAIL = DETAIL.LOW
format: str = Field("jpg", description="Image format")


class DocumentPromptMessageContent(PromptMessageContent):
class DocumentPromptMessageContent(MultiModalPromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
encode_format: Literal["base64"]
data: str
format: str = Field(..., description="Document format")


class PromptMessage(ABC, BaseModel):
Expand Down
27 changes: 10 additions & 17 deletions api/core/model_runtime/model_providers/anthropic/llm/llm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import base64
import io
import json
from collections.abc import Generator, Sequence
from typing import Optional, Union, cast
Expand All @@ -18,7 +17,6 @@
)
from anthropic.types.beta.tools import ToolsBetaMessage
from httpx import Timeout
from PIL import Image

from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities import (
Expand Down Expand Up @@ -498,22 +496,19 @@ def _convert_prompt_messages(self, prompt_messages: Sequence[PromptMessage]) ->
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
if not message_content.data.startswith("data:"):
if not message_content.b64data:
# fetch image data from url
try:
image_content = requests.get(message_content.data).content
with Image.open(io.BytesIO(image_content)) as img:
mime_type = f"image/{img.format.lower()}"
image_content = requests.get(message_content.url).content
base64_data = base64.b64encode(image_content).decode("utf-8")
except Exception as ex:
raise ValueError(
f"Failed to fetch image data from url {message_content.data}, {ex}"
)
else:
data_split = message_content.data.split(";base64,")
mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1]
base64_data = message_content.b64data

mime_type = message_content.mime_type
if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
raise ValueError(
f"Unsupported image type {mime_type}, "
Expand All @@ -526,19 +521,17 @@ def _convert_prompt_messages(self, prompt_messages: Sequence[PromptMessage]) ->
}
sub_messages.append(sub_message_dict)
elif isinstance(message_content, DocumentPromptMessageContent):
data_split = message_content.data.split(";base64,")
mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1]
if mime_type != "application/pdf":
if message_content.mime_type != "application/pdf":
raise ValueError(
f"Unsupported document type {mime_type}, " "only support application/pdf"
f"Unsupported document type {message_content.mime_type}, "
"only support application/pdf"
)
sub_message_dict = {
"type": "document",
"source": {
"type": message_content.encode_format,
"media_type": mime_type,
"data": base64_data,
"type": "base64",
"media_type": message_content.mime_type,
"data": message_content.data,
},
}
sub_messages.append(sub_message_dict)
Expand Down
6 changes: 3 additions & 3 deletions api/core/model_runtime/model_providers/tongyi/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,9 +434,9 @@ def _convert_prompt_messages_to_tongyi_messages(
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.VIDEO:
message_content = cast(VideoPromptMessageContent, message_content)
video_url = message_content.data
if message_content.data.startswith("data:"):
raise InvokeError("not support base64, please set MULTIMODAL_SEND_VIDEO_FORMAT to url")
video_url = message_content.url
if not video_url:
raise InvokeError("not support base64, please set MULTIMODAL_SEND_FORMAT to url")

sub_message_dict = {"video": video_url}
sub_messages.append(sub_message_dict)
Expand Down
Loading
Loading