Skip to content

Commit

Permalink
Merge branch 'main' into fix/document-word-count-is-incorrect
Browse files Browse the repository at this point in the history
# Conflicts:
#	api/services/dataset_service.py
  • Loading branch information
JohnJyong committed Nov 8, 2024
2 parents 47af468 + d52c750 commit d44a337
Show file tree
Hide file tree
Showing 28 changed files with 557 additions and 343 deletions.
3 changes: 2 additions & 1 deletion api/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,9 @@ UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
UPLOAD_AUDIO_FILE_SIZE_LIMIT=50

# Model Configuration
# Model configuration
MULTIMODAL_SEND_IMAGE_FORMAT=base64
MULTIMODAL_SEND_VIDEO_FORMAT=base64
PROMPT_GENERATION_MAX_TOKENS=512
CODE_GENERATION_MAX_TOKENS=1024

Expand Down
9 changes: 7 additions & 2 deletions api/configs/feature/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,12 +634,17 @@ class IndexingConfig(BaseSettings):
)


class ImageFormatConfig(BaseSettings):
class VisionFormatConfig(BaseSettings):
MULTIMODAL_SEND_IMAGE_FORMAT: Literal["base64", "url"] = Field(
description="Format for sending images in multimodal contexts ('base64' or 'url'), default is base64",
default="base64",
)

MULTIMODAL_SEND_VIDEO_FORMAT: Literal["base64", "url"] = Field(
description="Format for sending videos in multimodal contexts ('base64' or 'url'), default is base64",
default="base64",
)


class CeleryBeatConfig(BaseSettings):
CELERY_BEAT_SCHEDULER_TIME: int = Field(
Expand Down Expand Up @@ -742,7 +747,7 @@ class FeatureConfig(
FileAccessConfig,
FileUploadConfig,
HttpConfig,
ImageFormatConfig,
VisionFormatConfig,
InnerAPIConfig,
IndexingConfig,
LoggingConfig,
Expand Down
7 changes: 5 additions & 2 deletions api/controllers/console/datasets/datasets_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,11 @@ def post(self):
raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
try:
model_manager = ModelManager()
model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=args["embedding_model_provider"],
model_type=ModelType.TEXT_EMBEDDING,
model=args["embedding_model"],
)
except InvokeAuthorizationError:
raise ProviderNotInitializeError(
Expand Down
3 changes: 2 additions & 1 deletion api/controllers/service_api/app/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,10 @@ def delete(self, app_model: App, end_user: EndUser, c_id):
conversation_id = str(c_id)

try:
return ConversationService.delete(app_model, conversation_id, end_user)
ConversationService.delete(app_model, conversation_id, end_user)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
return {"result": "success"}, 200


class ConversationRenameApi(Resource):
Expand Down
12 changes: 10 additions & 2 deletions api/core/file/file_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from configs import dify_config
from core.file import file_repository
from core.helper import ssrf_proxy
from core.model_runtime.entities import AudioPromptMessageContent, ImagePromptMessageContent
from core.model_runtime.entities import AudioPromptMessageContent, ImagePromptMessageContent, VideoPromptMessageContent
from extensions.ext_database import db
from extensions.ext_storage import storage

Expand Down Expand Up @@ -71,6 +71,12 @@ def to_prompt_message_content(f: File, /):
if f.extension is None:
raise ValueError("Missing file extension")
return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip("."))
case FileType.VIDEO:
if dify_config.MULTIMODAL_SEND_VIDEO_FORMAT == "url":
data = _to_url(f)
else:
data = _to_base64_data_string(f)
return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
case _:
raise ValueError(f"file type {f.type} is not supported")

Expand Down Expand Up @@ -112,7 +118,7 @@ def _download_file_content(path: str, /):
def _get_encoded_string(f: File, /):
match f.transfer_method:
case FileTransferMethod.REMOTE_URL:
response = ssrf_proxy.get(f.remote_url)
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
response.raise_for_status()
content = response.content
encoded_string = base64.b64encode(content).decode("utf-8")
Expand Down Expand Up @@ -140,6 +146,8 @@ def _file_to_encoded_string(f: File, /):
match f.type:
case FileType.IMAGE:
return _to_base64_data_string(f)
case FileType.VIDEO:
return _to_base64_data_string(f)
case FileType.AUDIO:
return _get_encoded_string(f)
case _:
Expand Down
2 changes: 2 additions & 0 deletions api/core/model_runtime/entities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
VideoPromptMessageContent,
)
from .model_entities import ModelPropertyKey

__all__ = [
"ImagePromptMessageContent",
"VideoPromptMessageContent",
"PromptMessage",
"PromptMessageRole",
"LLMUsage",
Expand Down
7 changes: 7 additions & 0 deletions api/core/model_runtime/entities/message_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class PromptMessageContentType(Enum):
TEXT = "text"
IMAGE = "image"
AUDIO = "audio"
VIDEO = "video"


class PromptMessageContent(BaseModel):
Expand All @@ -75,6 +76,12 @@ class TextPromptMessageContent(PromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.TEXT


class VideoPromptMessageContent(PromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.VIDEO
data: str = Field(..., description="Base64 encoded video data")
format: str = Field(..., description="Video format")


class AudioPromptMessageContent(PromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.AUDIO
data: str = Field(..., description="Base64 encoded audio data")
Expand Down
9 changes: 9 additions & 0 deletions api/core/model_runtime/model_providers/tongyi/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
VideoPromptMessageContent,
)
from core.model_runtime.entities.model_entities import (
AIModelEntity,
Expand Down Expand Up @@ -431,6 +432,14 @@ def _convert_prompt_messages_to_tongyi_messages(

sub_message_dict = {"image": image_url}
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.VIDEO:
message_content = cast(VideoPromptMessageContent, message_content)
video_url = message_content.data
if message_content.data.startswith("data:"):
raise InvokeError("not support base64, please set MULTIMODAL_SEND_VIDEO_FORMAT to url")

sub_message_dict = {"video": video_url}
sub_messages.append(sub_message_dict)

# resort sub_messages to ensure text is always at last
sub_messages = sorted(sub_messages, key=lambda x: "text" in x)
Expand Down
38 changes: 26 additions & 12 deletions api/core/model_runtime/model_providers/zhipuai/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,21 +313,35 @@ def _construct_glm_4v_parameter(self, model: str, prompt_messages: list[PromptMe
return params

def _construct_glm_4v_messages(self, prompt_message: Union[str, list[PromptMessageContent]]) -> list[dict]:
if isinstance(prompt_message, str):
if isinstance(prompt_message, list):
sub_messages = []
for item in prompt_message:
if item.type == PromptMessageContentType.IMAGE:
sub_messages.append(
{
"type": "image_url",
"image_url": {"url": self._remove_base64_header(item.data)},
}
)
elif item.type == PromptMessageContentType.VIDEO:
sub_messages.append(
{
"type": "video_url",
"video_url": {"url": self._remove_base64_header(item.data)},
}
)
else:
sub_messages.append({"type": "text", "text": item.data})
return sub_messages
else:
return [{"type": "text", "text": prompt_message}]

return [
{"type": "image_url", "image_url": {"url": self._remove_image_header(item.data)}}
if item.type == PromptMessageContentType.IMAGE
else {"type": "text", "text": item.data}
for item in prompt_message
]

def _remove_image_header(self, image: str) -> str:
if image.startswith("data:image"):
return image.split(",")[1]
def _remove_base64_header(self, file_content: str) -> str:
if file_content.startswith("data:"):
data_split = file_content.split(";base64,")
return data_split[1]

return image
return file_content

def _handle_generate_response(
self,
Expand Down
4 changes: 4 additions & 0 deletions api/core/ops/entities/config_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,7 @@ def set_value(cls, v, info: ValidationInfo):
raise ValueError("endpoint must start with https://")

return v


OPS_FILE_PATH = "ops_trace/"
OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE"
11 changes: 11 additions & 0 deletions api/core/ops/entities/trace_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ def ensure_type(cls, v):
return v
return ""

class Config:
json_encoders = {
datetime: lambda v: v.isoformat(),
}


class WorkflowTraceInfo(BaseTraceInfo):
workflow_data: Any
Expand Down Expand Up @@ -100,6 +105,12 @@ class GenerateNameTraceInfo(BaseTraceInfo):
tenant_id: str


class TaskData(BaseModel):
app_id: str
trace_info_type: str
trace_info: Any


trace_info_info_map = {
"WorkflowTraceInfo": WorkflowTraceInfo,
"MessageTraceInfo": MessageTraceInfo,
Expand Down
20 changes: 15 additions & 5 deletions api/core/ops/ops_trace_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
import time
from datetime import timedelta
from typing import Any, Optional, Union
from uuid import UUID
from uuid import UUID, uuid4

from flask import current_app

from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token
from core.ops.entities.config_entity import (
OPS_FILE_PATH,
LangfuseConfig,
LangSmithConfig,
TracingProviderEnum,
Expand All @@ -22,6 +23,7 @@
MessageTraceInfo,
ModerationTraceInfo,
SuggestedQuestionTraceInfo,
TaskData,
ToolTraceInfo,
TraceTaskName,
WorkflowTraceInfo,
Expand All @@ -30,6 +32,7 @@
from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
from core.ops.utils import get_message_data
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import App, AppModelConfig, Conversation, Message, MessageAgentThought, MessageFile, TraceAppConfig
from models.workflow import WorkflowAppLog, WorkflowRun
from tasks.ops_trace_task import process_trace_tasks
Expand Down Expand Up @@ -740,10 +743,17 @@ def start_timer(self):
def send_to_celery(self, tasks: list[TraceTask]):
with self.flask_app.app_context():
for task in tasks:
file_id = uuid4().hex
trace_info = task.execute()
task_data = {
task_data = TaskData(
app_id=task.app_id,
trace_info_type=type(trace_info).__name__,
trace_info=trace_info.model_dump() if trace_info else None,
)
file_path = f"{OPS_FILE_PATH}{task.app_id}/{file_id}.json"
storage.save(file_path, task_data.model_dump_json().encode("utf-8"))
file_info = {
"file_id": file_id,
"app_id": task.app_id,
"trace_info_type": type(trace_info).__name__,
"trace_info": trace_info.model_dump() if trace_info else {},
}
process_trace_tasks.delay(task_data)
process_trace_tasks.delay(file_info)
24 changes: 24 additions & 0 deletions api/core/tools/provider/builtin/cogview/tools/cogvideo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from typing import Any, Union

from zhipuai import ZhipuAI

from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool


class CogVideoTool(BuiltinTool):
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
client = ZhipuAI(
base_url=self.runtime.credentials["zhipuai_base_url"],
api_key=self.runtime.credentials["zhipuai_api_key"],
)
if not tool_parameters.get("prompt") and not tool_parameters.get("image_url"):
return self.create_text_message("require at least one of prompt and image_url")

response = client.videos.generations(
model="cogvideox", prompt=tool_parameters.get("prompt"), image_url=tool_parameters.get("image_url")
)

return self.create_json_message(response.dict())
32 changes: 32 additions & 0 deletions api/core/tools/provider/builtin/cogview/tools/cogvideo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
identity:
name: cogvideo
author: hjlarry
label:
en_US: CogVideo
zh_Hans: CogVideo 视频生成
description:
human:
en_US: Use the CogVideox model provided by ZhipuAI to generate videos based on user prompts and images.
zh_Hans: 使用智谱cogvideox模型,根据用户输入的提示词和图片,生成视频。
llm: A tool for generating videos. The input is user's prompt or image url or both of them, the output is a task id. You can use another tool with this task id to check the status and get the video.
parameters:
- name: prompt
type: string
label:
en_US: prompt
zh_Hans: 提示词
human_description:
en_US: The prompt text used to generate video.
zh_Hans: 用于生成视频的提示词。
llm_description: The prompt text used to generate video. Optional.
form: llm
- name: image_url
type: string
label:
en_US: image url
zh_Hans: 图片链接
human_description:
en_US: The image url used to generate video.
zh_Hans: 输入一个图片链接,生成的视频将基于该图片和提示词。
llm_description: The image url used to generate video. Optional.
form: llm
30 changes: 30 additions & 0 deletions api/core/tools/provider/builtin/cogview/tools/cogvideo_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Any, Union

import httpx
from zhipuai import ZhipuAI

from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool


class CogVideoJobTool(BuiltinTool):
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
client = ZhipuAI(
api_key=self.runtime.credentials["zhipuai_api_key"],
base_url=self.runtime.credentials["zhipuai_base_url"],
)

response = client.videos.retrieve_videos_result(id=tool_parameters.get("id"))
result = [self.create_json_message(response.dict())]
if response.task_status == "SUCCESS":
for item in response.video_result:
video_cover_image = self.create_image_message(item.cover_image_url)
result.append(video_cover_image)
video = self.create_blob_message(
blob=httpx.get(item.url).content, meta={"mime_type": "video/mp4"}, save_as=self.VariableKey.VIDEO
)
result.append(video)

return result
Loading

0 comments on commit d44a337

Please sign in to comment.