Skip to content

Commit

Permalink
Added @tool decorator to generate function tools
Browse files Browse the repository at this point in the history
  • Loading branch information
rishit-singh committed Sep 5, 2024
1 parent cf309f5 commit e02329f
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 10 deletions.
98 changes: 89 additions & 9 deletions examples/function_agent.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@

import json
import sys
import inspect

sys.path.append("..")
sys.path.append("../src")

from examples import parser
from examples.yt import YouTubeDataAPI

import os
from typing_extensions import Any
from typing_extensions import Any, Callable
from GroqContext import WebGroqContext, WebGroqMessage
from tinytune.llmcontext import LLMContext, Message
from tinytune.pipeline import Pipeline
from tinytune.prompt import prompt_job
from examples.parser import Parse

from dotenv import load_dotenv

Expand All @@ -27,7 +31,33 @@ def Callback(chunk):

yt_api = YouTubeDataAPI(str(os.getenv("YT_KEY")))

def tool():
def wrapper(func: Callable):
spec = inspect.getfullargspec(func)

doc = Parse(str(func.__doc__))

return (func, {
"name": func.__name__,
"description": doc["title"],
"parameters": {
"type": "object",
"properties": doc["params"]["Args"] if isinstance(["params"], dict) else [ { key: str(spec.annotations[key]) } for key in spec.annotations ]
},
"repr": func.__repr__()
})

return wrapper

@tool()
def GetVideos(query: str, max: int) -> str:
"""
Gets the videos based on a query and max arguments
Args:
query - Search query
max - Max search results
"""
videos = yt_api.search_videos(query, max_results=max, order='relevance', video_duration='short')

# Extract the video IDs
Expand Down Expand Up @@ -59,13 +89,48 @@ def GetVideos(query: str, max: int) -> str:

return result


class ToolManager:
ToolMap: dict = {}
# @staticmethod
# def Register(func: Callable):
# tool: dict = GenerateTool(func)
# ToolManager.ToolMap[tool["name"]] = tool

@staticmethod
def Call(tool: dict):
return

# Agent 1
def Classifier(context: LLMContext, videos: str) -> Pipeline:
pipeline = Pipeline(context)

@prompt_job(id="Setup", context=context)
def Setup(id: str, context: LLMContext, prevResult: Any, *args):
return context.Prompt(WebGroqMessage("user", "You are a YT video analyzer. You take a list of youtube videos, stats, and descriptions, and group the similar ones together.")).Run(stream=True)

func, tool = GetVideos

return (context.Prompt(WebGroqMessage("user", f"""
You have access to the following functions:
Use the function '{tool["name"]}' to '{tool["description"]}':
{json.dumps(tool)}
If you choose to call a function ONLY reply in the following format with no prefix or suffix:
<function=example_function_name>{{\"example_name\": \"example_value\"}}</function>
Reminder:
- Function calls MUST follow the specified format, start with <function= and end with </function>
- Required parameters MUST be specified
- Only call one function at a time
- Put the entire function call reply on one line
- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls
""")).Run(stream=True)
.Prompt(WebGroqMessage("user", "Find me a fun youtube video. Call the appropriate function"))
.Run(stream=True))

print(LLM.Messages)

@prompt_job(id="Classify", context=context)
def Classify(id: str, context: LLMContext, prevResult: Any, *args):
Expand All @@ -75,7 +140,7 @@ def Classify(id: str, context: LLMContext, prevResult: Any, *args):
def Extract(id: str, context: LLMContext, prevResult: Any, *args):
return context.Prompt(WebGroqMessage("user", f"{prevResult} extract just the json from this, respond with plain JSON text, no backticks, nothing extra.")).Run(stream=True).Messages[-1].Content

return pipeline.AddJob(Setup).AddJob(Classify).AddJob(Extract)
return pipeline.AddJob(Setup)

# Agent 2
def Analyzer(context: LLMContext) -> Pipeline:
Expand All @@ -92,11 +157,26 @@ def Plot(id: str, context: LLMContext, prevResult: Any, *args):
return pipeline.AddJob(Analyze).AddJob(Plot)


with open("plot.py", 'w') as fp:
pipeline = Pipeline(LLM)
Classifier(LLM, "")()
print('\n', LLM.Messages[-1].Content)
# with open("plot.py", 'w') as fp:
# pipeline = Pipeline(LLM)

# results = (pipeline.AddJob(Classifier(LLM, GetVideos("AI alignment", 155555)))
# .AddJob(Analyzer(LLM))
# .Run(stream=True))

# fp.write(str(results))

def Foo(x: int, y: str, z: float):
pass
"""
A random function.
results = (pipeline.AddJob(Classifier(LLM, GetVideos("AI alignment", 155555)))
.AddJob(Analyzer(LLM))
.Run(stream=True))
Args:
x - some int
y - some string
z - some float
"""

fp.write(str(results))
# print(GetVideosTool)
1 change: 0 additions & 1 deletion examples/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def Parse(docString: str) -> dict:
if (line == '' or line == "\n"):
continue


if (colon >= 1):
key = "params"

Expand Down

0 comments on commit e02329f

Please sign in to comment.