Skip to content

Commit

Permalink
Add support for custom validators
Browse files Browse the repository at this point in the history
  • Loading branch information
jlowin committed Sep 3, 2024
1 parent fdf950a commit b835fd6
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 37 deletions.
2 changes: 2 additions & 0 deletions src/controlflow/agents/names.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
"Deckard",
"HK-47",
"Bender",
"Norbert",
"Norby",
]

TEAMS = [
Expand Down
80 changes: 44 additions & 36 deletions src/controlflow/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,14 @@ class Task(ControlFlowModel):
", generic alias, BaseModel subclass, or list of choices. "
"Can be None if no result is expected or the agent should communicate internally.",
)
result_validator: Optional[Callable] = Field(
None,
description="A function that validates the result. This should be a "
"function that takes the raw result and either returns a validated "
"result or raises an informative error if the result is not valid. The "
"result validator function is called *after* the `result_type` is "
"processed.",
)
tools: list[Callable] = Field(
default_factory=list,
description="Tools available to every agent working on this task.",
Expand Down Expand Up @@ -518,7 +526,7 @@ def mark_successful(self, result: T = None, validate_upstreams: bool = True):
f"are: {', '.join(t.friendly_name() for t in self._subtasks if t.is_incomplete())}"
)

self.result = validate_result(result, self.result_type)
self.result = self.validate_result(result)
self.set_status(TaskStatus.SUCCESSFUL)

def mark_failed(self, reason: Optional[str] = None):
Expand Down Expand Up @@ -623,35 +631,40 @@ def fail(reason: str) -> str:

return fail


def validate_result(result: Any, result_type: type[T]) -> T:
if result_type is None and result is not None:
raise ValueError("Task has result_type=None, but a result was provided.")
elif isinstance(result_type, tuple):
if result not in result_type:
raise ValueError(
f"Result {result} is not in the list of valid result types: {result_type}"
)
elif result_type is not None:
try:
result = TypeAdapter(result_type).validate_python(result)
except PydanticSchemaGenerationError:
if isinstance(result, dict):
result = result_type(**result)
def validate_result(self, raw_result: Any) -> T:
if self.result_type is None and raw_result is not None:
raise ValueError("Task has result_type=None, but a result was provided.")
elif isinstance(self.result_type, tuple):
if raw_result not in self.result_type:
raise ValueError(
f"Result {raw_result} is not in the list of valid result types: {self.result_type}"
)
else:
result = result_type(result)
result = raw_result
elif self.result_type is not None:
try:
result = TypeAdapter(self.result_type).validate_python(raw_result)
except PydanticSchemaGenerationError:
if isinstance(raw_result, dict):
result = self.result_type(**raw_result)
else:
result = self.result_type(raw_result)

# Convert DataFrame schema back into pd.DataFrame object
# if result_type == PandasDataFrame:
# import pandas as pd

# Convert DataFrame schema back into pd.DataFrame object
# if result_type == PandasDataFrame:
# import pandas as pd
# result = pd.DataFrame(**result)
# elif result_type == PandasSeries:
# import pandas as pd

# result = pd.DataFrame(**result)
# elif result_type == PandasSeries:
# import pandas as pd
# result = pd.Series(**result)

# result = pd.Series(**result)
# apply custom validation
if self.result_validator is not None:
result = self.result_validator(result)

return result
return result


def _generate_result_schema(result_type: type[T]) -> type[T]:
Expand Down Expand Up @@ -681,15 +694,13 @@ def run(
max_turns: int = None,
**task_kwargs,
):
task = controlflow.Task(
objective=objective,
return controlflow.run(
objective,
*task_args,
**task_kwargs,
)
return task.run(
turn_strategy=turn_strategy,
max_calls_per_turn=max_calls_per_turn,
max_turns=max_turns,
**task_kwargs,
)


Expand All @@ -701,14 +712,11 @@ async def run_async(
max_turns: int = None,
**task_kwargs,
):
task = controlflow.Task(
objective=objective,
return await controlflow.run_async(
objective,
*task_args,
**task_kwargs,
)

return await task.run_async(
turn_strategy=turn_strategy,
max_calls_per_turn=max_calls_per_turn,
max_turns=max_turns,
**task_kwargs,
)
79 changes: 78 additions & 1 deletion tests/tasks/test_tasks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Annotated
from typing import Annotated, Any

import pytest
from pydantic import BaseModel
Expand Down Expand Up @@ -311,6 +311,83 @@ def test_annotated_result(self):
assert int(task.result)


class TestResultValidator:
def test_result_validator(self):
def validate_even(value: int) -> int:
if value % 2 != 0:
raise ValueError("Value must be even")
return value

task = Task(
"choose an even number", result_type=int, result_validator=validate_even
)
task.mark_successful(result=4)
assert task.result == 4

with pytest.raises(ValueError, match="Value must be even"):
task.mark_successful(result=5)

def test_result_validator_with_constraints(self):
def validate_range(value: int) -> int:
if not 10 <= value <= 20:
raise ValueError("Value must be between 10 and 20")
return value

task = Task("choose a number", result_type=int, result_validator=validate_range)
task.mark_successful(result=15)
assert task.result == 15

with pytest.raises(ValueError, match="Value must be between 10 and 20"):
task.mark_successful(result=5)

def test_result_validator_with_modification(self):
def round_to_nearest_ten(value: int) -> int:
return round(value, -1)

task = Task(
"choose a number", result_type=int, result_validator=round_to_nearest_ten
)
task.mark_successful(result=44)
assert task.result == 40

task.mark_successful(result=46)
assert task.result == 50

def test_result_validator_with_pydantic_model(self):
class User(BaseModel):
name: str
age: int

def validate_adult(user: User) -> User:
if user.age < 18:
raise ValueError("User must be an adult")
return user

task = Task(
"create an adult user", result_type=User, result_validator=validate_adult
)
task.mark_successful(result={"name": "John", "age": 25})
assert task.result == User(name="John", age=25)

with pytest.raises(ValueError, match="User must be an adult"):
task.mark_successful(result={"name": "Jane", "age": 16})

def test_result_validator_applied_after_type_coercion(self):
def always_return_none(value: Any) -> None:
return None

task = Task(
"do something with no result",
result_type=None,
result_validator=always_return_none,
)

with pytest.raises(
ValueError, match="Task has result_type=None, but a result was provided"
):
task.mark_successful(result="anything")


class TestSuccessTool:
def test_success_tool(self):
task = Task("choose 5", result_type=int)
Expand Down

0 comments on commit b835fd6

Please sign in to comment.