Skip to content

Commit

Permalink
Merge pull request #7 from langgenius/feat/agent-node
Browse files Browse the repository at this point in the history
feat: add agent strategy support
  • Loading branch information
Nov1c444 authored Jan 7, 2025
2 parents d0f12a5 + 92558f5 commit 43e58b9
Show file tree
Hide file tree
Showing 12 changed files with 676 additions and 75 deletions.
2 changes: 1 addition & 1 deletion python/dify_plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
monkey.patch_all(sys=True)

from dify_plugin.config.config import DifyPluginEnv
from dify_plugin.interfaces.agent import AgentProvider, AgentStrategy
from dify_plugin.interfaces.endpoint import Endpoint
from dify_plugin.interfaces.model import ModelProvider
from dify_plugin.interfaces.model.large_language_model import LargeLanguageModel
Expand All @@ -19,7 +20,6 @@
from dify_plugin.interfaces.model.text_embedding_model import TextEmbeddingModel
from dify_plugin.interfaces.model.tts_model import TTSModel
from dify_plugin.interfaces.tool import Tool, ToolProvider
from dify_plugin.interfaces.agent import AgentProvider, AgentStrategy
from dify_plugin.invocations.file import File
from dify_plugin.plugin import Plugin

Expand Down
4 changes: 2 additions & 2 deletions python/dify_plugin/core/entities/plugin/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class PluginInvokeType(Enum):
Tool = "tool"
Model = "model"
Endpoint = "endpoint"
Agent = "agent"
Agent = "agent_strategy"


class AgentActions(Enum):
Expand Down Expand Up @@ -70,7 +70,7 @@ class AgentInvokeRequest(PluginAccessRequest):
action: AgentActions = AgentActions.InvokeAgentStrategy
agent_strategy_provider: str
agent_strategy: str
agent_strategy_parameters: dict[str, Any]
agent_strategy_params: dict[str, Any]


class ToolValidateCredentialsRequest(PluginAccessRequest):
Expand Down
26 changes: 13 additions & 13 deletions python/dify_plugin/core/plugin_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,6 @@
from werkzeug import Response

from dify_plugin.config.config import DifyPluginEnv
from dify_plugin.core.plugin_registration import PluginRegistration
from dify_plugin.core.runtime import Session
from dify_plugin.core.utils.http_parser import parse_raw_request
from dify_plugin.entities.tool import ToolRuntime
from dify_plugin.interfaces.endpoint import Endpoint
from dify_plugin.interfaces.model.ai_model import AIModel
from dify_plugin.interfaces.model.large_language_model import LargeLanguageModel
from dify_plugin.interfaces.model.moderation_model import ModerationModel
from dify_plugin.interfaces.model.rerank_model import RerankModel
from dify_plugin.interfaces.model.speech2text_model import Speech2TextModel
from dify_plugin.interfaces.model.text_embedding_model import TextEmbeddingModel
from dify_plugin.interfaces.model.tts_model import TTSModel
from dify_plugin.core.entities.plugin.request import (
AgentInvokeRequest,
EndpointInvokeRequest,
Expand All @@ -36,6 +24,18 @@
ToolInvokeRequest,
ToolValidateCredentialsRequest,
)
from dify_plugin.core.plugin_registration import PluginRegistration
from dify_plugin.core.runtime import Session
from dify_plugin.core.utils.http_parser import parse_raw_request
from dify_plugin.entities.tool import ToolRuntime
from dify_plugin.interfaces.endpoint import Endpoint
from dify_plugin.interfaces.model.ai_model import AIModel
from dify_plugin.interfaces.model.large_language_model import LargeLanguageModel
from dify_plugin.interfaces.model.moderation_model import ModerationModel
from dify_plugin.interfaces.model.rerank_model import RerankModel
from dify_plugin.interfaces.model.speech2text_model import Speech2TextModel
from dify_plugin.interfaces.model.text_embedding_model import TextEmbeddingModel
from dify_plugin.interfaces.model.tts_model import TTSModel


class PluginExecutor:
Expand Down Expand Up @@ -83,7 +83,7 @@ def invoke_agent_strategy(self, session: Session, request: AgentInvokeRequest):
)

agent = agent_cls(session=session)
yield from agent.invoke(request.agent_strategy_parameters)
yield from agent.invoke(request.agent_strategy_params)

def get_tool_runtime_parameters(self, session: Session, data: ToolGetRuntimeParametersRequest):
tool_cls = self.registration.get_tool_cls(data.provider, data.tool)
Expand Down
3 changes: 1 addition & 2 deletions python/dify_plugin/core/plugin_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from dify_plugin.core.entities.plugin.setup import PluginAsset, PluginConfiguration
from dify_plugin.core.utils.class_loader import load_multi_subclasses_from_source, load_single_subclass_from_source
from dify_plugin.core.utils.yaml_loader import load_yaml_file
from dify_plugin.entities.agent import AgentStrategyProviderConfiguration, AgentStrategyConfiguration
from dify_plugin.entities.agent import AgentStrategyConfiguration, AgentStrategyProviderConfiguration
from dify_plugin.entities.endpoint import EndpointProviderConfiguration
from dify_plugin.entities.model import ModelType
from dify_plugin.entities.model.provider import ModelProviderConfiguration
Expand All @@ -29,7 +29,6 @@
from dify_plugin.interfaces.model.tts_model import TTSModel
from dify_plugin.interfaces.tool import Tool, ToolProvider


T = TypeVar("T")


Expand Down
4 changes: 3 additions & 1 deletion python/dify_plugin/entities/agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from collections.abc import Mapping
from enum import Enum
from typing import Any, Mapping, Optional, Union
from typing import Any, Optional, Union

from pydantic import BaseModel, Field, field_validator

from dify_plugin.core.utils.yaml_loader import load_yaml_file
Expand Down
5 changes: 3 additions & 2 deletions python/dify_plugin/entities/tool.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import base64
import contextlib
from enum import Enum, StrEnum
from typing import Any, Mapping, Optional, Union
import uuid
from collections.abc import Mapping
from enum import Enum, StrEnum
from typing import Any, Optional, Union

from pydantic import (
BaseModel,
Expand Down
230 changes: 226 additions & 4 deletions python/dify_plugin/interfaces/agent/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,87 @@
from abc import abstractmethod
from typing import Generator
from collections.abc import Generator
from typing import Optional, Union

from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator

from dify_plugin.core.runtime import Session
from dify_plugin.entities.agent import AgentInvokeMessage
from dify_plugin.entities.model import AIModelEntity, ModelPropertyKey
from dify_plugin.entities.model.llm import LLMModelConfig, LLMUsage
from dify_plugin.entities.model.message import PromptMessage, PromptMessageTool
from dify_plugin.entities.tool import ToolDescription, ToolIdentity, ToolParameter
from dify_plugin.interfaces.tool import ToolLike, ToolProvider


class AgentToolIdentity(ToolIdentity):
provider: str = Field(..., description="The provider of the tool")


class AgentModelConfig(LLMModelConfig):
entity: AIModelEntity


class AgentScratchpadUnit(BaseModel):
"""
Agent First Prompt Entity.
"""

class Action(BaseModel):
"""
Action Entity.
"""

action_name: str
action_input: Union[dict, str]

def to_dict(self) -> dict:
"""
Convert to dictionary.
"""
return {
"action": self.action_name,
"action_input": self.action_input,
}

agent_response: Optional[str] = ""
thought: Optional[str] = ""
action_str: Optional[str] = ""
observation: Optional[str] = ""
action: Optional[Action] = None

def is_final(self) -> bool:
"""
Check if the scratchpad unit is final.
"""
return self.action is None or (
"final" in self.action.action_name.lower() and "answer" in self.action.action_name.lower()
)


class ToolEntity(BaseModel):
identity: AgentToolIdentity
parameters: list[ToolParameter] = Field(default_factory=list)
description: Optional[ToolDescription] = None
output_schema: Optional[dict] = None
has_runtime_parameters: bool = Field(default=False, description="Whether the tool has runtime parameters")

# pydantic configs
model_config = ConfigDict(protected_namespaces=())

@field_validator("parameters", mode="before")
@classmethod
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]:
return v or []


class AgentProvider(ToolProvider):
def validate_credentials(self):
def validate_credentials(self, credentials: dict):
"""
Always permit the agent to run
"""
return True
pass

def _validate_credentials(self):
def _validate_credentials(self, credentials: dict):
pass


Expand Down Expand Up @@ -40,3 +109,156 @@ def invoke(self, parameters: dict) -> Generator[AgentInvokeMessage, None, None]:
# convert parameters into correct types
parameters = self._convert_parameters(parameters)
return self._invoke(parameters)

def increase_usage(self, final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage):
if not final_llm_usage_dict["usage"]:
final_llm_usage_dict["usage"] = usage
else:
llm_usage = final_llm_usage_dict["usage"]
llm_usage.prompt_tokens += usage.prompt_tokens
llm_usage.completion_tokens += usage.completion_tokens
llm_usage.prompt_price += usage.prompt_price
llm_usage.completion_price += usage.completion_price
llm_usage.total_price += usage.total_price

def recalc_llm_max_tokens(
self, model_entity: AIModelEntity, prompt_messages: list[PromptMessage], parameters: dict
):
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit

model_context_tokens = model_entity.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)

max_tokens = 0
for parameter_rule in model_entity.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
parameters.get(parameter_rule.name) or parameters.get(parameter_rule.use_template or "")
) or 0

if model_context_tokens is None:
return -1

if max_tokens is None:
max_tokens = 0

prompt_tokens = self._get_num_tokens_by_gpt2(prompt_messages)

if prompt_tokens + max_tokens > model_context_tokens:
max_tokens = max(model_context_tokens - prompt_tokens, 16)

for parameter_rule in model_entity.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
parameters[parameter_rule.name] = max_tokens

def _get_num_tokens_by_gpt2(self, prompt_messges: list[PromptMessage]) -> int:
"""
Get number of tokens for given prompt messages by gpt2
Some provider models do not provide an interface for obtaining the number of tokens.
Here, the gpt2 tokenizer is used to calculate the number of tokens.
This method can be executed offline, and the gpt2 tokenizer has been cached in the project.
:param text: plain text of prompt. You need to convert the original message to plain text
:return: number of tokens
"""
import tiktoken

text = " ".join([prompt.content for prompt in prompt_messges if isinstance(prompt.content, str)])
return len(tiktoken.encoding_for_model("gpt2").encode(text))

def _init_prompt_tools(self, tools: list[ToolEntity] | None) -> list[PromptMessageTool]:
"""
Init tools
"""

prompt_messages_tools = []
for tool in tools or []:
try:
prompt_tool = self._convert_tool_to_prompt_message_tool(tool)
except Exception:
# api tool may be deleted
continue

# save prompt tool
prompt_messages_tools.append(prompt_tool)

return prompt_messages_tools

def _convert_tool_to_prompt_message_tool(self, tool: ToolEntity) -> PromptMessageTool:
"""
convert tool to prompt message tool
"""
message_tool = PromptMessageTool(
name=tool.identity.name,
description=tool.description.llm if tool.description else "",
parameters={
"type": "object",
"properties": {},
"required": [],
},
)

parameters = tool.parameters
for parameter in parameters:
if parameter.form != ToolParameter.ToolParameterForm.LLM:
continue

parameter_type = parameter.type
if parameter.type in {
ToolParameter.ToolParameterType.FILE,
ToolParameter.ToolParameterType.FILES,
}:
continue
enum = []
if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options] if parameter.options else []

message_tool.parameters["properties"][parameter.name] = {
"type": parameter_type,
"description": parameter.llm_description or "",
}

if len(enum) > 0:
message_tool.parameters["properties"][parameter.name]["enum"] = enum

if parameter.required:
message_tool.parameters["required"].append(parameter.name)

return message_tool

def update_prompt_message_tool(self, tool: ToolEntity, prompt_tool: PromptMessageTool) -> PromptMessageTool:
"""
update prompt message tool
"""
# try to get tool runtime parameters
tool_runtime_parameters = tool.parameters

for parameter in tool_runtime_parameters:
if parameter.form != ToolParameter.ToolParameterForm.LLM:
continue

parameter_type = parameter.type
if parameter.type in {
ToolParameter.ToolParameterType.FILE,
ToolParameter.ToolParameterType.FILES,
}:
continue
enum = []
if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options] if parameter.options else []

prompt_tool.parameters["properties"][parameter.name] = {
"type": parameter_type,
"description": parameter.llm_description or "",
}

if len(enum) > 0:
prompt_tool.parameters["properties"][parameter.name]["enum"] = enum

if parameter.required and parameter.name not in prompt_tool.parameters["required"]:
prompt_tool.parameters["required"].append(parameter.name)

return prompt_tool
9 changes: 4 additions & 5 deletions python/dify_plugin/interfaces/tool/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from abc import ABC, abstractmethod
from collections.abc import Generator
from typing import Any, Generic, Mapping, Optional, Type, TypeVar

from dify_plugin.entities.agent import AgentInvokeMessage
from dify_plugin.file.entities import FileType
from collections.abc import Generator, Mapping
from typing import Any, Generic, Optional, Type, TypeVar

from dify_plugin.core.runtime import Session
from dify_plugin.entities.agent import AgentInvokeMessage
from dify_plugin.entities.tool import ToolInvokeMessage, ToolParameter, ToolRuntime, ToolSelector
from dify_plugin.file.constants import DIFY_FILE_IDENTITY, DIFY_TOOL_SELECTOR_IDENTITY
from dify_plugin.file.entities import FileType
from dify_plugin.file.file import File

T = TypeVar("T", bound=ToolInvokeMessage | AgentInvokeMessage)
Expand Down
2 changes: 1 addition & 1 deletion python/dify_plugin/invocations/model/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def invoke(
"prompt_messages": [message.model_dump() for message in prompt_messages],
"tools": [tool.model_dump() for tool in tools] if tools else None,
"stop": stop,
"stream": True,
"stream": stream,
}

if stream:
Expand Down
Loading

0 comments on commit 43e58b9

Please sign in to comment.