Skip to content

Commit

Permalink
Merge pull request #121 from alipay/dev_weizj
Browse files Browse the repository at this point in the history
add sql tool
  • Loading branch information
LandJerry authored Jul 11, 2024
2 parents d19723e + 3ec6e25 commit 772909a
Show file tree
Hide file tree
Showing 14 changed files with 171 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def invoke(self, agent_model: AgentModel, planner_input: dict, input_object: Inp
history_messages_key="chat_history",
input_messages_key=self.input_key,
) | StrOutputParser()
res = asyncio.run(
chain_with_history.ainvoke(input=planner_input, config={"configurable": {"session_id": "unused"}}))

res = self.invoke_chain(agent_model, chain_with_history, planner_input, chat_history, input_object)
return {**planner_input, self.output_key: res, 'chat_history': generate_memories(chat_history)}

def handle_prompt(self, agent_model: AgentModel, planner_input: dict) -> Prompt:
Expand Down
20 changes: 15 additions & 5 deletions agentuniverse/agent/plan/planner/peer_planner/peer_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def invoke(self, agent_model: AgentModel, planner_input: dict, input_object: Inp
"""
planner_config = agent_model.plan.get('planner')
sub_agents = self.generate_sub_agents(planner_config)
return self.agents_run(sub_agents, planner_config, planner_input, input_object)
return self.agents_run(agent_model, sub_agents, planner_config, planner_input, input_object)

@staticmethod
def generate_sub_agents(planner_config: dict) -> dict:
Expand Down Expand Up @@ -79,7 +79,8 @@ def build_expert_framework(planner_config: dict, input_object: InputObject):
elif context:
input_object.add_data('expert_framework', context)

def agents_run(self, agents: dict, planner_config: dict, agent_input: dict, input_object: InputObject) -> dict:
def agents_run(self, agent_mode: AgentModel, agents: dict, planner_config: dict, agent_input: dict,
input_object: InputObject) -> dict:
"""Planner agents run.
Args:
Expand Down Expand Up @@ -125,7 +126,10 @@ def agents_run(self, agents: dict, planner_config: dict, agent_input: dict, inpu
for index, one_framework in enumerate(planning_result.get_data('framework')):
logger_info += f"[{index + 1}] {one_framework} \n"
LOGGER.info(logger_info)
self.stream_output(input_object, {"data": planning_result.to_dict(), "type": "planning"})
self.stream_output(input_object, {"data": {
'output': planning_result.to_dict(),
"agent_info": agent_mode.info
}, "type": "planning"})

if not executing_result or jump_step in ["planning", "executing"]:
if not executingAgent:
Expand All @@ -144,7 +148,10 @@ def agents_run(self, agents: dict, planner_config: dict, agent_input: dict, inpu
one_exec_log_info += f"[{index + 1}] output: {one_exec_res['output']}\n"
logger_info += one_exec_log_info
LOGGER.info(logger_info)
self.stream_output(input_object, {"data": executing_result.to_dict(), "type": "executing"})
self.stream_output(input_object, {"data": {
'output': executing_result.to_dict(),
"agent_info": agent_mode.info
}, "type": "executing"})

if not expressing_result or jump_step in ["planning", "executing", "expressing"]:
if not expressingAgent:
Expand All @@ -159,7 +166,10 @@ def agents_run(self, agents: dict, planner_config: dict, agent_input: dict, inpu
logger_info = f"\nExpressing agent execution result is :\n"
logger_info += f"{expressing_result.get_data('output')}"
LOGGER.info(logger_info)
self.stream_output(input_object, {"data": executing_result.to_dict(), "type": "expressing"})
self.stream_output(input_object, {"data": {
'output': expressing_result.get_data('output'),
"agent_info": agent_mode.info
}, "type": "expressing"})

if not reviewing_result or jump_step in ["planning", "executing", "expressing", "reviewing"]:
if not reviewingAgent:
Expand Down
40 changes: 36 additions & 4 deletions agentuniverse/agent/plan/planner/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,18 @@
# @FileName: planner.py
"""Base class for Planner."""
from abc import abstractmethod
import copy
import logging
from queue import Queue
from typing import Optional, List
from typing import Optional, List, Any

from langchain_core.runnables import RunnableSerializable

from agentuniverse.agent.action.knowledge.knowledge import Knowledge
from agentuniverse.agent.action.knowledge.knowledge_manager import KnowledgeManager
from agentuniverse.agent.action.knowledge.store.document import Document
from agentuniverse.agent.action.knowledge.store.query import Query
from agentuniverse.agent.action.tool.tool_manager import ToolManager
from agentuniverse.agent.agent_manager import AgentManager
from agentuniverse.agent.agent_model import AgentModel
from agentuniverse.agent.input_object import InputObject
from agentuniverse.agent.memory.chat_memory import ChatMemory
Expand All @@ -28,7 +30,7 @@
from agentuniverse.llm.llm import LLM
from agentuniverse.llm.llm_manager import LLMManager
from agentuniverse.prompt.prompt import Prompt
from agentuniverse.base.util.memory_util import generate_messages
from agentuniverse.base.util.memory_util import generate_messages, generate_memories

logging.getLogger().setLevel(logging.ERROR)

Expand Down Expand Up @@ -101,6 +103,7 @@ def run_all_actions(self, agent_model: AgentModel, planner_input: dict, input_ob
action: dict = agent_model.action or dict()
tools: list = action.get('tool') or list()
knowledge: list = action.get('knowledge') or list()
agents: list = action.get('agent') or list()

action_result: list = list()

Expand All @@ -120,6 +123,16 @@ def run_all_actions(self, agent_model: AgentModel, planner_input: dict, input_ob
for document in knowledge_res:
action_result.append(document.text)

for agent_name in agents:
agent = AgentManager().get_instance_obj(agent_name)
if agent is None:
continue
agent_input = {key: input_object.get_data(key) for key in agent.input_keys()}
output_object = agent.run(**agent_input)
action_result.append("\n".join([output_object.get_data(key)
for key in agent.output_keys()
if output_object.get_data(key) is not None]))

planner_input['background'] = planner_input['background'] or '' + "\n".join(action_result)

def handle_prompt(self, agent_model: AgentModel, planner_input: dict):
Expand Down Expand Up @@ -167,7 +180,26 @@ def stream_output(input_object: InputObject, data: dict):
input_object (InputObject): Agent input object.
data (dict): The data to be streamed.
"""
output_stream:Queue = input_object.get_data('output_stream', None)
output_stream: Queue = input_object.get_data('output_stream', None)
if output_stream is None:
return
output_stream.put_nowait(data)

def invoke_chain(self, agent_model: AgentModel, chain: RunnableSerializable[Any, str], planner_input: dict, chat_history,
input_object: InputObject):

if not input_object.get_data('output_stream'):
res = chain.invoke(input=planner_input, config={"configurable": {"session_id": "unused"}})
return res
result = []
for token in chain.stream(input=planner_input, config={"configurable": {"session_id": "unused"}}):
self.stream_output(input_object, {
'type': 'token',
'data': {
'chunk': token,
'agent_info': agent_model.info
}
})
result.append(token)
return "".join(result)

Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ def invoke(self, agent_model: AgentModel, planner_input: dict,
history_messages_key="chat_history",
input_messages_key=self.input_key,
) | StrOutputParser()
res = asyncio.run(
chain_with_history.ainvoke(input=planner_input, config={"configurable": {"session_id": "unused"}}))
res = self.invoke_chain(agent_model, chain_with_history, planner_input, chat_history, input_object)
return {**planner_input, self.output_key: res, 'chat_history': generate_memories(chat_history)}

def handle_prompt(self, agent_model: AgentModel, planner_input: dict) -> Prompt:
Expand Down
24 changes: 2 additions & 22 deletions agentuniverse/agent/plan/planner/rag_planner/rag_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@
# @Email : lc299034@antgroup.com
# @FileName: rag_planner.py
"""Rag planner module."""
from typing import Any

from langchain_core.chat_history import InMemoryChatMessageHistory
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableSerializable
from langchain_core.runnables.history import RunnableWithMessageHistory

from agentuniverse.agent.agent_model import AgentModel
Expand Down Expand Up @@ -56,26 +54,8 @@ def invoke(self, agent_model: AgentModel, planner_input: dict,
history_messages_key="chat_history",
input_messages_key=self.input_key,
) | StrOutputParser()

if not input_object.get_data('output_stream'):
res = chain_with_history.invoke(input=planner_input, config={"configurable": {"session_id": "unused"}})
return {**planner_input, self.output_key: res, 'chat_history': generate_memories(chat_history)}
else:
return self.stream(agent_model, chain_with_history, planner_input, chat_history, input_object)

def stream(self, agent_model: AgentModel, chain: RunnableSerializable[Any, str], planner_input: dict, chat_history,
input_object: InputObject):
result = []
for token in chain.stream(input=planner_input, config={"configurable": {"session_id": "unused"}}):
self.stream_output(input_object, {
'type': 'token',
'data': {
'token': token,
'agent_info': agent_model.info
}
})
result.append(token)
return {**planner_input, self.output_key: ''.join(result), 'chat_history': generate_memories(chat_history)}
res = self.invoke_chain(agent_model, chain_with_history, planner_input, chat_history, input_object)
return {**planner_input, self.output_key: res, 'chat_history': generate_memories(chat_history)}

def handle_prompt(self, agent_model: AgentModel, planner_input: dict) -> ChatPrompt:
"""Prompt module processing.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,14 @@ def invoke(self, agent_model: AgentModel, planner_input: dict,
max_iterations=agent_model.plan.get('planner').get("max_iterations", 15))

return agent_executor.invoke(input=planner_input, memory=memory.as_langchain() if memory else None,
chat_history=chat_history, config=self.get_run_config(input_object))
chat_history=chat_history, config=self.get_run_config(agent_model, input_object))

@staticmethod
def get_run_config(input_object: InputObject) -> RunnableConfig:
def get_run_config(agent_model: AgentModel, input_object: InputObject) -> RunnableConfig:
config = RunnableConfig()
callbacks = []
output_stream = input_object.get_data('output_stream')
callbacks.append(StreamOutPutCallbackHandler(output_stream))
callbacks.append(StreamOutPutCallbackHandler(output_stream, agent_info=agent_model.info))
config.setdefault("callbacks", callbacks)
return config

Expand Down
37 changes: 32 additions & 5 deletions agentuniverse/agent/plan/planner/react_planner/stream_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
class StreamOutPutCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out."""

def __init__(self, queue_stream: asyncio.Queue, color: Optional[str] = None) -> None:
def __init__(self, queue_stream: asyncio.Queue, color: Optional[str] = None, agent_info: dict = None) -> None:
"""Initialize callback handler."""
self.queueStream = queue_stream
self.color = color
if agent_info is None:
agent_info = {}
self.agent_info = agent_info

def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
Expand All @@ -32,7 +35,13 @@ def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any:
self.queueStream.put_nowait("Thought:"+action.log)
self.queueStream.put_nowait({
"type": "ReAct",
"data": {
"output": "\nThought:" + action.log,
"agent_info": self.agent_info
}
})

def on_tool_end(
self,
Expand All @@ -44,9 +53,21 @@ def on_tool_end(
) -> None:
"""If not the final action, print out observation."""
if observation_prefix is not None:
self.queueStream.put_nowait(observation_prefix + output)
self.queueStream.put_nowait({
"type": "ReAct",
"data": {
"output": '\n' + observation_prefix + output,
"agent_info": self.agent_info
}
})
else:
self.queueStream.put_nowait('Observation:'+output)
self.queueStream.put_nowait({
"type": "ReAct",
"data": {
"output": '\n Observation:' + output,
"agent_info": self.agent_info
}
})

def on_text(
self,
Expand All @@ -61,4 +82,10 @@ def on_agent_finish(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None:
"""Run on agent end."""
self.queueStream.put_nowait("Thought:" + finish.log + "\n")
self.queueStream.put_nowait({
"type": "ReAct",
"data": {
"output": '\nThought:' + finish.output,
"agent_info": self.agent_info
}
})
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ def invoke(self, agent_model: AgentModel, planner_input: dict, input_object: Inp
history_messages_key="chat_history",
input_messages_key=self.input_key,
) | StrOutputParser()
res = asyncio.run(
chain_with_history.ainvoke(input=planner_input, config={"configurable": {"session_id": "unused"}}))
res = self.invoke_chain(agent_model, chain_with_history, planner_input, chat_history, input_object)
return {**planner_input, self.output_key: res, 'chat_history': generate_memories(chat_history)}

def handle_prompt(self, agent_model: AgentModel, planner_input: dict) -> Prompt:
Expand Down
1 change: 0 additions & 1 deletion agentuniverse/llm/openai_style_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,4 +207,3 @@ def max_context_length(self) -> int:
"""Return the maximum length of the context."""
if super().max_context_length():
return super().max_context_length()
return 4000
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
name: 'info_sql_database_tool'
description: ''
tool_type: 'api'
input_keys: ['input']
langchain:
module: langchain_community.tools
class_name: InfoSQLDatabaseTool
init_params:
db_wrapper: demo_sqldb_wrapper
metadata:
type: 'TOOL'
module: 'sample_standard_app.app.core.tool.langchain_tool.sql_langchain_tool'
class: 'SqlLangchainTool'
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def execute(self, tool_input: ToolInput):
def initialize_by_component_configer(self, component_configer: ToolConfiger) -> 'Tool':
super().initialize_by_component_configer(component_configer)
self.tool = self.init_langchain_tool(component_configer)
if not component_configer.description:
if not component_configer.description and self.tool is not None:
self.description = self.tool.description
return self

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
name: 'list_sql_database_tool'
description: ''
tool_type: 'api'
input_keys: ['input']
langchain:
module: langchain_community.tools
class_name: ListSQLDatabaseTool
init_params:
db_wrapper: demo_sqldb_wrapper
metadata:
type: 'TOOL'
module: 'sample_standard_app.app.core.tool.langchain_tool.sql_langchain_tool'
class: 'SqlLangchainTool'
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
name: 'query_sql_database_tool'
description: ''
tool_type: 'api'
input_keys: ['input']
langchain:
module: langchain_community.tools
class_name: QuerySQLDataBaseTool
init_params:
db_wrapper: demo_sqldb_wrapper
metadata:
type: 'TOOL'
module: 'sample_standard_app.app.core.tool.langchain_tool.sql_langchain_tool'
class: 'SqlLangchainTool'
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# !/usr/bin/env python3
# -*- coding:utf-8 -*-

# @Time : 2024/7/9 20:04
# @Author : weizjajj
# @Email : weizhongjie.wzj@antgroup.com
# @FileName: sql_langchain_tool.py

from typing import Type, Optional

from langchain_core.tools import BaseTool, Tool as LangchainTool

from agentuniverse.agent.action.tool.tool import ToolInput
from agentuniverse.database.sqldb_wrapper_manager import SQLDBWrapperManager
from sample_standard_app.app.core.tool.langchain_tool.langchain_tool import LangChainTool


class SqlLangchainTool(LangChainTool):
db_wrapper_name: Optional[str] = ""
clz: Type[BaseTool] = BaseTool

def execute(self, tool_input: ToolInput):
if self.tool is None:
self.get_sql_database()
return super().execute(tool_input)

def get_sql_database(self):
db_wrapper = SQLDBWrapperManager().get_instance_obj(self.db_wrapper_name)
self.tool = self.clz(db=db_wrapper.sql_database)
self.description = self.tool.description

def as_langchain(self) -> LangchainTool:
if self.tool is None:
self.get_sql_database()
return super().as_langchain()

def get_langchain_tool(self, init_params: dict, clz: Type[BaseTool]):
self.db_wrapper_name = init_params.get("db_wrapper")
self.clz = clz

0 comments on commit 772909a

Please sign in to comment.