Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/concepts/tasks.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -322,9 +322,9 @@ blog_task = Task(
- On success: it returns a tuple of `(bool, Any)`. For example: `(True, validated_result)`
- On Failure: it returns a tuple of `(bool, str)`. For example: `(False, "Error message explain the failure")`

### TaskGuardrail
### LLMGuardrail

The `TaskGuardrail` class offers a robust mechanism for validating task outputs
The `LLMGuardrail` class offers a robust mechanism for validating task outputs.

### Error Handling Best Practices

Expand Down Expand Up @@ -819,7 +819,7 @@ from crewai.llm import LLM
task = Task(
description="Generate JSON data",
expected_output="Valid JSON object",
guardrail=TaskGuardrail(
guardrail=LLMGuardrail(
description="Ensure the response is a valid JSON object",
llm=LLM(model="gpt-4o-mini"),
)
Expand Down
12 changes: 6 additions & 6 deletions src/crewai/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,10 +241,10 @@ def ensure_guardrail_is_callable(self) -> "Task":
if callable(self.guardrail):
self._guardrail = self.guardrail
elif isinstance(self.guardrail, str):
from crewai.tasks.task_guardrail import TaskGuardrail
from crewai.tasks.llm_guardrail import LLMGuardrail

assert self.agent is not None
self._guardrail = TaskGuardrail(
self._guardrail = LLMGuardrail(
description=self.guardrail, llm=self.agent.llm
)

Expand Down Expand Up @@ -494,16 +494,16 @@ def _process_guardrail(self, task_output: TaskOutput) -> GuardrailResult:
assert self._guardrail is not None

from crewai.utilities.events import (
TaskGuardrailCompletedEvent,
TaskGuardrailStartedEvent,
LLMGuardrailCompletedEvent,
LLMGuardrailStartedEvent,
)
from crewai.utilities.events.crewai_event_bus import crewai_event_bus

result = self._guardrail(task_output)

crewai_event_bus.emit(
self,
TaskGuardrailStartedEvent(
LLMGuardrailStartedEvent(
guardrail=self._guardrail, retry_count=self.retry_count
),
)
Expand All @@ -512,7 +512,7 @@ def _process_guardrail(self, task_output: TaskOutput) -> GuardrailResult:

crewai_event_bus.emit(
self,
TaskGuardrailCompletedEvent(
LLMGuardrailCompletedEvent(
success=guardrail_result.success,
result=guardrail_result.result,
error=guardrail_result.error,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from crewai.tasks.task_output import TaskOutput


class TaskGuardrailResult(BaseModel):
class LLMGuardrailResult(BaseModel):
valid: bool = Field(
description="Whether the task output complies with the guardrail"
)
Expand All @@ -18,7 +18,7 @@ class TaskGuardrailResult(BaseModel):
)


class TaskGuardrail:
class LLMGuardrail:
"""It validates the output of another task using an LLM.

This class is used to validate the output from a Task based on specified criteria.
Expand Down Expand Up @@ -62,7 +62,7 @@ def _validate_output(self, task_output: TaskOutput) -> LiteAgentOutput:
- If the Task result complies with the guardrail, saying that is valid
"""

result = agent.kickoff(query, response_format=TaskGuardrailResult)
result = agent.kickoff(query, response_format=LLMGuardrailResult)

return result

Expand All @@ -81,7 +81,7 @@ def __call__(self, task_output: TaskOutput) -> Tuple[bool, Any]:
try:
result = self._validate_output(task_output)
assert isinstance(
result.pydantic, TaskGuardrailResult
result.pydantic, LLMGuardrailResult
), "The guardrail result is not a valid pydantic model"

if result.pydantic.valid:
Expand Down
6 changes: 3 additions & 3 deletions src/crewai/utilities/events/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
CrewTestCompletedEvent,
CrewTestFailedEvent,
)
from .task_guardrail_events import (
TaskGuardrailCompletedEvent,
TaskGuardrailStartedEvent,
from .llm_guardrail_events import (
LLMGuardrailCompletedEvent,
LLMGuardrailStartedEvent,
)
from .agent_events import (
AgentExecutionStartedEvent,
Expand Down
12 changes: 6 additions & 6 deletions src/crewai/utilities/events/event_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@
LLMCallStartedEvent,
LLMStreamChunkEvent,
)
from .llm_guardrail_events import (
LLMGuardrailCompletedEvent,
LLMGuardrailStartedEvent,
)
from .task_events import (
TaskCompletedEvent,
TaskFailedEvent,
TaskStartedEvent,
)
from .task_guardrail_events import (
TaskGuardrailCompletedEvent,
TaskGuardrailStartedEvent,
)
from .tool_usage_events import (
ToolUsageErrorEvent,
ToolUsageFinishedEvent,
Expand Down Expand Up @@ -72,6 +72,6 @@
LLMCallCompletedEvent,
LLMCallFailedEvent,
LLMStreamChunkEvent,
TaskGuardrailStartedEvent,
TaskGuardrailCompletedEvent,
LLMGuardrailStartedEvent,
LLMGuardrailCompletedEvent,
]
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,35 @@
from crewai.utilities.events.base_events import BaseEvent


class TaskGuardrailStartedEvent(BaseEvent):
class LLMGuardrailStartedEvent(BaseEvent):
"""Event emitted when a guardrail task starts

Attributes:
guardrail: The guardrail callable or TaskGuardrail instance
guardrail: The guardrail callable or LLMGuardrail instance
retry_count: The number of times the guardrail has been retried
"""

type: str = "task_guardrail_started"
type: str = "llm_guardrail_started"
guardrail: Union[str, Callable]
retry_count: int

def __init__(self, **data):
from inspect import getsource

from crewai.tasks.task_guardrail import TaskGuardrail
from crewai.tasks.llm_guardrail import LLMGuardrail

super().__init__(**data)

if isinstance(self.guardrail, TaskGuardrail):
if isinstance(self.guardrail, LLMGuardrail):
self.guardrail = self.guardrail.description.strip()
elif isinstance(self.guardrail, Callable):
self.guardrail = getsource(self.guardrail).strip()


class TaskGuardrailCompletedEvent(BaseEvent):
class LLMGuardrailCompletedEvent(BaseEvent):
"""Event emitted when a guardrail task completes"""

type: str = "task_guardrail_completed"
type: str = "llm_guardrail_completed"
success: bool
result: Any
error: Optional[str] = None
Expand Down
14 changes: 7 additions & 7 deletions tests/test_task_guardrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

from crewai import Agent, Task
from crewai.llm import LLM
from crewai.tasks.task_guardrail import TaskGuardrail
from crewai.tasks.llm_guardrail import LLMGuardrail
from crewai.tasks.task_output import TaskOutput
from crewai.utilities.events import (
TaskGuardrailCompletedEvent,
TaskGuardrailStartedEvent,
LLMGuardrailCompletedEvent,
LLMGuardrailStartedEvent,
)
from crewai.utilities.events.crewai_event_bus import crewai_event_bus

Expand Down Expand Up @@ -153,7 +153,7 @@ def task_output():

@pytest.mark.vcr(filter_headers=["authorization"])
def test_task_guardrail_process_output(task_output):
guardrail = TaskGuardrail(
guardrail = LLMGuardrail(
description="Ensure the result has less than 10 words", llm=LLM(model="gpt-4o")
)

Expand All @@ -162,7 +162,7 @@ def test_task_guardrail_process_output(task_output):

assert "exceeding the guardrail limit of fewer than" in result[1].lower()

guardrail = TaskGuardrail(
guardrail = LLMGuardrail(
description="Ensure the result has less than 500 words", llm=LLM(model="gpt-4o")
)

Expand All @@ -178,13 +178,13 @@ def test_guardrail_emits_events(sample_agent):

with crewai_event_bus.scoped_handlers():

@crewai_event_bus.on(TaskGuardrailStartedEvent)
@crewai_event_bus.on(LLMGuardrailStartedEvent)
def handle_guardrail_started(source, event):
started_guardrail.append(
{"guardrail": event.guardrail, "retry_count": event.retry_count}
)

@crewai_event_bus.on(TaskGuardrailCompletedEvent)
@crewai_event_bus.on(LLMGuardrailCompletedEvent)
def handle_guardrail_completed(source, event):
completed_guardrail.append(
{
Expand Down
Loading