From ec9e3c3d58ca11db6a9f93bebb4f19c182fcdb1e Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Mon, 5 Feb 2024 08:44:04 -0800 Subject: [PATCH] lint fixes --- guardrails/functional/guard.py | 13 +++++++------ tests/unit_tests/functional/test_guard.py | 19 ++++++++++--------- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/guardrails/functional/guard.py b/guardrails/functional/guard.py index 12c540096..0bdab6d03 100644 --- a/guardrails/functional/guard.py +++ b/guardrails/functional/guard.py @@ -1,11 +1,12 @@ from string import Template from typing import Any, Dict, List, Optional, Tuple, Type, Union, overload + from guardrails.classes.generic.stack import Stack from guardrails.classes.validation_outcome import ValidationOutcome - +from guardrails.guard import Guard as OGuard from guardrails.utils.safe_get import safe_get from guardrails.validator_base import Validator -from guardrails.guard import Guard as OGuard + class Guard: validators: List[Validator] @@ -122,10 +123,10 @@ def validate(self, llm_output: str, *args, **kwargs) -> ValidationOutcome[str]: self.guard = OGuard.from_string(validators=self.validators) return self.guard.parse(llm_output=llm_output, *args, **kwargs) - - def __call__ (self, llm_output: str, *args, **kwargs) -> ValidationOutcome[str]: + + def __call__(self, llm_output: str, *args, **kwargs) -> ValidationOutcome[str]: return self.validate(llm_output, *args, **kwargs) - + @property def history(self): - return self.guard.history if self.guard else Stack() \ No newline at end of file + return self.guard.history if self.guard else Stack() diff --git a/tests/unit_tests/functional/test_guard.py b/tests/unit_tests/functional/test_guard.py index ea869910c..d1a15efa2 100644 --- a/tests/unit_tests/functional/test_guard.py +++ b/tests/unit_tests/functional/test_guard.py @@ -3,9 +3,9 @@ EndsWith, LowerCase, OneLine, + ReadingTime, TwoWords, ValidLength, - ReadingTime ) @@ -102,6 +102,7 @@ def test_integrate_tuple(): assert guard.validators[4]._kwargs["max"] == 12 assert guard.validators[4].on_fail_descriptor == "refrain" # bc we set it + def test_validate(): guard: Guard = ( Guard() @@ -112,19 +113,19 @@ def test_validate(): .add(ValidLength, 0, 12, on_fail="refrain") ) - llm_output = "Oh Canada" # bc it meets our criteria + llm_output = "Oh Canada" # bc it meets our criteria response = guard.validate(llm_output) - assert response.validation_passed == True + assert response.validation_passed is True assert response.validated_output == llm_output.lower() - - llm_output_2 = "Star Spangled Banner" # to stick with the theme + + llm_output_2 = "Star Spangled Banner" # to stick with the theme response_2 = guard.validate(llm_output_2) - assert response_2.validation_passed == False - assert response_2.validated_output == None + assert response_2.validation_passed is False + assert response_2.validated_output is None def test_call(): @@ -138,5 +139,5 @@ def test_call(): (ValidLength, args(0, 12), kwargs(on_fail="refrain")), )("Oh Canada") - assert response.validation_passed == True - assert response.validated_output == "oh canada" \ No newline at end of file + assert response.validation_passed is True + assert response.validated_output == "oh canada"