Skip to content

Commit

Permalink
lint fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
CalebCourier committed Feb 5, 2024
1 parent 5e5010f commit ec9e3c3
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 15 deletions.
13 changes: 7 additions & 6 deletions guardrails/functional/guard.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from string import Template
from typing import Any, Dict, List, Optional, Tuple, Type, Union, overload

from guardrails.classes.generic.stack import Stack
from guardrails.classes.validation_outcome import ValidationOutcome

from guardrails.guard import Guard as OGuard
from guardrails.utils.safe_get import safe_get
from guardrails.validator_base import Validator
from guardrails.guard import Guard as OGuard


class Guard:
validators: List[Validator]
Expand Down Expand Up @@ -122,10 +123,10 @@ def validate(self, llm_output: str, *args, **kwargs) -> ValidationOutcome[str]:
self.guard = OGuard.from_string(validators=self.validators)

return self.guard.parse(llm_output=llm_output, *args, **kwargs)
def __call__ (self, llm_output: str, *args, **kwargs) -> ValidationOutcome[str]:

def __call__(self, llm_output: str, *args, **kwargs) -> ValidationOutcome[str]:
return self.validate(llm_output, *args, **kwargs)

@property
def history(self):
return self.guard.history if self.guard else Stack()
return self.guard.history if self.guard else Stack()
19 changes: 10 additions & 9 deletions tests/unit_tests/functional/test_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
EndsWith,
LowerCase,
OneLine,
ReadingTime,
TwoWords,
ValidLength,
ReadingTime
)


Expand Down Expand Up @@ -102,6 +102,7 @@ def test_integrate_tuple():
assert guard.validators[4]._kwargs["max"] == 12
assert guard.validators[4].on_fail_descriptor == "refrain" # bc we set it


def test_validate():
guard: Guard = (
Guard()
Expand All @@ -112,19 +113,19 @@ def test_validate():
.add(ValidLength, 0, 12, on_fail="refrain")
)

llm_output = "Oh Canada" # bc it meets our criteria
llm_output = "Oh Canada" # bc it meets our criteria

response = guard.validate(llm_output)

assert response.validation_passed == True
assert response.validation_passed is True
assert response.validated_output == llm_output.lower()
llm_output_2 = "Star Spangled Banner" # to stick with the theme

llm_output_2 = "Star Spangled Banner" # to stick with the theme

response_2 = guard.validate(llm_output_2)

assert response_2.validation_passed == False
assert response_2.validated_output == None
assert response_2.validation_passed is False
assert response_2.validated_output is None


def test_call():
Expand All @@ -138,5 +139,5 @@ def test_call():
(ValidLength, args(0, 12), kwargs(on_fail="refrain")),
)("Oh Canada")

assert response.validation_passed == True
assert response.validated_output == "oh canada"
assert response.validation_passed is True
assert response.validated_output == "oh canada"

0 comments on commit ec9e3c3

Please sign in to comment.