Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decorator for easier tool building #33439

Merged
merged 11 commits into from
Sep 18, 2024
76 changes: 25 additions & 51 deletions docs/source/en/agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -325,62 +325,37 @@ model = next(iter(list_models(filter=task, sort="downloads", direction=-1)))
print(model.id)
```

This code can be converted into a class that inherits from the [`Tool`] superclass.
This code can quickly be converted into a tool, just by wrapping it in a function and adding the `tool` decorator:


The custom tool needs:
- An attribute `name`, which corresponds to the name of the tool itself. The name usually describes what the tool does. Since the code returns the model with the most downloads for a task, let's name is `model_download_counter`.
- An attribute `description` is used to populate the agent's system prompt.
- An `inputs` attribute, which is a dictionary with keys `"type"` and `"description"`. It contains information that helps the Python interpreter make educated choices about the input.
- An `output_type` attribute, which specifies the output type.
- A `forward` method which contains the inference code to be executed.


```python
from transformers import Tool
from huggingface_hub import list_models

class HFModelDownloadsTool(Tool):
name = "model_download_counter"
description = (
"This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub. "
"It returns the name of the checkpoint."
)

inputs = {
"task": {
"type": "text",
"description": "the task category (such as text-classification, depth-estimation, etc)",
}
}
output_type = "text"

def forward(self, task: str):
model = next(iter(list_models(filter=task, sort="downloads", direction=-1)))
return model.id
```

Now that the custom `HfModelDownloadsTool` class is ready, you can save it to a file named `model_downloads.py` and import it for use.


```python
from model_downloads import HFModelDownloadsTool

tool = HFModelDownloadsTool()
```

You can also share your custom tool to the Hub by calling [`~Tool.push_to_hub`] on the tool. Make sure you've created a repository for it on the Hub and are using a token with read access.

```python
tool.push_to_hub("{your_username}/hf-model-downloads")
```py
from transformers import tool

@tool
def model_download_counter(task: str) -> str:
"""
This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub.
It returns the name of the checkpoint.

Args:
task: The task for which
"""
model = next(iter(list_models(filter="text-classification", sort="downloads", direction=-1)))
return model.id
```

Load the tool with the [`~Tool.load_tool`] function and pass it to the `tools` parameter in your agent.
The function needs:
- A clear name. The name usually describes what the tool does. Since the code returns the model with the most downloads for a task, let's put `model_download_counter`.
- Type hints on both inputs and output
- A description, that includes an 'Args:' part where each argument is described (without a type indication this time, it will be pulled from the type hint).
All these will be automatically baked into the agent's system prompt upon initialization: so strive to make them as clear as possible!

```python
from transformers import load_tool, CodeAgent
> [!TIP]
> This definition format is the same as tool schemas used in `apply_chat_template`, the only difference is the added `tool` decorator: read more on our tool use API [here](https://huggingface.co/blog/unified-tool-use#passing-tools-to-a-chat-template).

model_download_tool = load_tool("m-ric/hf-model-downloads")
Then you can directly initialize your agent:
```py
from transformers import CodeAgent
agent = CodeAgent(tools=[model_download_tool], llm_engine=llm_engine)
agent.run(
"Can you give me the name of the model that has the most downloads in the 'text-to-video' task on the Hugging Face Hub?"
Expand All @@ -400,7 +375,6 @@ print(f"The most downloaded model for the 'text-to-video' task is {most_download
And the output:
`"The most downloaded model for the 'text-to-video' task is ByteDance/AnimateDiff-Lightning."`


### Manage your agent's toolbox

If you have already initialized an agent, it is inconvenient to reinitialize it from scratch with a tool you want to use. With Transformers, you can manage an agent's toolbox by adding or replacing a tool.
Expand Down
63 changes: 62 additions & 1 deletion docs/source/en/agents_advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,68 @@ manager_agent.run("Who is the CEO of Hugging Face?")
> For an in-depth example of an efficient multi-agent implementation, see [how we pushed our multi-agent system to the top of the GAIA leaderboard](https://huggingface.co/blog/beating-gaia).


## Use tools from gradio or LangChain
## Advanced tool usage

### Directly define a tool by subclassing Tool, and share it to the Hub

Let's take again the tool example from main documentation, for which we had implemented a `tool` decorator.

If you need to add variation, like custom attributes for your too, you can build your tool following the fine-grained method: building a class that inherits from the [`Tool`] superclass.

The custom tool needs:
- An attribute `name`, which corresponds to the name of the tool itself. The name usually describes what the tool does. Since the code returns the model with the most downloads for a task, let's name is `model_download_counter`.
- An attribute `description` is used to populate the agent's system prompt.
- An `inputs` attribute, which is a dictionary with keys `"type"` and `"description"`. It contains information that helps the Python interpreter make educated choices about the input.
- An `output_type` attribute, which specifies the output type.
- A `forward` method which contains the inference code to be executed.

The types for both `inputs` and `output_type` should be amongst [Pydantic formats](https://docs.pydantic.dev/latest/concepts/json_schema/#generating-json-schema).

```python
from transformers import Tool
from huggingface_hub import list_models

class HFModelDownloadsTool(Tool):
name = "model_download_counter"
description = """
This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub.
It returns the name of the checkpoint."""

inputs = {
"task": {
"type": "string",
"description": "the task category (such as text-classification, depth-estimation, etc)",
}
}
output_type = "string"

def forward(self, task: str):
model = next(iter(list_models(filter=task, sort="downloads", direction=-1)))
return model.id
```

Now that the custom `HfModelDownloadsTool` class is ready, you can save it to a file named `model_downloads.py` and import it for use.


```python
from model_downloads import HFModelDownloadsTool

tool = HFModelDownloadsTool()
```

You can also share your custom tool to the Hub by calling [`~Tool.push_to_hub`] on the tool. Make sure you've created a repository for it on the Hub and are using a token with read access.

```python
tool.push_to_hub("{your_username}/hf-model-downloads")
```

Load the tool with the [`~Tool.load_tool`] function and pass it to the `tools` parameter in your agent.

```python
from transformers import load_tool, CodeAgent

model_download_tool = load_tool("m-ric/hf-model-downloads")
```

### Use gradio-tools

Expand Down
4 changes: 4 additions & 0 deletions docs/source/en/main_classes/agent.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ We provide two types of agents, based on the main [`Agent`] class:

[[autodoc]] load_tool

### tool

[[autodoc]] tool

### Tool

[[autodoc]] Tool
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
"launch_gradio_demo",
"load_tool",
"stream_to_gradio",
"tool",
],
"audio_utils": [],
"benchmark": [],
Expand Down Expand Up @@ -4819,6 +4820,7 @@
launch_gradio_demo,
load_tool,
stream_to_gradio,
tool,
)
from .configuration_utils import PretrainedConfig

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"agents": ["Agent", "CodeAgent", "ManagedAgent", "ReactAgent", "ReactCodeAgent", "ReactJsonAgent", "Toolbox"],
"llm_engine": ["HfApiEngine", "TransformersEngine"],
"monitoring": ["stream_to_gradio"],
"tools": ["PipelineTool", "Tool", "ToolCollection", "launch_gradio_demo", "load_tool"],
"tools": ["PipelineTool", "Tool", "ToolCollection", "launch_gradio_demo", "load_tool", "tool"],
}

try:
Expand All @@ -48,7 +48,7 @@
from .agents import Agent, CodeAgent, ManagedAgent, ReactAgent, ReactCodeAgent, ReactJsonAgent, Toolbox
from .llm_engine import HfApiEngine, TransformersEngine
from .monitoring import stream_to_gradio
from .tools import PipelineTool, Tool, ToolCollection, launch_gradio_demo, load_tool
from .tools import PipelineTool, Tool, ToolCollection, launch_gradio_demo, load_tool, tool

try:
if not is_torch_available():
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/agents/agent_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def to_string(self):
return self._path


AGENT_TYPE_MAPPING = {"text": AgentText, "image": AgentImage, "audio": AgentAudio}
AGENT_TYPE_MAPPING = {"string": AgentText, "image": AgentImage, "audio": AgentAudio}
INSTANCE_TYPE_MAPPING = {str: AgentText, ImageType: AgentImage}

if is_torch_available():
Expand Down
18 changes: 7 additions & 11 deletions src/transformers/agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .. import is_torch_available
from ..utils import logging as transformers_logging
from ..utils.import_utils import is_pygments_available
from .agent_types import AgentAudio, AgentImage, AgentText
from .agent_types import AgentAudio, AgentImage
from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools
from .llm_engine import HfApiEngine, MessageRole
from .prompts import (
Expand Down Expand Up @@ -626,10 +626,9 @@ def run(self, task: str, return_generated_code: bool = False, **kwargs):
Example:

```py
from transformers.agents import CodeAgent, PythonInterpreterTool
from transformers.agents import CodeAgent

python_interpreter = PythonInterpreterTool()
agent = CodeAgent(tools=[python_interpreter])
agent = CodeAgent(tools=[])
agent.run("What is the result of 2 power 3.7384?")
```
"""
Expand Down Expand Up @@ -1019,20 +1018,17 @@ def step(self):
arguments = {}
observation = self.execute_tool_call(tool_name, arguments)
observation_type = type(observation)
if observation_type == AgentText:
updated_information = str(observation).strip()
else:
# TODO: observation naming could allow for different names of same type
if observation_type in [AgentImage, AgentAudio]:
if observation_type == AgentImage:
observation_name = "image.png"
elif observation_type == AgentAudio:
observation_name = "audio.mp3"
else:
observation_name = "object.object"
# TODO: observation naming could allow for different names of same type

self.state[observation_name] = observation
updated_information = f"Stored '{observation_name}' in memory."

else:
updated_information = str(observation).strip()
self.logger.info(updated_information)
current_step_logs["observation"] = updated_information
return current_step_logs
Expand Down
9 changes: 4 additions & 5 deletions src/transformers/agents/default_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,7 @@ class PythonInterpreterTool(Tool):
name = "python_interpreter"
description = "This is a tool that evaluates python code. It can be used to perform calculations."

output_type = "text"
available_tools = BASE_PYTHON_TOOLS.copy()
output_type = "string"

def __init__(self, *args, authorized_imports=None, **kwargs):
if authorized_imports is None:
Expand All @@ -162,7 +161,7 @@ def __init__(self, *args, authorized_imports=None, **kwargs):
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(authorized_imports))
self.inputs = {
"code": {
"type": "text",
"type": "string",
"description": (
"The code snippet to evaluate. All variables used in this snippet must be defined in this same snippet, "
f"else you will get an error. This code can only import the following python libraries: {authorized_imports}."
Expand All @@ -173,15 +172,15 @@ def __init__(self, *args, authorized_imports=None, **kwargs):

def forward(self, code):
output = str(
evaluate_python_code(code, static_tools=self.available_tools, authorized_imports=self.authorized_imports)
evaluate_python_code(code, static_tools=BASE_PYTHON_TOOLS, authorized_imports=self.authorized_imports)
)
return output


class FinalAnswerTool(Tool):
name = "final_answer"
description = "Provides a final answer to the given problem."
inputs = {"answer": {"type": "text", "description": "The final answer to the problem"}}
inputs = {"answer": {"type": "any", "description": "The final answer to the problem"}}
output_type = "any"

def forward(self, answer):
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/agents/document_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

class DocumentQuestionAnsweringTool(PipelineTool):
default_checkpoint = "naver-clova-ix/donut-base-finetuned-docvqa"
description = "This is a tool that answers a question about an document (pdf). It returns a text that contains the answer to the question."
description = "This is a tool that answers a question about an document (pdf). It returns a string that contains the answer to the question."
name = "document_qa"
pre_processor_class = AutoProcessor
model_class = VisionEncoderDecoderModel
Expand All @@ -41,9 +41,9 @@ class DocumentQuestionAnsweringTool(PipelineTool):
"type": "image",
"description": "The image containing the information. Can be a PIL Image or a string path to the image.",
},
"question": {"type": "text", "description": "The question in English"},
"question": {"type": "string", "description": "The question in English"},
}
output_type = "text"
output_type = "string"

def __init__(self, *args, **kwargs):
if not is_vision_available():
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/agents/image_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ class ImageQuestionAnsweringTool(PipelineTool):
"type": "image",
"description": "The image containing the information. Can be a PIL Image or a string path to the image.",
},
"question": {"type": "text", "description": "The question in English"},
"question": {"type": "string", "description": "The question in English"},
}
output_type = "text"
output_type = "string"

def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/agents/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"):
Action:
{
"action": "image_generator",
"action_input": {"text": ""A portrait of John Doe, a 55-year-old man living in Canada.""}
"action_input": {"prompt": "A portrait of John Doe, a 55-year-old man living in Canada."}
}<end_action>
Observation: "image.png"

Expand Down
6 changes: 3 additions & 3 deletions src/transformers/agents/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class DuckDuckGoSearchTool(Tool):
name = "web_search"
description = """Perform a web search based on your query (think a Google search) then returns the top search results as a list of dict elements.
Each result has keys 'title', 'href' and 'body'."""
inputs = {"query": {"type": "text", "description": "The search query to perform."}}
inputs = {"query": {"type": "string", "description": "The search query to perform."}}
output_type = "any"

def forward(self, query: str) -> str:
Expand All @@ -45,11 +45,11 @@ class VisitWebpageTool(Tool):
description = "Visits a wbepage at the given url and returns its content as a markdown string."
inputs = {
"url": {
"type": "text",
"type": "string",
"description": "The url of the webpage to visit.",
}
}
output_type = "text"
output_type = "string"

def forward(self, url: str) -> str:
try:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/agents/speech_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class SpeechToTextTool(PipelineTool):
model_class = WhisperForConditionalGeneration

inputs = {"audio": {"type": "audio", "description": "The audio to transcribe"}}
output_type = "text"
output_type = "string"

def encode(self, audio):
return self.pre_processor(audio, return_tensors="pt")
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/agents/text_to_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class TextToSpeechTool(PipelineTool):
model_class = SpeechT5ForTextToSpeech
post_processor_class = SpeechT5HifiGan

inputs = {"text": {"type": "text", "description": "The text to read out loud (in English)"}}
inputs = {"text": {"type": "string", "description": "The text to read out loud (in English)"}}
output_type = "audio"

def setup(self):
Expand Down
Loading
Loading