diff --git a/src/transformers/agents/agent_types.py b/src/transformers/agents/agent_types.py index 87255dc7dec98a..0b4999b7f76d3c 100644 --- a/src/transformers/agents/agent_types.py +++ b/src/transformers/agents/agent_types.py @@ -188,7 +188,7 @@ def __init__(self, value, samplerate=16_000): self.samplerate = samplerate if isinstance(value, (str, pathlib.Path)): self._path = value - elif isinstance(value, torch.Tensor): + elif is_torch_available() and isinstance(value, torch.Tensor): self._tensor = value elif isinstance(value, tuple): self.samplerate = value[0] @@ -232,7 +232,10 @@ def to_string(self): AGENT_TYPE_MAPPING = {"text": AgentText, "image": AgentImage, "audio": AgentAudio} -INSTANCE_TYPE_MAPPING = {str: AgentText, float: AgentText, int: AgentText, Tensor: AgentAudio, ImageType: AgentImage} +INSTANCE_TYPE_MAPPING = {str: AgentText, ImageType: AgentImage} + +if is_torch_available(): + INSTANCE_TYPE_MAPPING[Tensor] = AgentAudio def handle_agent_inputs(*args, **kwargs): @@ -251,4 +254,4 @@ def handle_agent_outputs(output, output_type=None): for _k, _v in INSTANCE_TYPE_MAPPING.items(): if isinstance(output, _k): return _v(output) - return AgentType(output) + return output diff --git a/src/transformers/agents/agents.py b/src/transformers/agents/agents.py index d1756d2e9e3d4f..1ddfb6b4174777 100644 --- a/src/transformers/agents/agents.py +++ b/src/transformers/agents/agents.py @@ -856,6 +856,10 @@ def __init__( self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else [] self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports)) self.system_prompt = self.system_prompt.replace("<>", str(self.authorized_imports)) + self.available_tools = { + **BASE_PYTHON_TOOLS.copy(), + **self.toolbox.tools, + } # This list can be augmented by the code agent creating some new functions def step(self): """ @@ -905,10 +909,9 @@ def step(self): # Execute self.log_code_action(code_action) try: - available_tools = {**BASE_PYTHON_TOOLS.copy(), **self.toolbox.tools} result = self.python_evaluator( code_action, - available_tools, + tools=self.available_tools, state=self.state, authorized_imports=self.authorized_imports, ) diff --git a/src/transformers/agents/python_interpreter.py b/src/transformers/agents/python_interpreter.py index 04f62a8acfb959..1235bb95c3ae02 100644 --- a/src/transformers/agents/python_interpreter.py +++ b/src/transformers/agents/python_interpreter.py @@ -778,7 +778,10 @@ def evaluate_ast( def evaluate_python_code( - code: str, tools: Optional[Dict[str, Callable]] = {}, state=None, authorized_imports: List[str] = LIST_SAFE_MODULES + code: str, + tools: Optional[Dict[str, Callable]] = None, + state: Optional[Dict[str, Any]] = None, + authorized_imports: List[str] = LIST_SAFE_MODULES, ): """ Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set @@ -803,6 +806,8 @@ def evaluate_python_code( raise SyntaxError(f"The code generated by the agent is not valid.\n{e}") if state is None: state = {} + if tools is None: + tools = {} result = None global PRINT_OUTPUTS PRINT_OUTPUTS = "" diff --git a/tests/agents/test_agents.py b/tests/agents/test_agents.py index 062b98abd47350..5bdaea1651b741 100644 --- a/tests/agents/test_agents.py +++ b/tests/agents/test_agents.py @@ -94,12 +94,48 @@ def fake_react_code_llm_error(messages, stop_sequences=None) -> str: """ +def fake_react_code_functiondef(messages, stop_sequences=None) -> str: + prompt = str(messages) + if "special_marker" not in prompt: + return """ +Thought: Let's define the function. special_marker +Code: +```py +import numpy as np + +def moving_average(x, w): + return np.convolve(x, np.ones(w), 'valid') / w +``` +""" + else: # We're at step 2 + return """ +Thought: I can now answer the initial question +Code: +```py +x, w = [0, 1, 2, 3, 4, 5], 2 +res = moving_average(x, w) +final_answer(res) +``` +""" + + def fake_code_llm_oneshot(messages, stop_sequences=None) -> str: return """ Thought: I should multiply 2 by 3.6452. special_marker Code: ```py result = python_interpreter(code="2*3.6452") +final_answer(result) +``` +""" + + +def fake_code_llm_no_return(messages, stop_sequences=None) -> str: + return """ +Thought: I should multiply 2 by 3.6452. special_marker +Code: +```py +result = python_interpreter(code="2*3.6452") print(result) ``` """ @@ -135,8 +171,8 @@ def test_fake_react_json_agent(self): def test_fake_react_code_agent(self): agent = ReactCodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm) output = agent.run("What is 2 multiplied by 3.6452?") - assert isinstance(output, AgentText) - assert output == "7.2904" + assert isinstance(output, float) + assert output == 7.2904 assert agent.logs[0]["task"] == "What is 2 multiplied by 3.6452?" assert float(agent.logs[1]["observation"].strip()) - 12.511648 < 1e-6 assert agent.logs[2]["tool_call"] == { @@ -157,7 +193,7 @@ def test_setup_agent_with_empty_toolbox(self): def test_react_fails_max_iterations(self): agent = ReactCodeAgent( tools=[PythonInterpreterTool()], - llm_engine=fake_code_llm_oneshot, # use this callable because it never ends + llm_engine=fake_code_llm_no_return, # use this callable because it never ends max_iterations=5, ) agent.run("What is 2 multiplied by 3.6452?") @@ -192,3 +228,10 @@ def test_init_agent_with_different_toolsets(self): # check that python_interpreter base tool does not get added to code agents agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True) assert len(agent.toolbox.tools) == 6 # added final_answer tool + 5 base tools (excluding interpreter) + + def test_function_persistence_across_steps(self): + agent = ReactCodeAgent( + tools=[], llm_engine=fake_react_code_functiondef, max_iterations=2, additional_authorized_imports=["numpy"] + ) + res = agent.run("ok") + assert res[0] == 0.5 diff --git a/tests/agents/test_python_interpreter.py b/tests/agents/test_python_interpreter.py index 6f5907e27be1f0..8843a394b35313 100644 --- a/tests/agents/test_python_interpreter.py +++ b/tests/agents/test_python_interpreter.py @@ -660,7 +660,6 @@ def add_one(n, shift): """ state = {} result = evaluate_python_code(code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state) - print(state) assert result == 2 # test returning None @@ -672,5 +671,4 @@ def returns_none(a): """ state = {} result = evaluate_python_code(code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state) - print(state) assert result is None