diff --git a/src/transformers/tools/agents.py b/src/transformers/tools/agents.py index 66563c3529f1..fdbae381f4ec 100644 --- a/src/transformers/tools/agents.py +++ b/src/transformers/tools/agents.py @@ -48,6 +48,7 @@ BASE_PYTHON_TOOLS = { "print": print, + "range": range, "float": float, "int": int, "bool": bool, diff --git a/src/transformers/tools/python_interpreter.py b/src/transformers/tools/python_interpreter.py index ef366c5d2e77..960be1a2a265 100644 --- a/src/transformers/tools/python_interpreter.py +++ b/src/transformers/tools/python_interpreter.py @@ -110,6 +110,9 @@ def evaluate_ast(expression: ast.AST, state: Dict[str, Any], tools: Dict[str, Ca elif isinstance(expression, ast.Expr): # Expression -> evaluate the content return evaluate_ast(expression.value, state, tools) + elif isinstance(expression, ast.For): + # For loop -> execute the loop + return evaluate_for(expression, state, tools) elif isinstance(expression, ast.FormattedValue): # Formatted value (part of f-string) -> evaluate the content and return return evaluate_ast(expression.value, state, tools) @@ -236,3 +239,15 @@ def evaluate_if(if_statement, state, tools): if line_result is not None: result = line_result return result + + +def evaluate_for(for_loop, state, tools): + result = None + iterator = evaluate_ast(for_loop.iter, state, tools) + for counter in iterator: + state[for_loop.target.id] = counter + for expression in for_loop.body: + line_result = evaluate_ast(expression, state, tools) + if line_result is not None: + result = line_result + return result diff --git a/tests/tools/test_python_interpreter.py b/tests/tools/test_python_interpreter.py index 5dfdd8d9c37c..b9a38b4a21f1 100644 --- a/tests/tools/test_python_interpreter.py +++ b/tests/tools/test_python_interpreter.py @@ -122,3 +122,10 @@ def test_evaluate_subscript(self): result = evaluate(code, {"add_two": add_two}, state=state) assert result == 5 self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}}) + + def test_evaluate_for(self): + code = "x = 0\nfor i in range(3):\n x = i" + state = {} + result = evaluate(code, {"range": range}, state=state) + assert result == 2 + self.assertDictEqual(state, {"x": 2, "i": 2})