From 641958cd7ae951c9f793dd8ac85e3d214fe3b891 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?= Date: Tue, 17 Dec 2024 12:05:13 +0800 Subject: [PATCH] feat: enhance gemini models (#11497) --- api/core/file/file_manager.py | 16 +-- .../entities/message_entities.py | 3 +- .../model_providers/anthropic/llm/llm.py | 12 +- .../google/llm/gemini-1.5-flash-001.yaml | 2 + .../google/llm/gemini-1.5-flash-002.yaml | 2 + .../llm/gemini-1.5-flash-8b-exp-0827.yaml | 2 + .../llm/gemini-1.5-flash-8b-exp-0924.yaml | 2 + .../google/llm/gemini-1.5-flash-exp-0827.yaml | 2 + .../google/llm/gemini-1.5-flash-latest.yaml | 2 + .../google/llm/gemini-1.5-flash.yaml | 2 + .../google/llm/gemini-1.5-pro-001.yaml | 2 + .../google/llm/gemini-1.5-pro-002.yaml | 2 + .../google/llm/gemini-1.5-pro-exp-0801.yaml | 2 + .../google/llm/gemini-1.5-pro-exp-0827.yaml | 2 + .../google/llm/gemini-1.5-pro-latest.yaml | 2 + .../google/llm/gemini-1.5-pro.yaml | 2 + .../google/llm/gemini-exp-1114.yaml | 2 + .../google/llm/gemini-exp-1121.yaml | 3 + .../model_providers/google/llm/llm.py | 119 +++++++++--------- .../model_providers/openai/llm/llm.py | 4 +- .../model_runtime/__mock/google.py | 58 ++++----- .../model_runtime/google/test_llm.py | 6 +- .../core/workflow/nodes/llm/test_node.py | 2 + 23 files changed, 138 insertions(+), 113 deletions(-) diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py index 3b83683755960e..9df605b4938ed5 100644 --- a/api/core/file/file_manager.py +++ b/api/core/file/file_manager.py @@ -50,12 +50,12 @@ def to_prompt_message_content( else: data = _to_base64_data_string(f) - return ImagePromptMessageContent(data=data, detail=image_detail_config) + return ImagePromptMessageContent(data=data, detail=image_detail_config, format=f.extension.lstrip(".")) case FileType.AUDIO: - encoded_string = _get_encoded_string(f) + data = _to_base64_data_string(f) if f.extension is None: raise ValueError("Missing file extension") - return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip(".")) + return AudioPromptMessageContent(data=data, format=f.extension.lstrip(".")) case FileType.VIDEO: if dify_config.MULTIMODAL_SEND_VIDEO_FORMAT == "url": data = _to_url(f) @@ -65,14 +65,8 @@ def to_prompt_message_content( raise ValueError("Missing file extension") return VideoPromptMessageContent(data=data, format=f.extension.lstrip(".")) case FileType.DOCUMENT: - data = _get_encoded_string(f) - if f.mime_type is None: - raise ValueError("Missing file mime_type") - return DocumentPromptMessageContent( - encode_format="base64", - mime_type=f.mime_type, - data=data, - ) + data = _to_base64_data_string(f) + return DocumentPromptMessageContent(encode_format="base64", data=data, format=f.extension.lstrip(".")) case _: raise ValueError(f"file type {f.type} is not supported") diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index f2870209bb5e00..26af522ea6ef86 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -101,13 +101,14 @@ class DETAIL(StrEnum): type: PromptMessageContentType = PromptMessageContentType.IMAGE detail: DETAIL = DETAIL.LOW + format: str = Field("jpg", description="Image format") class DocumentPromptMessageContent(PromptMessageContent): type: PromptMessageContentType = PromptMessageContentType.DOCUMENT encode_format: Literal["base64"] - mime_type: str data: str + format: str = Field(..., description="Document format") class PromptMessage(ABC, BaseModel): diff --git a/api/core/model_runtime/model_providers/anthropic/llm/llm.py b/api/core/model_runtime/model_providers/anthropic/llm/llm.py index 3faf5abbe87f58..edf56591f0d3c5 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/llm.py +++ b/api/core/model_runtime/model_providers/anthropic/llm/llm.py @@ -526,17 +526,19 @@ def _convert_prompt_messages(self, prompt_messages: Sequence[PromptMessage]) -> } sub_messages.append(sub_message_dict) elif isinstance(message_content, DocumentPromptMessageContent): - if message_content.mime_type != "application/pdf": + data_split = message_content.data.split(";base64,") + mime_type = data_split[0].replace("data:", "") + base64_data = data_split[1] + if mime_type != "application/pdf": raise ValueError( - f"Unsupported document type {message_content.mime_type}, " - "only support application/pdf" + f"Unsupported document type {mime_type}, " "only support application/pdf" ) sub_message_dict = { "type": "document", "source": { "type": message_content.encode_format, - "media_type": message_content.mime_type, - "data": message_content.data, + "media_type": mime_type, + "data": base64_data, }, } sub_messages.append(sub_message_dict) diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-001.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-001.yaml index 43f4e4787d2e07..86bba2154a527c 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-001.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-001.yaml @@ -8,6 +8,8 @@ features: - tool-call - stream-tool-call - document + - video + - audio model_properties: mode: chat context_size: 1048576 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-002.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-002.yaml index 7b9add6af16ebd..9ad57a19339515 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-002.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-002.yaml @@ -8,6 +8,8 @@ features: - tool-call - stream-tool-call - document + - video + - audio model_properties: mode: chat context_size: 1048576 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-8b-exp-0827.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-8b-exp-0827.yaml index d6de82012ef2d9..72205f15a8760f 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-8b-exp-0827.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-8b-exp-0827.yaml @@ -8,6 +8,8 @@ features: - tool-call - stream-tool-call - document + - video + - audio model_properties: mode: chat context_size: 1048576 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-8b-exp-0924.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-8b-exp-0924.yaml index 23b8d318fc14bc..1193e60669e2e2 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-8b-exp-0924.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-8b-exp-0924.yaml @@ -8,6 +8,8 @@ features: - tool-call - stream-tool-call - document + - video + - audio model_properties: mode: chat context_size: 1048576 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-exp-0827.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-exp-0827.yaml index 9762706cd7666c..7eba1f3d4de1b8 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-exp-0827.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-exp-0827.yaml @@ -8,6 +8,8 @@ features: - tool-call - stream-tool-call - document + - video + - audio model_properties: mode: chat context_size: 1048576 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-latest.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-latest.yaml index b9739d068e9907..b8c50241581670 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-latest.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-latest.yaml @@ -8,6 +8,8 @@ features: - tool-call - stream-tool-call - document + - video + - audio model_properties: mode: chat context_size: 1048576 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash.yaml index d8ab4efc918ad5..ea0c42dda88457 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash.yaml @@ -8,6 +8,8 @@ features: - tool-call - stream-tool-call - document + - video + - audio model_properties: mode: chat context_size: 1048576 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-001.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-001.yaml index 05184823e4ca27..16df30857c6761 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-001.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-001.yaml @@ -8,6 +8,8 @@ features: - tool-call - stream-tool-call - document + - video + - audio model_properties: mode: chat context_size: 2097152 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-002.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-002.yaml index 548fe6ddb22d80..717d9481b91953 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-002.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-002.yaml @@ -8,6 +8,8 @@ features: - tool-call - stream-tool-call - document + - video + - audio model_properties: mode: chat context_size: 2097152 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-exp-0801.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-exp-0801.yaml index defab26acf4d8d..bf9704f0d54879 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-exp-0801.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-exp-0801.yaml @@ -8,6 +8,8 @@ features: - tool-call - stream-tool-call - document + - video + - audio model_properties: mode: chat context_size: 2097152 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-exp-0827.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-exp-0827.yaml index 9cbc889f1776aa..714ff35f3443f3 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-exp-0827.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-exp-0827.yaml @@ -8,6 +8,8 @@ features: - tool-call - stream-tool-call - document + - video + - audio model_properties: mode: chat context_size: 2097152 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-latest.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-latest.yaml index e5aefcdb990aa7..bbca2ba3852869 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-latest.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-latest.yaml @@ -8,6 +8,8 @@ features: - tool-call - stream-tool-call - document + - video + - audio model_properties: mode: chat context_size: 2097152 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro.yaml index 00bd3e8d99db50..ae127fb4e2dea0 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro.yaml @@ -8,6 +8,8 @@ features: - tool-call - stream-tool-call - document + - video + - audio model_properties: mode: chat context_size: 2097152 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-exp-1114.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-exp-1114.yaml index 0515e706c2c79a..bd49b476938eee 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-exp-1114.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-exp-1114.yaml @@ -8,6 +8,8 @@ features: - tool-call - stream-tool-call - document + - video + - audio model_properties: mode: chat context_size: 32767 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-exp-1121.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-exp-1121.yaml index 9ca4f6e6756348..8e3f218df41971 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-exp-1121.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-exp-1121.yaml @@ -7,6 +7,9 @@ features: - vision - tool-call - stream-tool-call + - document + - video + - audio model_properties: mode: chat context_size: 32767 diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index c19e860d2e4b4b..9a1b13f96f49c0 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -1,29 +1,30 @@ import base64 -import io import json +import os +import tempfile +import time from collections.abc import Generator -from typing import Optional, Union, cast +from typing import Optional, Union import google.ai.generativelanguage as glm import google.generativeai as genai import requests from google.api_core import exceptions -from google.generativeai.client import _ClientManager -from google.generativeai.types import ContentType, GenerateContentResponse +from google.generativeai.types import ContentType, File, GenerateContentResponse from google.generativeai.types.content_types import to_part -from PIL import Image from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, - DocumentPromptMessageContent, ImagePromptMessageContent, PromptMessage, + PromptMessageContent, PromptMessageContentType, PromptMessageTool, SystemPromptMessage, ToolPromptMessage, UserPromptMessage, + VideoPromptMessageContent, ) from core.model_runtime.errors.invoke import ( InvokeAuthorizationError, @@ -35,21 +36,7 @@ ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel - -GOOGLE_AVAILABLE_MIMETYPE = [ - "application/pdf", - "application/x-javascript", - "text/javascript", - "application/x-python", - "text/x-python", - "text/plain", - "text/html", - "text/css", - "text/md", - "text/csv", - "text/xml", - "text/rtf", -] +from extensions.ext_redis import redis_client class GoogleLargeLanguageModel(LargeLanguageModel): @@ -201,29 +188,17 @@ def _generate( if stop: config_kwargs["stop_sequences"] = stop + genai.configure(api_key=credentials["google_api_key"]) google_model = genai.GenerativeModel(model_name=model) history = [] - # hack for gemini-pro-vision, which currently does not support multi-turn chat - if model == "gemini-pro-vision": - last_msg = prompt_messages[-1] - content = self._format_message_to_glm_content(last_msg) - history.append(content) - else: - for msg in prompt_messages: # makes message roles strictly alternating - content = self._format_message_to_glm_content(msg) - if history and history[-1]["role"] == content["role"]: - history[-1]["parts"].extend(content["parts"]) - else: - history.append(content) - - # Create a new ClientManager with tenant's API key - new_client_manager = _ClientManager() - new_client_manager.configure(api_key=credentials["google_api_key"]) - new_custom_client = new_client_manager.make_client("generative") - - google_model._client = new_custom_client + for msg in prompt_messages: # makes message roles strictly alternating + content = self._format_message_to_glm_content(msg) + if history and history[-1]["role"] == content["role"]: + history[-1]["parts"].extend(content["parts"]) + else: + history.append(content) response = google_model.generate_content( contents=history, @@ -346,7 +321,7 @@ def _convert_one_message_to_text(self, message: PromptMessage) -> str: content = message.content if isinstance(content, list): - content = "".join(c.data for c in content if c.type != PromptMessageContentType.IMAGE) + content = "".join(c.data for c in content if c.type == PromptMessageContentType.TEXT) if isinstance(message, UserPromptMessage): message_text = f"{human_prompt} {content}" @@ -359,6 +334,44 @@ def _convert_one_message_to_text(self, message: PromptMessage) -> str: return message_text + def _upload_file_content_to_google(self, message_content: PromptMessageContent) -> File: + key = f"{message_content.type.value}:{hash(message_content.data)}" + if redis_client.exists(key): + try: + return genai.get_file(redis_client.get(key).decode()) + except: + pass + with tempfile.NamedTemporaryFile(delete=False) as temp_file: + if message_content.data.startswith("data:"): + metadata, base64_data = message_content.data.split(",", 1) + file_content = base64.b64decode(base64_data) + mime_type = metadata.split(";", 1)[0].split(":")[1] + temp_file.write(file_content) + else: + # only ImagePromptMessageContent and VideoPromptMessageContent has url + try: + response = requests.get(message_content.data) + response.raise_for_status() + if message_content.type is ImagePromptMessageContent: + prefix = "image/" + elif message_content.type is VideoPromptMessageContent: + prefix = "video/" + mime_type = prefix + message_content.format + temp_file.write(response.content) + except Exception as ex: + raise ValueError(f"Failed to fetch data from url {message_content.data}, {ex}") + temp_file.flush() + try: + file = genai.upload_file(path=temp_file.name, mime_type=mime_type) + while file.state.name == "PROCESSING": + time.sleep(5) + file = genai.get_file(file.name) + # google will delete your upload files in 2 days. + redis_client.setex(key, 47 * 60 * 60, file.name) + return file + finally: + os.unlink(temp_file.name) + def _format_message_to_glm_content(self, message: PromptMessage) -> ContentType: """ Format a single message into glm.Content for Google API @@ -374,28 +387,8 @@ def _format_message_to_glm_content(self, message: PromptMessage) -> ContentType: for c in message.content: if c.type == PromptMessageContentType.TEXT: glm_content["parts"].append(to_part(c.data)) - elif c.type == PromptMessageContentType.IMAGE: - message_content = cast(ImagePromptMessageContent, c) - if message_content.data.startswith("data:"): - metadata, base64_data = c.data.split(",", 1) - mime_type = metadata.split(";", 1)[0].split(":")[1] - else: - # fetch image data from url - try: - image_content = requests.get(message_content.data).content - with Image.open(io.BytesIO(image_content)) as img: - mime_type = f"image/{img.format.lower()}" - base64_data = base64.b64encode(image_content).decode("utf-8") - except Exception as ex: - raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") - blob = {"inline_data": {"mime_type": mime_type, "data": base64_data}} - glm_content["parts"].append(blob) - elif c.type == PromptMessageContentType.DOCUMENT: - message_content = cast(DocumentPromptMessageContent, c) - if message_content.mime_type not in GOOGLE_AVAILABLE_MIMETYPE: - raise ValueError(f"Unsupported mime type {message_content.mime_type}") - blob = {"inline_data": {"mime_type": message_content.mime_type, "data": message_content.data}} - glm_content["parts"].append(blob) + else: + glm_content["parts"].append(self._upload_file_content_to_google(c)) return glm_content elif isinstance(message, AssistantPromptMessage): diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index 07cb1e2d1018f9..b73ce8752f13f3 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -920,10 +920,12 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: } sub_messages.append(sub_message_dict) elif isinstance(message_content, AudioPromptMessageContent): + data_split = message_content.data.split(";base64,") + base64_data = data_split[1] sub_message_dict = { "type": "input_audio", "input_audio": { - "data": message_content.data, + "data": base64_data, "format": message_content.format, }, } diff --git a/api/tests/integration_tests/model_runtime/__mock/google.py b/api/tests/integration_tests/model_runtime/__mock/google.py index 402bd9c2c21f69..5ea86baa83dd4b 100644 --- a/api/tests/integration_tests/model_runtime/__mock/google.py +++ b/api/tests/integration_tests/model_runtime/__mock/google.py @@ -1,4 +1,5 @@ from collections.abc import Generator +from unittest.mock import MagicMock import google.generativeai.types.generation_types as generation_config_types import pytest @@ -6,11 +7,10 @@ from google.ai import generativelanguage as glm from google.ai.generativelanguage_v1beta.types import content as gag_content from google.generativeai import GenerativeModel -from google.generativeai.client import _ClientManager, configure from google.generativeai.types import GenerateContentResponse, content_types, safety_types from google.generativeai.types.generation_types import BaseGenerateContentResponse -current_api_key = "" +from extensions import ext_redis class MockGoogleResponseClass: @@ -57,11 +57,6 @@ def generate_content( stream: bool = False, **kwargs, ) -> GenerateContentResponse: - global current_api_key - - if len(current_api_key) < 16: - raise Exception("Invalid API key") - if stream: return MockGoogleClass.generate_content_stream() @@ -75,33 +70,29 @@ def generative_response_text(self) -> str: def generative_response_candidates(self) -> list[MockGoogleResponseCandidateClass]: return [MockGoogleResponseCandidateClass()] - def make_client(self: _ClientManager, name: str): - global current_api_key - if name.endswith("_async"): - name = name.split("_")[0] - cls = getattr(glm, name.title() + "ServiceAsyncClient") - else: - cls = getattr(glm, name.title() + "ServiceClient") +def mock_configure(api_key: str): + if len(api_key) < 16: + raise Exception("Invalid API key") + + +class MockFileState: + def __init__(self): + self.name = "FINISHED" - # Attempt to configure using defaults. - if not self.client_config: - configure() - client_options = self.client_config.get("client_options", None) - if client_options: - current_api_key = client_options.api_key +class MockGoogleFile: + def __init__(self, name: str = "mock_file_name"): + self.name = name + self.state = MockFileState() - def nop(self, *args, **kwargs): - pass - original_init = cls.__init__ - cls.__init__ = nop - client: glm.GenerativeServiceClient = cls(**self.client_config) - cls.__init__ = original_init +def mock_get_file(name: str) -> MockGoogleFile: + return MockGoogleFile(name) - if not self.default_metadata: - return client + +def mock_upload_file(path: str, mime_type: str) -> MockGoogleFile: + return MockGoogleFile() @pytest.fixture @@ -109,8 +100,17 @@ def setup_google_mock(request, monkeypatch: MonkeyPatch): monkeypatch.setattr(BaseGenerateContentResponse, "text", MockGoogleClass.generative_response_text) monkeypatch.setattr(BaseGenerateContentResponse, "candidates", MockGoogleClass.generative_response_candidates) monkeypatch.setattr(GenerativeModel, "generate_content", MockGoogleClass.generate_content) - monkeypatch.setattr(_ClientManager, "make_client", MockGoogleClass.make_client) + monkeypatch.setattr("google.generativeai.configure", mock_configure) + monkeypatch.setattr("google.generativeai.get_file", mock_get_file) + monkeypatch.setattr("google.generativeai.upload_file", mock_upload_file) yield monkeypatch.undo() + + +@pytest.fixture +def setup_mock_redis() -> None: + ext_redis.redis_client.get = MagicMock(return_value=None) + ext_redis.redis_client.setex = MagicMock(return_value=None) + ext_redis.redis_client.exists = MagicMock(return_value=True) diff --git a/api/tests/integration_tests/model_runtime/google/test_llm.py b/api/tests/integration_tests/model_runtime/google/test_llm.py index 2877fa150764eb..777bbfdcb68188 100644 --- a/api/tests/integration_tests/model_runtime/google/test_llm.py +++ b/api/tests/integration_tests/model_runtime/google/test_llm.py @@ -13,7 +13,7 @@ ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.google.llm.llm import GoogleLargeLanguageModel -from tests.integration_tests.model_runtime.__mock.google import setup_google_mock +from tests.integration_tests.model_runtime.__mock.google import setup_google_mock, setup_mock_redis @pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True) @@ -95,7 +95,7 @@ def test_invoke_stream_model(setup_google_mock): @pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True) -def test_invoke_chat_model_with_vision(setup_google_mock): +def test_invoke_chat_model_with_vision(setup_google_mock, setup_mock_redis): model = GoogleLargeLanguageModel() result = model.invoke( @@ -124,7 +124,7 @@ def test_invoke_chat_model_with_vision(setup_google_mock): @pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True) -def test_invoke_chat_model_with_vision_multi_pics(setup_google_mock): +def test_invoke_chat_model_with_vision_multi_pics(setup_google_mock, setup_mock_redis): model = GoogleLargeLanguageModel() result = model.invoke( diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 9a24d35a1fcdae..024f9129c80fb7 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -326,6 +326,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): tenant_id="test", type=FileType.IMAGE, filename="test1.jpg", + extension=".jpg", transfer_method=FileTransferMethod.REMOTE_URL, remote_url=fake_remote_url, ) @@ -395,6 +396,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): tenant_id="test", type=FileType.IMAGE, filename="test1.jpg", + extension=".jpg", transfer_method=FileTransferMethod.REMOTE_URL, remote_url=fake_remote_url, )