Skip to content

Commit

Permalink
chore: the consistency of MultiModalPromptMessageContent (#11721)
Browse files Browse the repository at this point in the history
  • Loading branch information
hjlarry authored Dec 17, 2024
1 parent 78c3051 commit c9b4029
Show file tree
Hide file tree
Showing 14 changed files with 109 additions and 100 deletions.
3 changes: 1 addition & 2 deletions api/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,7 @@ UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
UPLOAD_AUDIO_FILE_SIZE_LIMIT=50

# Model configuration
MULTIMODAL_SEND_IMAGE_FORMAT=base64
MULTIMODAL_SEND_VIDEO_FORMAT=base64
MULTIMODAL_SEND_FORMAT=base64
PROMPT_GENERATION_MAX_TOKENS=512
CODE_GENERATION_MAX_TOKENS=1024

Expand Down
13 changes: 4 additions & 9 deletions api/configs/feature/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,14 +665,9 @@ class IndexingConfig(BaseSettings):
)


class VisionFormatConfig(BaseSettings):
MULTIMODAL_SEND_IMAGE_FORMAT: Literal["base64", "url"] = Field(
description="Format for sending images in multimodal contexts ('base64' or 'url'), default is base64",
default="base64",
)

MULTIMODAL_SEND_VIDEO_FORMAT: Literal["base64", "url"] = Field(
description="Format for sending videos in multimodal contexts ('base64' or 'url'), default is base64",
class MultiModalTransferConfig(BaseSettings):
MULTIMODAL_SEND_FORMAT: Literal["base64", "url"] = Field(
description="Format for sending files in multimodal contexts ('base64' or 'url'), default is base64",
default="base64",
)

Expand Down Expand Up @@ -778,13 +773,13 @@ class FeatureConfig(
FileAccessConfig,
FileUploadConfig,
HttpConfig,
VisionFormatConfig,
InnerAPIConfig,
IndexingConfig,
LoggingConfig,
MailConfig,
ModelLoadBalanceConfig,
ModerationConfig,
MultiModalTransferConfig,
PositionConfig,
RagEtlConfig,
SecurityConfig,
Expand Down
57 changes: 25 additions & 32 deletions api/core/file/file_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,33 +42,31 @@ def to_prompt_message_content(
*,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
):
match f.type:
case FileType.IMAGE:
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url":
data = _to_url(f)
else:
data = _to_base64_data_string(f)

return ImagePromptMessageContent(data=data, detail=image_detail_config, format=f.extension.lstrip("."))
case FileType.AUDIO:
data = _to_base64_data_string(f)
if f.extension is None:
raise ValueError("Missing file extension")
return AudioPromptMessageContent(data=data, format=f.extension.lstrip("."))
case FileType.VIDEO:
if dify_config.MULTIMODAL_SEND_VIDEO_FORMAT == "url":
data = _to_url(f)
else:
data = _to_base64_data_string(f)
if f.extension is None:
raise ValueError("Missing file extension")
return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
case FileType.DOCUMENT:
data = _to_base64_data_string(f)
return DocumentPromptMessageContent(encode_format="base64", data=data, format=f.extension.lstrip("."))
case _:
raise ValueError(f"file type {f.type} is not supported")
if f.extension is None:
raise ValueError("Missing file extension")
if f.mime_type is None:
raise ValueError("Missing file mime_type")

params = {
"base64_data": _get_encoded_string(f) if dify_config.MULTIMODAL_SEND_FORMAT == "base64" else "",
"url": _to_url(f) if dify_config.MULTIMODAL_SEND_FORMAT == "url" else "",
"format": f.extension.removeprefix("."),
"mime_type": f.mime_type,
}
if f.type == FileType.IMAGE:
params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW

prompt_class_map = {
FileType.IMAGE: ImagePromptMessageContent,
FileType.AUDIO: AudioPromptMessageContent,
FileType.VIDEO: VideoPromptMessageContent,
FileType.DOCUMENT: DocumentPromptMessageContent,
}

try:
return prompt_class_map[f.type](**params)
except KeyError:
raise ValueError(f"file type {f.type} is not supported")


def download(f: File, /):
Expand Down Expand Up @@ -122,11 +120,6 @@ def _get_encoded_string(f: File, /):
return encoded_string


def _to_base64_data_string(f: File, /):
encoded_string = _get_encoded_string(f)
return f"data:{f.mime_type};base64,{encoded_string}"


def _to_url(f: File, /):
if f.transfer_method == FileTransferMethod.REMOTE_URL:
if f.remote_url is None:
Expand Down
39 changes: 24 additions & 15 deletions api/core/model_runtime/entities/message_entities.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from abc import ABC
from collections.abc import Sequence
from enum import Enum, StrEnum
from typing import Literal, Optional
from typing import Optional

from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field, computed_field, field_validator


class PromptMessageRole(Enum):
Expand Down Expand Up @@ -67,7 +67,6 @@ class PromptMessageContent(BaseModel):
"""

type: PromptMessageContentType
data: str


class TextPromptMessageContent(PromptMessageContent):
Expand All @@ -76,21 +75,35 @@ class TextPromptMessageContent(PromptMessageContent):
"""

type: PromptMessageContentType = PromptMessageContentType.TEXT
data: str


class MultiModalPromptMessageContent(PromptMessageContent):
"""
Model class for multi-modal prompt message content.
"""

type: PromptMessageContentType
format: str = Field(..., description="the format of multi-modal file")
base64_data: str = Field("", description="the base64 data of multi-modal file")
url: str = Field("", description="the url of multi-modal file")
mime_type: str = Field(..., description="the mime type of multi-modal file")

@computed_field(return_type=str)
@property
def data(self):
return self.url or f"data:{self.mime_type};base64,{self.base64_data}"

class VideoPromptMessageContent(PromptMessageContent):

class VideoPromptMessageContent(MultiModalPromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.VIDEO
data: str = Field(..., description="Base64 encoded video data")
format: str = Field(..., description="Video format")


class AudioPromptMessageContent(PromptMessageContent):
class AudioPromptMessageContent(MultiModalPromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.AUDIO
data: str = Field(..., description="Base64 encoded audio data")
format: str = Field(..., description="Audio format")


class ImagePromptMessageContent(PromptMessageContent):
class ImagePromptMessageContent(MultiModalPromptMessageContent):
"""
Model class for image prompt message content.
"""
Expand All @@ -101,14 +114,10 @@ class DETAIL(StrEnum):

type: PromptMessageContentType = PromptMessageContentType.IMAGE
detail: DETAIL = DETAIL.LOW
format: str = Field("jpg", description="Image format")


class DocumentPromptMessageContent(PromptMessageContent):
class DocumentPromptMessageContent(MultiModalPromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
encode_format: Literal["base64"]
data: str
format: str = Field(..., description="Document format")


class PromptMessage(ABC, BaseModel):
Expand Down
27 changes: 10 additions & 17 deletions api/core/model_runtime/model_providers/anthropic/llm/llm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import base64
import io
import json
from collections.abc import Generator, Sequence
from typing import Optional, Union, cast
Expand All @@ -18,7 +17,6 @@
)
from anthropic.types.beta.tools import ToolsBetaMessage
from httpx import Timeout
from PIL import Image

from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities import (
Expand Down Expand Up @@ -498,22 +496,19 @@ def _convert_prompt_messages(self, prompt_messages: Sequence[PromptMessage]) ->
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
if not message_content.data.startswith("data:"):
if not message_content.base64_data:
# fetch image data from url
try:
image_content = requests.get(message_content.data).content
with Image.open(io.BytesIO(image_content)) as img:
mime_type = f"image/{img.format.lower()}"
image_content = requests.get(message_content.url).content
base64_data = base64.b64encode(image_content).decode("utf-8")
except Exception as ex:
raise ValueError(
f"Failed to fetch image data from url {message_content.data}, {ex}"
)
else:
data_split = message_content.data.split(";base64,")
mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1]
base64_data = message_content.base64_data

mime_type = message_content.mime_type
if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
raise ValueError(
f"Unsupported image type {mime_type}, "
Expand All @@ -526,19 +521,17 @@ def _convert_prompt_messages(self, prompt_messages: Sequence[PromptMessage]) ->
}
sub_messages.append(sub_message_dict)
elif isinstance(message_content, DocumentPromptMessageContent):
data_split = message_content.data.split(";base64,")
mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1]
if mime_type != "application/pdf":
if message_content.mime_type != "application/pdf":
raise ValueError(
f"Unsupported document type {mime_type}, " "only support application/pdf"
f"Unsupported document type {message_content.mime_type}, "
"only support application/pdf"
)
sub_message_dict = {
"type": "document",
"source": {
"type": message_content.encode_format,
"media_type": mime_type,
"data": base64_data,
"type": "base64",
"media_type": message_content.mime_type,
"data": message_content.data,
},
}
sub_messages.append(sub_message_dict)
Expand Down
6 changes: 3 additions & 3 deletions api/core/model_runtime/model_providers/tongyi/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,9 +434,9 @@ def _convert_prompt_messages_to_tongyi_messages(
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.VIDEO:
message_content = cast(VideoPromptMessageContent, message_content)
video_url = message_content.data
if message_content.data.startswith("data:"):
raise InvokeError("not support base64, please set MULTIMODAL_SEND_VIDEO_FORMAT to url")
video_url = message_content.url
if not video_url:
raise InvokeError("not support base64, please set MULTIMODAL_SEND_FORMAT to url")

sub_message_dict = {"video": video_url}
sub_messages.append(sub_message_dict)
Expand Down
Loading

0 comments on commit c9b4029

Please sign in to comment.