Skip to content

Commit

Permalink
Decorator for easier tool building (huggingface#33439)
Browse files Browse the repository at this point in the history
* Decorator for tool building
  • Loading branch information
aymeric-roucher authored and amyeroberts committed Oct 2, 2024
1 parent 0ac6dee commit 52e32fc
Show file tree
Hide file tree
Showing 21 changed files with 294 additions and 113 deletions.
76 changes: 25 additions & 51 deletions docs/source/en/agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -325,62 +325,37 @@ model = next(iter(list_models(filter=task, sort="downloads", direction=-1)))
print(model.id)
```

This code can be converted into a class that inherits from the [`Tool`] superclass.
This code can quickly be converted into a tool, just by wrapping it in a function and adding the `tool` decorator:


The custom tool needs:
- An attribute `name`, which corresponds to the name of the tool itself. The name usually describes what the tool does. Since the code returns the model with the most downloads for a task, let's name is `model_download_counter`.
- An attribute `description` is used to populate the agent's system prompt.
- An `inputs` attribute, which is a dictionary with keys `"type"` and `"description"`. It contains information that helps the Python interpreter make educated choices about the input.
- An `output_type` attribute, which specifies the output type.
- A `forward` method which contains the inference code to be executed.


```python
from transformers import Tool
from huggingface_hub import list_models

class HFModelDownloadsTool(Tool):
name = "model_download_counter"
description = (
"This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub. "
"It returns the name of the checkpoint."
)

inputs = {
"task": {
"type": "text",
"description": "the task category (such as text-classification, depth-estimation, etc)",
}
}
output_type = "text"

def forward(self, task: str):
model = next(iter(list_models(filter=task, sort="downloads", direction=-1)))
return model.id
```

Now that the custom `HfModelDownloadsTool` class is ready, you can save it to a file named `model_downloads.py` and import it for use.


```python
from model_downloads import HFModelDownloadsTool

tool = HFModelDownloadsTool()
```

You can also share your custom tool to the Hub by calling [`~Tool.push_to_hub`] on the tool. Make sure you've created a repository for it on the Hub and are using a token with read access.

```python
tool.push_to_hub("{your_username}/hf-model-downloads")
```py
from transformers import tool

@tool
def model_download_counter(task: str) -> str:
"""
This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub.
It returns the name of the checkpoint.
Args:
task: The task for which
"""
model = next(iter(list_models(filter="text-classification", sort="downloads", direction=-1)))
return model.id
```

Load the tool with the [`~Tool.load_tool`] function and pass it to the `tools` parameter in your agent.
The function needs:
- A clear name. The name usually describes what the tool does. Since the code returns the model with the most downloads for a task, let's put `model_download_counter`.
- Type hints on both inputs and output
- A description, that includes an 'Args:' part where each argument is described (without a type indication this time, it will be pulled from the type hint).
All these will be automatically baked into the agent's system prompt upon initialization: so strive to make them as clear as possible!

```python
from transformers import load_tool, CodeAgent
> [!TIP]
> This definition format is the same as tool schemas used in `apply_chat_template`, the only difference is the added `tool` decorator: read more on our tool use API [here](https://huggingface.co/blog/unified-tool-use#passing-tools-to-a-chat-template).
model_download_tool = load_tool("m-ric/hf-model-downloads")
Then you can directly initialize your agent:
```py
from transformers import CodeAgent
agent = CodeAgent(tools=[model_download_tool], llm_engine=llm_engine)
agent.run(
"Can you give me the name of the model that has the most downloads in the 'text-to-video' task on the Hugging Face Hub?"
Expand All @@ -400,7 +375,6 @@ print(f"The most downloaded model for the 'text-to-video' task is {most_download
And the output:
`"The most downloaded model for the 'text-to-video' task is ByteDance/AnimateDiff-Lightning."`


### Manage your agent's toolbox

If you have already initialized an agent, it is inconvenient to reinitialize it from scratch with a tool you want to use. With Transformers, you can manage an agent's toolbox by adding or replacing a tool.
Expand Down
63 changes: 62 additions & 1 deletion docs/source/en/agents_advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,68 @@ manager_agent.run("Who is the CEO of Hugging Face?")
> For an in-depth example of an efficient multi-agent implementation, see [how we pushed our multi-agent system to the top of the GAIA leaderboard](https://huggingface.co/blog/beating-gaia).

## Use tools from gradio or LangChain
## Advanced tool usage

### Directly define a tool by subclassing Tool, and share it to the Hub

Let's take again the tool example from main documentation, for which we had implemented a `tool` decorator.

If you need to add variation, like custom attributes for your too, you can build your tool following the fine-grained method: building a class that inherits from the [`Tool`] superclass.

The custom tool needs:
- An attribute `name`, which corresponds to the name of the tool itself. The name usually describes what the tool does. Since the code returns the model with the most downloads for a task, let's name is `model_download_counter`.
- An attribute `description` is used to populate the agent's system prompt.
- An `inputs` attribute, which is a dictionary with keys `"type"` and `"description"`. It contains information that helps the Python interpreter make educated choices about the input.
- An `output_type` attribute, which specifies the output type.
- A `forward` method which contains the inference code to be executed.

The types for both `inputs` and `output_type` should be amongst [Pydantic formats](https://docs.pydantic.dev/latest/concepts/json_schema/#generating-json-schema).

```python
from transformers import Tool
from huggingface_hub import list_models

class HFModelDownloadsTool(Tool):
name = "model_download_counter"
description = """
This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub.
It returns the name of the checkpoint."""

inputs = {
"task": {
"type": "string",
"description": "the task category (such as text-classification, depth-estimation, etc)",
}
}
output_type = "string"

def forward(self, task: str):
model = next(iter(list_models(filter=task, sort="downloads", direction=-1)))
return model.id
```

Now that the custom `HfModelDownloadsTool` class is ready, you can save it to a file named `model_downloads.py` and import it for use.


```python
from model_downloads import HFModelDownloadsTool

tool = HFModelDownloadsTool()
```

You can also share your custom tool to the Hub by calling [`~Tool.push_to_hub`] on the tool. Make sure you've created a repository for it on the Hub and are using a token with read access.

```python
tool.push_to_hub("{your_username}/hf-model-downloads")
```

Load the tool with the [`~Tool.load_tool`] function and pass it to the `tools` parameter in your agent.

```python
from transformers import load_tool, CodeAgent

model_download_tool = load_tool("m-ric/hf-model-downloads")
```

### Use gradio-tools

Expand Down
4 changes: 4 additions & 0 deletions docs/source/en/main_classes/agent.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ We provide two types of agents, based on the main [`Agent`] class:

[[autodoc]] load_tool

### tool

[[autodoc]] tool

### Tool

[[autodoc]] Tool
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
"launch_gradio_demo",
"load_tool",
"stream_to_gradio",
"tool",
],
"audio_utils": [],
"benchmark": [],
Expand Down Expand Up @@ -4819,6 +4820,7 @@
launch_gradio_demo,
load_tool,
stream_to_gradio,
tool,
)
from .configuration_utils import PretrainedConfig

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"agents": ["Agent", "CodeAgent", "ManagedAgent", "ReactAgent", "ReactCodeAgent", "ReactJsonAgent", "Toolbox"],
"llm_engine": ["HfApiEngine", "TransformersEngine"],
"monitoring": ["stream_to_gradio"],
"tools": ["PipelineTool", "Tool", "ToolCollection", "launch_gradio_demo", "load_tool"],
"tools": ["PipelineTool", "Tool", "ToolCollection", "launch_gradio_demo", "load_tool", "tool"],
}

try:
Expand All @@ -48,7 +48,7 @@
from .agents import Agent, CodeAgent, ManagedAgent, ReactAgent, ReactCodeAgent, ReactJsonAgent, Toolbox
from .llm_engine import HfApiEngine, TransformersEngine
from .monitoring import stream_to_gradio
from .tools import PipelineTool, Tool, ToolCollection, launch_gradio_demo, load_tool
from .tools import PipelineTool, Tool, ToolCollection, launch_gradio_demo, load_tool, tool

try:
if not is_torch_available():
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/agents/agent_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def to_string(self):
return self._path


AGENT_TYPE_MAPPING = {"text": AgentText, "image": AgentImage, "audio": AgentAudio}
AGENT_TYPE_MAPPING = {"string": AgentText, "image": AgentImage, "audio": AgentAudio}
INSTANCE_TYPE_MAPPING = {str: AgentText, ImageType: AgentImage}

if is_torch_available():
Expand Down
18 changes: 7 additions & 11 deletions src/transformers/agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .. import is_torch_available
from ..utils import logging as transformers_logging
from ..utils.import_utils import is_pygments_available
from .agent_types import AgentAudio, AgentImage, AgentText
from .agent_types import AgentAudio, AgentImage
from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools
from .llm_engine import HfApiEngine, MessageRole
from .prompts import (
Expand Down Expand Up @@ -626,10 +626,9 @@ def run(self, task: str, return_generated_code: bool = False, **kwargs):
Example:
```py
from transformers.agents import CodeAgent, PythonInterpreterTool
from transformers.agents import CodeAgent
python_interpreter = PythonInterpreterTool()
agent = CodeAgent(tools=[python_interpreter])
agent = CodeAgent(tools=[])
agent.run("What is the result of 2 power 3.7384?")
```
"""
Expand Down Expand Up @@ -1019,20 +1018,17 @@ def step(self):
arguments = {}
observation = self.execute_tool_call(tool_name, arguments)
observation_type = type(observation)
if observation_type == AgentText:
updated_information = str(observation).strip()
else:
# TODO: observation naming could allow for different names of same type
if observation_type in [AgentImage, AgentAudio]:
if observation_type == AgentImage:
observation_name = "image.png"
elif observation_type == AgentAudio:
observation_name = "audio.mp3"
else:
observation_name = "object.object"
# TODO: observation naming could allow for different names of same type

self.state[observation_name] = observation
updated_information = f"Stored '{observation_name}' in memory."

else:
updated_information = str(observation).strip()
self.logger.info(updated_information)
current_step_logs["observation"] = updated_information
return current_step_logs
Expand Down
9 changes: 4 additions & 5 deletions src/transformers/agents/default_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,7 @@ class PythonInterpreterTool(Tool):
name = "python_interpreter"
description = "This is a tool that evaluates python code. It can be used to perform calculations."

output_type = "text"
available_tools = BASE_PYTHON_TOOLS.copy()
output_type = "string"

def __init__(self, *args, authorized_imports=None, **kwargs):
if authorized_imports is None:
Expand All @@ -162,7 +161,7 @@ def __init__(self, *args, authorized_imports=None, **kwargs):
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(authorized_imports))
self.inputs = {
"code": {
"type": "text",
"type": "string",
"description": (
"The code snippet to evaluate. All variables used in this snippet must be defined in this same snippet, "
f"else you will get an error. This code can only import the following python libraries: {authorized_imports}."
Expand All @@ -173,15 +172,15 @@ def __init__(self, *args, authorized_imports=None, **kwargs):

def forward(self, code):
output = str(
evaluate_python_code(code, static_tools=self.available_tools, authorized_imports=self.authorized_imports)
evaluate_python_code(code, static_tools=BASE_PYTHON_TOOLS, authorized_imports=self.authorized_imports)
)
return output


class FinalAnswerTool(Tool):
name = "final_answer"
description = "Provides a final answer to the given problem."
inputs = {"answer": {"type": "text", "description": "The final answer to the problem"}}
inputs = {"answer": {"type": "any", "description": "The final answer to the problem"}}
output_type = "any"

def forward(self, answer):
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/agents/document_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

class DocumentQuestionAnsweringTool(PipelineTool):
default_checkpoint = "naver-clova-ix/donut-base-finetuned-docvqa"
description = "This is a tool that answers a question about an document (pdf). It returns a text that contains the answer to the question."
description = "This is a tool that answers a question about an document (pdf). It returns a string that contains the answer to the question."
name = "document_qa"
pre_processor_class = AutoProcessor
model_class = VisionEncoderDecoderModel
Expand All @@ -41,9 +41,9 @@ class DocumentQuestionAnsweringTool(PipelineTool):
"type": "image",
"description": "The image containing the information. Can be a PIL Image or a string path to the image.",
},
"question": {"type": "text", "description": "The question in English"},
"question": {"type": "string", "description": "The question in English"},
}
output_type = "text"
output_type = "string"

def __init__(self, *args, **kwargs):
if not is_vision_available():
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/agents/image_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ class ImageQuestionAnsweringTool(PipelineTool):
"type": "image",
"description": "The image containing the information. Can be a PIL Image or a string path to the image.",
},
"question": {"type": "text", "description": "The question in English"},
"question": {"type": "string", "description": "The question in English"},
}
output_type = "text"
output_type = "string"

def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/agents/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"):
Action:
{
"action": "image_generator",
"action_input": {"text": ""A portrait of John Doe, a 55-year-old man living in Canada.""}
"action_input": {"prompt": "A portrait of John Doe, a 55-year-old man living in Canada."}
}<end_action>
Observation: "image.png"
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/agents/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class DuckDuckGoSearchTool(Tool):
name = "web_search"
description = """Perform a web search based on your query (think a Google search) then returns the top search results as a list of dict elements.
Each result has keys 'title', 'href' and 'body'."""
inputs = {"query": {"type": "text", "description": "The search query to perform."}}
inputs = {"query": {"type": "string", "description": "The search query to perform."}}
output_type = "any"

def forward(self, query: str) -> str:
Expand All @@ -45,11 +45,11 @@ class VisitWebpageTool(Tool):
description = "Visits a wbepage at the given url and returns its content as a markdown string."
inputs = {
"url": {
"type": "text",
"type": "string",
"description": "The url of the webpage to visit.",
}
}
output_type = "text"
output_type = "string"

def forward(self, url: str) -> str:
try:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/agents/speech_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class SpeechToTextTool(PipelineTool):
model_class = WhisperForConditionalGeneration

inputs = {"audio": {"type": "audio", "description": "The audio to transcribe"}}
output_type = "text"
output_type = "string"

def encode(self, audio):
return self.pre_processor(audio, return_tensors="pt")
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/agents/text_to_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class TextToSpeechTool(PipelineTool):
model_class = SpeechT5ForTextToSpeech
post_processor_class = SpeechT5HifiGan

inputs = {"text": {"type": "text", "description": "The text to read out loud (in English)"}}
inputs = {"text": {"type": "string", "description": "The text to read out loud (in English)"}}
output_type = "audio"

def setup(self):
Expand Down
Loading

0 comments on commit 52e32fc

Please sign in to comment.