From b835fd66fafae908056aa4430e0b51305c4950a7 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Mon, 2 Sep 2024 20:19:00 -0400 Subject: [PATCH] Add support for custom validators --- src/controlflow/agents/names.py | 2 + src/controlflow/tasks/task.py | 80 ++++++++++++++++++--------------- tests/tasks/test_tasks.py | 79 +++++++++++++++++++++++++++++++- 3 files changed, 124 insertions(+), 37 deletions(-) diff --git a/src/controlflow/agents/names.py b/src/controlflow/agents/names.py index fc8ea12e..5ca039a6 100644 --- a/src/controlflow/agents/names.py +++ b/src/controlflow/agents/names.py @@ -17,6 +17,8 @@ "Deckard", "HK-47", "Bender", + "Norbert", + "Norby", ] TEAMS = [ diff --git a/src/controlflow/tasks/task.py b/src/controlflow/tasks/task.py index eeb5fe2b..ce3df90a 100644 --- a/src/controlflow/tasks/task.py +++ b/src/controlflow/tasks/task.py @@ -106,6 +106,14 @@ class Task(ControlFlowModel): ", generic alias, BaseModel subclass, or list of choices. " "Can be None if no result is expected or the agent should communicate internally.", ) + result_validator: Optional[Callable] = Field( + None, + description="A function that validates the result. This should be a " + "function that takes the raw result and either returns a validated " + "result or raises an informative error if the result is not valid. The " + "result validator function is called *after* the `result_type` is " + "processed.", + ) tools: list[Callable] = Field( default_factory=list, description="Tools available to every agent working on this task.", @@ -518,7 +526,7 @@ def mark_successful(self, result: T = None, validate_upstreams: bool = True): f"are: {', '.join(t.friendly_name() for t in self._subtasks if t.is_incomplete())}" ) - self.result = validate_result(result, self.result_type) + self.result = self.validate_result(result) self.set_status(TaskStatus.SUCCESSFUL) def mark_failed(self, reason: Optional[str] = None): @@ -623,35 +631,40 @@ def fail(reason: str) -> str: return fail - -def validate_result(result: Any, result_type: type[T]) -> T: - if result_type is None and result is not None: - raise ValueError("Task has result_type=None, but a result was provided.") - elif isinstance(result_type, tuple): - if result not in result_type: - raise ValueError( - f"Result {result} is not in the list of valid result types: {result_type}" - ) - elif result_type is not None: - try: - result = TypeAdapter(result_type).validate_python(result) - except PydanticSchemaGenerationError: - if isinstance(result, dict): - result = result_type(**result) + def validate_result(self, raw_result: Any) -> T: + if self.result_type is None and raw_result is not None: + raise ValueError("Task has result_type=None, but a result was provided.") + elif isinstance(self.result_type, tuple): + if raw_result not in self.result_type: + raise ValueError( + f"Result {raw_result} is not in the list of valid result types: {self.result_type}" + ) else: - result = result_type(result) + result = raw_result + elif self.result_type is not None: + try: + result = TypeAdapter(self.result_type).validate_python(raw_result) + except PydanticSchemaGenerationError: + if isinstance(raw_result, dict): + result = self.result_type(**raw_result) + else: + result = self.result_type(raw_result) + + # Convert DataFrame schema back into pd.DataFrame object + # if result_type == PandasDataFrame: + # import pandas as pd - # Convert DataFrame schema back into pd.DataFrame object - # if result_type == PandasDataFrame: - # import pandas as pd + # result = pd.DataFrame(**result) + # elif result_type == PandasSeries: + # import pandas as pd - # result = pd.DataFrame(**result) - # elif result_type == PandasSeries: - # import pandas as pd + # result = pd.Series(**result) - # result = pd.Series(**result) + # apply custom validation + if self.result_validator is not None: + result = self.result_validator(result) - return result + return result def _generate_result_schema(result_type: type[T]) -> type[T]: @@ -681,15 +694,13 @@ def run( max_turns: int = None, **task_kwargs, ): - task = controlflow.Task( - objective=objective, + return controlflow.run( + objective, *task_args, - **task_kwargs, - ) - return task.run( turn_strategy=turn_strategy, max_calls_per_turn=max_calls_per_turn, max_turns=max_turns, + **task_kwargs, ) @@ -701,14 +712,11 @@ async def run_async( max_turns: int = None, **task_kwargs, ): - task = controlflow.Task( - objective=objective, + return await controlflow.run_async( + objective, *task_args, - **task_kwargs, - ) - - return await task.run_async( turn_strategy=turn_strategy, max_calls_per_turn=max_calls_per_turn, max_turns=max_turns, + **task_kwargs, ) diff --git a/tests/tasks/test_tasks.py b/tests/tasks/test_tasks.py index 6ed6786d..773bb668 100644 --- a/tests/tasks/test_tasks.py +++ b/tests/tasks/test_tasks.py @@ -1,4 +1,4 @@ -from typing import Annotated +from typing import Annotated, Any import pytest from pydantic import BaseModel @@ -311,6 +311,83 @@ def test_annotated_result(self): assert int(task.result) +class TestResultValidator: + def test_result_validator(self): + def validate_even(value: int) -> int: + if value % 2 != 0: + raise ValueError("Value must be even") + return value + + task = Task( + "choose an even number", result_type=int, result_validator=validate_even + ) + task.mark_successful(result=4) + assert task.result == 4 + + with pytest.raises(ValueError, match="Value must be even"): + task.mark_successful(result=5) + + def test_result_validator_with_constraints(self): + def validate_range(value: int) -> int: + if not 10 <= value <= 20: + raise ValueError("Value must be between 10 and 20") + return value + + task = Task("choose a number", result_type=int, result_validator=validate_range) + task.mark_successful(result=15) + assert task.result == 15 + + with pytest.raises(ValueError, match="Value must be between 10 and 20"): + task.mark_successful(result=5) + + def test_result_validator_with_modification(self): + def round_to_nearest_ten(value: int) -> int: + return round(value, -1) + + task = Task( + "choose a number", result_type=int, result_validator=round_to_nearest_ten + ) + task.mark_successful(result=44) + assert task.result == 40 + + task.mark_successful(result=46) + assert task.result == 50 + + def test_result_validator_with_pydantic_model(self): + class User(BaseModel): + name: str + age: int + + def validate_adult(user: User) -> User: + if user.age < 18: + raise ValueError("User must be an adult") + return user + + task = Task( + "create an adult user", result_type=User, result_validator=validate_adult + ) + task.mark_successful(result={"name": "John", "age": 25}) + assert task.result == User(name="John", age=25) + + with pytest.raises(ValueError, match="User must be an adult"): + task.mark_successful(result={"name": "Jane", "age": 16}) + + def test_result_validator_applied_after_type_coercion(self): + def always_return_none(value: Any) -> None: + return None + + task = Task( + "do something with no result", + result_type=None, + result_validator=always_return_none, + ) + + with pytest.raises( + ValueError, match="Task has result_type=None, but a result was provided" + ): + task.mark_successful(result="anything") + + class TestSuccessTool: def test_success_tool(self): task = Task("choose 5", result_type=int)