Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 Fix functions #819

Merged
merged 13 commits into from
Jun 17, 2023
4 changes: 2 additions & 2 deletions next/src/components/AppTitle.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ const AppTitle = () => {

return (
<div id="title" className="relative flex flex-col items-center font-mono">
<div className="flex flex-row items-start shadow-2xl">
<div className="flex flex-row items-start">
<span className="text-4xl font-bold text-[#C0C0C0] xs:text-5xl sm:text-6xl">Agent</span>
<span className="text-4xl font-bold text-white xs:text-5xl sm:text-6xl">GPT</span>
<PopIn delay={0.5}>
Expand All @@ -31,4 +31,4 @@ const AppTitle = () => {
);
};

export default AppTitle;
export default AppTitle;
3 changes: 3 additions & 0 deletions next/src/stores/agentStore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ export const useAgentStore = createSelectors(
{
name: "agent-storage-v2",
storage: createJSONStorage(() => localStorage),
partialize: (state) => ({
tools: state.tools,
}),
}
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from loguru import logger

from reworkd_platform.web.api.agent.agent_service.agent_service import AgentService
from reworkd_platform.web.api.agent.analysis import Analysis
from reworkd_platform.web.api.agent.analysis import Analysis, AnalysisArguments
from reworkd_platform.web.api.agent.helpers import (
call_model_with_handling,
openai_error_handler,
Expand All @@ -19,9 +19,11 @@
start_goal_prompt,
)
from reworkd_platform.web.api.agent.task_output_parser import TaskOutputParser
from reworkd_platform.web.api.agent.tools.open_ai_function import analysis_function
from reworkd_platform.web.api.agent.tools.open_ai_function import get_tool_function
from reworkd_platform.web.api.agent.tools.tools import (
get_default_tool,
get_tool_from_name,
get_tool_name,
get_user_tools,
)
from reworkd_platform.web.api.errors import OpenAIError
Expand Down Expand Up @@ -67,15 +69,19 @@ async def analyze_task_agent(
task=task,
language=self.language,
).to_messages(),
functions=[analysis_function(get_user_tools(tool_names))],
functions=list(map(get_tool_function, get_user_tools(tool_names))),
)

function_call = message.additional_kwargs.get("function_call", {})
completion = function_call.get("arguments", "")

try:
pydantic_parser = PydanticOutputParser(pydantic_object=Analysis)
return parse_with_handling(pydantic_parser, completion)
pydantic_parser = PydanticOutputParser(pydantic_object=AnalysisArguments)
analysis_arguments = parse_with_handling(pydantic_parser, completion)
return Analysis(
action=function_call.get("name", get_tool_name(get_default_tool())),
**analysis_arguments.dict(),
)
except OpenAIError:
return Analysis.get_default_analysis()

Expand Down
11 changes: 9 additions & 2 deletions platform/reworkd_platform/web/api/agent/analysis.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
from pydantic import BaseModel, validator


class Analysis(BaseModel):
class AnalysisArguments(BaseModel):
"""
Arguments for the analysis function of a tool. OpenAI functions will resolve these values but leave out the action.
"""

reasoning: str
action: str
arg: str


class Analysis(AnalysisArguments):
action: str

@validator("action")
def action_must_be_valid_tool(cls, v: str) -> str:
# TODO: Remove circular import
Expand Down
2 changes: 1 addition & 1 deletion platform/reworkd_platform/web/api/agent/model_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def create_model(
max_tokens=model_settings.max_tokens,
streaming=streaming,
max_retries=5,
user=user.email,
model_kwargs={"user": user.email},
)


Expand Down
13 changes: 7 additions & 6 deletions platform/reworkd_platform/web/api/agent/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@
High level objective: "{goal}"
Current task: "{task}"

Based on this information, use the "analyze" function to return an object for what specific 'tool' to use.
Select the correct tool by being smart and efficient. Provide concrete reasoning for the tool choice detailing
your overall plan and any concerns you may have. Your reasoning should be no more than three sentences.
Ensure "reasoning" and only "reasoning" is in the {language} language.
Based on this information, use the best function to make progress or accomplish the task entirely.
Select the correct function by being smart and efficient. Ensure "reasoning" and only "reasoning" is in the
{language} language.

Note you MUST select a function.
""",
input_variables=["goal", "task", "language"],
)
Expand Down Expand Up @@ -57,8 +58,8 @@
the following overall objective `{goal}` and the following sub-task, `{task}`.

Perform the task by understanding the problem, extracting variables, and being smart
and efficient. Provide a descriptive response, make decisions yourself when
confronted with choices and provide reasoning for ideas / decisions.
and efficient. Write a detailed response that address the task.
When confronted with choices, make a decision yourself with reasoning.
""",
input_variables=["goal", "language", "task"],
)
Expand Down
5 changes: 1 addition & 4 deletions platform/reworkd_platform/web/api/agent/tools/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@


class Code(Tool):
description = (
"Useful for writing, reviewing, and refactoring code. Can also fix bugs, "
"and explain programming concepts."
)
description = "Should only be used to write code, refactor code, fix code bugs, and explain programming concepts."
public_description = "Write and review code."

async def call(
Expand Down
11 changes: 6 additions & 5 deletions platform/reworkd_platform/web/api/agent/tools/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,13 @@ async def get_open_ai_image(input_str: str) -> str:


class Image(Tool):
description = (
"Used to sketch, draw, or generate an image. The input string "
"should be a detailed description of the image touching on image "
"style, image focus, color, etc"
)
description = "Used to sketch, draw, or generate an image."
public_description = "Generate AI images."
arg_description = (
"The input prompt to the image generator. "
"This should be a detailed description of the image touching on image "
"style, image focus, color, etc."
)

async def call(
self, goal: str, task: str, input_str: str
Expand Down
26 changes: 9 additions & 17 deletions platform/reworkd_platform/web/api/agent/tools/open_ai_function.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Type, TypedDict
from typing import Type, TypedDict

from reworkd_platform.web.api.agent.tools.tool import Tool
from reworkd_platform.web.api.agent.tools.tools import get_tool_name
Expand All @@ -15,37 +15,29 @@ class FunctionDescription(TypedDict):
"""The parameters of the function."""


def analysis_function(tools: List[Type[Tool]]) -> FunctionDescription:
"""A function that will return the tool specifications from OpenAI"""
tool_names = [get_tool_name(tool) for tool in tools]
tool_name_to_description = [
f"{get_tool_name(tool)}: {tool.description}" for tool in tools
]
def get_tool_function(tool: Type[Tool]) -> FunctionDescription:
"""A function that will return the tool's function specification"""
name = get_tool_name(tool)

return {
"name": "analysis",
"description": (
"Return an object for what specific 'action'/'tool' to call based on their descriptions:\n"
f"{tool_name_to_description}"
),
"name": name,
"description": tool.description,
"parameters": {
"type": "object",
"properties": {
"reasoning": {
"type": "string",
"description": (
f"You must use one of the tools available to you: {tool_names}"
"This reasoning should be how you will accomplish the task with the provided action."
f"Reasoning is how the task will be accomplished with the current function. "
"Detail your overall plan along with any concerns you have."
"Ensure this reasoning value is in the user defined langauge "
),
},
"action": {"type": "string", "enum": tool_names},
"arg": {
"type": "string",
"description": "The appropriate action argument based on the action type",
"description": tool.arg_description,
},
},
"required": ["reasoning", "action", "arg"],
"required": ["reasoning", "arg"],
},
}
2 changes: 1 addition & 1 deletion platform/reworkd_platform/web/api/agent/tools/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ class Search(Tool):
description = (
"Search Google for short up to date searches for simple questions "
"news and people.\n"
"The input arg should be the search query. Ensure this value is NOT empty."
)
public_description = "Search google for information about current events."
arg_description = "The search query. Ensure this value is NOT empty."

@staticmethod
def available() -> bool:
Expand Down
2 changes: 2 additions & 0 deletions platform/reworkd_platform/web/api/agent/tools/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
class Tool(ABC):
description: str = ""
public_description: str = ""
arg_description: str = "Always leave as an empty string"

model: BaseChatModel
language: str

Expand Down
11 changes: 11 additions & 0 deletions platform/reworkd_platform/web/api/agent/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ class CitedSnippet:
text: str
url: str

def __repr__(self):
"""
The string representation the AI model will see
"""
return f"{{i: {self.index}, text: {self.text}, url: {self.url}}}"


def summarize(
model: BaseChatModel,
Expand All @@ -25,6 +31,11 @@ def summarize(

chain = LLMChain(llm=model, prompt=summarize_prompt)

print(
summarize_prompt.format_prompt(
goal=goal, query=query, snippets=snippets, language=language
)
)
return StreamingResponse.from_chain(
chain,
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@ class Wikipedia(Tool):
description = (
"Search Wikipedia for information about historical people, companies, events, "
"places or research. This should be used over search for broad overviews of "
"specific nouns.\n The argument should be a simple query of just the noun."
"specific nouns."
)
public_description = "Search Wikipedia for historical information."
arg_description = "A simple query string of just the noun in question."

async def call(self, goal: str, task: str, input_str: str) -> StreamingResponse:
wikipedia_client = WikipediaAPIWrapper(
wiki_client=None, # Meta private value but mypy will complain its missing
)

# TODO: Make the below async
wikipedia_search = wikipedia_client.run(input_str)
# return summarize(self.model, self.language, goal, task, [wikipedia_search])
Expand Down