Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: better memory usage from 800+ to 500+ #11796

Merged
merged 2 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions api/core/model_runtime/model_providers/vertex_ai/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@
import logging
import time
from collections.abc import Generator
from typing import Optional, Union, cast
from typing import TYPE_CHECKING, Optional, Union, cast

import google.auth.transport.requests
import requests
import vertexai.generative_models as glm
from anthropic import AnthropicVertex, Stream
from anthropic.types import (
ContentBlockDeltaEvent,
Expand All @@ -19,8 +18,6 @@
MessageStreamEvent,
)
from google.api_core import exceptions
from google.cloud import aiplatform
from google.oauth2 import service_account
from PIL import Image

from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
Expand All @@ -47,6 +44,9 @@
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel

if TYPE_CHECKING:
import vertexai.generative_models as glm

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -102,6 +102,8 @@ def _generate_anthropic(
:param stream: is stream response
:return: full response or stream response chunk generator result
"""
from google.oauth2 import service_account

# use Anthropic official SDK references
# - https://github.com/anthropics/anthropic-sdk-python
service_account_key = credentials.get("vertex_service_account_key", "")
Expand Down Expand Up @@ -406,13 +408,15 @@ def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:

return text.rstrip()

def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool:
def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> "glm.Tool":
"""
Convert tool messages to glm tools

:param tools: tool messages
:return: glm tools
"""
import vertexai.generative_models as glm

return glm.Tool(
function_declarations=[
glm.FunctionDeclaration(
Expand Down Expand Up @@ -473,6 +477,10 @@ def _generate(
:param user: unique user id
:return: full response or stream response chunk generator result
"""
import vertexai.generative_models as glm
from google.cloud import aiplatform
from google.oauth2 import service_account

config_kwargs = model_parameters.copy()
config_kwargs["max_output_tokens"] = config_kwargs.pop("max_tokens_to_sample", None)

Expand Down Expand Up @@ -522,7 +530,7 @@ def _generate(
return self._handle_generate_response(model, credentials, response, prompt_messages)

def _handle_generate_response(
self, model: str, credentials: dict, response: glm.GenerationResponse, prompt_messages: list[PromptMessage]
self, model: str, credentials: dict, response: "glm.GenerationResponse", prompt_messages: list[PromptMessage]
) -> LLMResult:
"""
Handle llm response
Expand Down Expand Up @@ -554,7 +562,7 @@ def _handle_generate_response(
return result

def _handle_generate_stream_response(
self, model: str, credentials: dict, response: glm.GenerationResponse, prompt_messages: list[PromptMessage]
self, model: str, credentials: dict, response: "glm.GenerationResponse", prompt_messages: list[PromptMessage]
) -> Generator:
"""
Handle llm stream response
Expand Down Expand Up @@ -638,13 +646,15 @@ def _convert_one_message_to_text(self, message: PromptMessage) -> str:

return message_text

def _format_message_to_glm_content(self, message: PromptMessage) -> glm.Content:
def _format_message_to_glm_content(self, message: PromptMessage) -> "glm.Content":
"""
Format a single message into glm.Content for Google API

:param message: one PromptMessage
:return: glm Content representation of message
"""
import vertexai.generative_models as glm

if isinstance(message, UserPromptMessage):
glm_content = glm.Content(role="user", parts=[])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,9 @@
import json
import time
from decimal import Decimal
from typing import Optional
from typing import TYPE_CHECKING, Optional

import tiktoken
from google.cloud import aiplatform
from google.oauth2 import service_account
from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel

from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject
Expand All @@ -24,6 +21,11 @@
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.vertex_ai._common import _CommonVertexAi

if TYPE_CHECKING:
from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
else:
VertexTextEmbeddingModel = None


class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
"""
Expand All @@ -48,6 +50,10 @@ def _invoke(
:param input_type: input type
:return: embeddings result
"""
from google.cloud import aiplatform
from google.oauth2 import service_account
from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel

service_account_key = credentials.get("vertex_service_account_key", "")
project_id = credentials["vertex_project_id"]
location = credentials["vertex_location"]
Expand Down Expand Up @@ -100,6 +106,10 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
:param credentials: model credentials
:return:
"""
from google.cloud import aiplatform
from google.oauth2 import service_account
from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel

try:
service_account_key = credentials.get("vertex_service_account_key", "")
project_id = credentials["vertex_project_id"]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
import re
from typing import Optional

import jieba
from jieba.analyse import default_tfidf

from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS


class JiebaKeywordTableHandler:
def __init__(self):
default_tfidf.stop_words = STOPWORDS
import jieba.analyse

from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS

jieba.analyse.default_tfidf.stop_words = STOPWORDS

def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10) -> set[str]:
"""Extract keywords with JIEBA tfidf."""
import jieba

keywords = jieba.analyse.extract_tags(
sentence=text,
topK=max_keywords_per_chunk,
Expand All @@ -22,6 +23,8 @@ def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10

def _expand_tokens_with_subtokens(self, tokens: set[str]) -> set[str]:
"""Get subtokens from a list of tokens., filtering for stopwords."""
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS

results = set()
for token in tokens:
results.add(token)
Expand Down
6 changes: 4 additions & 2 deletions api/core/rag/datasource/vdb/oracle/oraclevector.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@
from typing import Any

import jieba.posseg as pseg
import nltk
import numpy
import oracledb
from nltk.corpus import stopwords
from pydantic import BaseModel, model_validator

from configs import dify_config
Expand Down Expand Up @@ -202,6 +200,10 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc
return docs

def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# lazy import
import nltk
from nltk.corpus import stopwords

top_k = kwargs.get("top_k", 5)
# just not implement fetch by score_threshold now, may be later
score_threshold = float(kwargs.get("score_threshold") or 0.0)
Expand Down
17 changes: 11 additions & 6 deletions api/core/workflow/nodes/document_extractor/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,6 @@
import pandas as pd
import pypdfium2 # type: ignore
import yaml # type: ignore
from unstructured.partition.api import partition_via_api
from unstructured.partition.email import partition_email
from unstructured.partition.epub import partition_epub
from unstructured.partition.msg import partition_msg
from unstructured.partition.ppt import partition_ppt
from unstructured.partition.pptx import partition_pptx

from configs import dify_config
from core.file import File, FileTransferMethod, file_manager
Expand Down Expand Up @@ -256,6 +250,8 @@ def _extract_text_from_excel(file_content: bytes) -> str:


def _extract_text_from_ppt(file_content: bytes) -> str:
from unstructured.partition.ppt import partition_ppt

try:
with io.BytesIO(file_content) as file:
elements = partition_ppt(file=file)
Expand All @@ -265,6 +261,9 @@ def _extract_text_from_ppt(file_content: bytes) -> str:


def _extract_text_from_pptx(file_content: bytes) -> str:
from unstructured.partition.api import partition_via_api
from unstructured.partition.pptx import partition_pptx

try:
if dify_config.UNSTRUCTURED_API_URL and dify_config.UNSTRUCTURED_API_KEY:
with tempfile.NamedTemporaryFile(suffix=".pptx", delete=False) as temp_file:
Expand All @@ -287,6 +286,8 @@ def _extract_text_from_pptx(file_content: bytes) -> str:


def _extract_text_from_epub(file_content: bytes) -> str:
from unstructured.partition.epub import partition_epub

try:
with io.BytesIO(file_content) as file:
elements = partition_epub(file=file)
Expand All @@ -296,6 +297,8 @@ def _extract_text_from_epub(file_content: bytes) -> str:


def _extract_text_from_eml(file_content: bytes) -> str:
from unstructured.partition.email import partition_email

try:
with io.BytesIO(file_content) as file:
elements = partition_email(file=file)
Expand All @@ -305,6 +308,8 @@ def _extract_text_from_eml(file_content: bytes) -> str:


def _extract_text_from_msg(file_content: bytes) -> str:
from unstructured.partition.msg import partition_msg

try:
with io.BytesIO(file_content) as file:
elements = partition_msg(file=file)
Expand Down
Loading