Skip to content

Commit

Permalink
Agents use grammar (#31735)
Browse files Browse the repository at this point in the history
* Allow optional use of grammars to constrain generation
  • Loading branch information
aymeric-roucher authored Aug 7, 2024
1 parent c54a6f9 commit e0d8253
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 24 deletions.
8 changes: 5 additions & 3 deletions docs/source/en/agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,12 @@ def llm_engine(messages, stop_sequences=["Task"]) -> str:
```

You could use any `llm_engine` method as long as:
1. it follows the [messages format](./chat_templating.md) for its input (`List[Dict[str, str]]`) and returns a `str`
2. it stops generating outputs at the sequences passed in the argument `stop`
1. it follows the [messages format](./chat_templating.md) (`List[Dict[str, str]]`) for its input `messages`, and it returns a `str`.
2. it stops generating outputs at the sequences passed in the argument `stop_sequences`

You also need a `tools` argument which accepts a list of `Tools`. You can provide an empty list for `tools`, but use the default toolbox with the optional argument `add_base_tools=True`.
Additionally, `llm_engine` can also take a `grammar` argument. In the case where you specify a `grammar` upon agent initialization, this argument will be passed to the calls to llm_engine, with the `grammar` that you defined upon initialization, to allow [constrained generation](https://huggingface.co/docs/text-generation-inference/conceptual/guidance) in order to force properly-formatted agent outputs.

You will also need a `tools` argument which accepts a list of `Tools` - it can be an empty list. You can also add the default toolbox on top of your `tools` list by defining the optional argument `add_base_tools=True`.

Now you can create an agent, like [`CodeAgent`], and run it. For convenience, we also provide the [`HfEngine`] class that uses `huggingface_hub.InferenceClient` under the hood.

Expand Down
26 changes: 22 additions & 4 deletions src/transformers/agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,14 +328,15 @@ def __init__(
self,
tools: Union[List[Tool], Toolbox],
llm_engine: Callable = HfEngine(),
system_prompt=DEFAULT_REACT_JSON_SYSTEM_PROMPT,
system_prompt=DEFAULT_REACT_CODE_SYSTEM_PROMPT,
tool_description_template=None,
additional_args={},
max_iterations: int = 6,
tool_parser=parse_json_tool_call,
add_base_tools: bool = False,
verbose: int = 0,
memory_verbose: bool = False,
grammar: Dict[str, str] = None,
):
self.agent_name = self.__class__.__name__
self.llm_engine = llm_engine
Expand All @@ -347,6 +348,7 @@ def __init__(
self.max_iterations = max_iterations
self.logger = logger
self.tool_parser = tool_parser
self.grammar = grammar

if isinstance(tools, Toolbox):
self._toolbox = tools
Expand Down Expand Up @@ -533,6 +535,7 @@ def __init__(
llm_engine: Callable = HfEngine(),
system_prompt: str = DEFAULT_CODE_SYSTEM_PROMPT,
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
grammar: Dict[str, str] = None,
additional_authorized_imports: Optional[List[str]] = None,
**kwargs,
):
Expand All @@ -541,6 +544,7 @@ def __init__(
llm_engine=llm_engine,
system_prompt=system_prompt,
tool_description_template=tool_description_template,
grammar=grammar,
**kwargs,
)

Expand Down Expand Up @@ -599,7 +603,9 @@ def run(self, task: str, return_generated_code: bool = False, **kwargs):
self.prompt = [prompt_message, task_message]
self.logger.info("====Executing with this prompt====")
self.logger.info(self.prompt)
llm_output = self.llm_engine(self.prompt, stop_sequences=["<end_action>"])

additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
llm_output = self.llm_engine(self.prompt, stop_sequences=["<end_action>"], **additional_args)

if return_generated_code:
return llm_output
Expand Down Expand Up @@ -652,6 +658,7 @@ def __init__(
llm_engine: Callable = HfEngine(),
system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT,
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
grammar: Dict[str, str] = None,
plan_type: Literal[tuple(SUPPORTED_PLAN_TYPES)] = SUPPORTED_PLAN_TYPES[0],
planning_interval: Optional[int] = None,
**kwargs,
Expand All @@ -662,6 +669,7 @@ def __init__(
llm_engine=llm_engine,
system_prompt=system_prompt,
tool_description_template=tool_description_template,
grammar=grammar,
**kwargs,
)
self.planning_interval = planning_interval
Expand Down Expand Up @@ -881,6 +889,7 @@ def __init__(
llm_engine: Callable = HfEngine(),
system_prompt: str = DEFAULT_REACT_JSON_SYSTEM_PROMPT,
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
grammar: Dict[str, str] = None,
planning_interval: Optional[int] = None,
**kwargs,
):
Expand All @@ -889,6 +898,7 @@ def __init__(
llm_engine=llm_engine,
system_prompt=system_prompt,
tool_description_template=tool_description_template,
grammar=grammar,
planning_interval=planning_interval,
**kwargs,
)
Expand All @@ -912,7 +922,10 @@ def step(self):
self.logger.info(self.prompt[-1])

try:
llm_output = self.llm_engine(self.prompt, stop_sequences=["<end_action>", "Observation:"])
additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
llm_output = self.llm_engine(
self.prompt, stop_sequences=["<end_action>", "Observation:"], **additional_args
)
except Exception as e:
raise AgentGenerationError(f"Error in generating llm output: {e}.")
self.logger.debug("===== Output message of the LLM: =====")
Expand Down Expand Up @@ -982,6 +995,7 @@ def __init__(
llm_engine: Callable = HfEngine(),
system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT,
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
grammar: Dict[str, str] = None,
additional_authorized_imports: Optional[List[str]] = None,
planning_interval: Optional[int] = None,
**kwargs,
Expand All @@ -991,6 +1005,7 @@ def __init__(
llm_engine=llm_engine,
system_prompt=system_prompt,
tool_description_template=tool_description_template,
grammar=grammar,
planning_interval=planning_interval,
**kwargs,
)
Expand Down Expand Up @@ -1028,7 +1043,10 @@ def step(self):
self.logger.info(self.prompt[-2:])

try:
llm_output = self.llm_engine(self.prompt, stop_sequences=["<end_action>", "Observation:"])
additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
llm_output = self.llm_engine(
self.prompt, stop_sequences=["<end_action>", "Observation:"], **additional_args
)
except Exception as e:
raise AgentGenerationError(f"Error in generating llm output: {e}.")

Expand Down
29 changes: 24 additions & 5 deletions src/transformers/agents/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# limitations under the License.
from copy import deepcopy
from enum import Enum
from typing import Dict, List
from typing import Dict, List, Optional

from huggingface_hub import InferenceClient

Expand Down Expand Up @@ -66,20 +66,39 @@ def get_clean_message_list(message_list: List[Dict[str, str]], role_conversions:


class HfEngine:
def __init__(self, model: str = "meta-llama/Meta-Llama-3-8B-Instruct"):
def __init__(self, model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct"):
self.model = model
self.client = InferenceClient(model=self.model, timeout=120)
self.client = InferenceClient(self.model, timeout=120)

def __call__(self, messages: List[Dict[str, str]], stop_sequences=[]) -> str:
def __call__(
self, messages: List[Dict[str, str]], stop_sequences: List[str] = [], grammar: Optional[str] = None
) -> str:
# Get clean message list
messages = get_clean_message_list(messages, role_conversions=llama_role_conversions)

# Get LLM output
response = self.client.chat_completion(messages, stop=stop_sequences, max_tokens=1500)
if grammar is not None:
response = self.client.chat_completion(
messages, stop=stop_sequences, max_tokens=1500, response_format=grammar
)
else:
response = self.client.chat_completion(messages, stop=stop_sequences, max_tokens=1500)

response = response.choices[0].message.content

# Remove stop sequences from LLM output
for stop_seq in stop_sequences:
if response[-len(stop_seq) :] == stop_seq:
response = response[: -len(stop_seq)]
return response


DEFAULT_JSONAGENT_REGEX_GRAMMAR = {
"type": "regex",
"value": 'Thought: .+?\\nAction:\\n\\{\\n\\s{4}"action":\\s"[^"\\n]+",\\n\\s{4}"action_input":\\s"[^"\\n]+"\\n\\}\\n<end_action>',
}

DEFAULT_CODEAGENT_REGEX_GRAMMAR = {
"type": "regex",
"value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_action>",
}
14 changes: 6 additions & 8 deletions src/transformers/agents/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"):
---
Task: "Answer the question in the variable `question` about the image stored in the variable `image`. The question is in French."
I will use the following tools: `translator` to translate the question into English and then `image_qa` to answer the question on the input image.
Thought: I will use the following tools: `translator` to translate the question into English and then `image_qa` to answer the question on the input image.
Code:
```py
translated_question = translator(question=question, src_lang="French", tgt_lang="English")
Expand All @@ -75,7 +75,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"):
---
Task: "Identify the oldest person in the `document` and create an image showcasing the result."
I will use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer.
Thought: I will use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer.
Code:
```py
answer = document_qa(document, question="What is the oldest person?")
Expand All @@ -87,7 +87,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"):
---
Task: "Generate an image using the text given in the variable `caption`."
I will use the following tool: `image_generator` to generate an image.
Thought: I will use the following tool: `image_generator` to generate an image.
Code:
```py
image = image_generator(prompt=caption)
Expand All @@ -97,7 +97,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"):
---
Task: "Summarize the text given in the variable `text` and read it out loud."
I will use the following tools: `summarizer` to create a summary of the input text, then `text_reader` to read it out loud.
Thought: I will use the following tools: `summarizer` to create a summary of the input text, then `text_reader` to read it out loud.
Code:
```py
summarized_text = summarizer(text)
Expand All @@ -109,7 +109,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"):
---
Task: "Answer the question in the variable `question` about the text in the variable `text`. Use the answer to generate an image."
I will use the following tools: `text_qa` to create the answer, then `image_generator` to generate an image according to the answer.
Thought: I will use the following tools: `text_qa` to create the answer, then `image_generator` to generate an image according to the answer.
Code:
```py
answer = text_qa(text=text, question=question)
Expand All @@ -121,7 +121,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"):
---
Task: "Caption the following `image`."
I will use the following tool: `image_captioner` to generate a caption for the image.
Thought: I will use the following tool: `image_captioner` to generate a caption for the image.
Code:
```py
caption = image_captioner(image)
Expand Down Expand Up @@ -292,7 +292,6 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"):
Observation: "The oldest person in the document is John Doe, a 55 year old lumberjack living in Newfoundland."
Thought: I will now generate an image showcasing the oldest person.
Code:
```py
image = image_generator("A portrait of John Doe, a 55-year-old man living in Canada.")
Expand All @@ -303,7 +302,6 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"):
Task: "What is the result of the following operation: 5 + 3 + 1294.678?"
Thought: I will use python code to compute the result of the operation and then return the final answer using the `final_answer` tool
Code:
```py
result = 5 + 3 + 1294.678
Expand Down
8 changes: 4 additions & 4 deletions tests/agents/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def get_new_path(suffix="") -> str:
return os.path.join(directory, str(uuid.uuid4()) + suffix)


def fake_react_json_llm(messages, stop_sequences=None) -> str:
def fake_react_json_llm(messages, stop_sequences=None, grammar=None) -> str:
prompt = str(messages)

if "special_marker" not in prompt:
Expand All @@ -53,7 +53,7 @@ def fake_react_json_llm(messages, stop_sequences=None) -> str:
"""


def fake_react_code_llm(messages, stop_sequences=None) -> str:
def fake_react_code_llm(messages, stop_sequences=None, grammar=None) -> str:
prompt = str(messages)
if "special_marker" not in prompt:
return """
Expand Down Expand Up @@ -119,7 +119,7 @@ def moving_average(x, w):
"""


def fake_code_llm_oneshot(messages, stop_sequences=None) -> str:
def fake_code_llm_oneshot(messages, stop_sequences=None, grammar=None) -> str:
return """
Thought: I should multiply 2 by 3.6452. special_marker
Code:
Expand All @@ -130,7 +130,7 @@ def fake_code_llm_oneshot(messages, stop_sequences=None) -> str:
"""


def fake_code_llm_no_return(messages, stop_sequences=None) -> str:
def fake_code_llm_no_return(messages, stop_sequences=None, grammar=None) -> str:
return """
Thought: I should multiply 2 by 3.6452. special_marker
Code:
Expand Down

0 comments on commit e0d8253

Please sign in to comment.