Skip to content

Commit

Permalink
type "fixes"
Browse files Browse the repository at this point in the history
  • Loading branch information
CalebCourier committed Feb 5, 2024
1 parent f1c123b commit 28952cd
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions guardrails/functional/chain/guard.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import json
from copy import deepcopy
from typing import Dict, Optional, TypeVar
from typing import Dict, Optional, TypeVar, cast

from langchain_core.messages import BaseMessage
from langchain_core.runnables.base import Runnable, RunnableConfig
from langchain_core.runnables import Runnable, RunnableConfig

from guardrails.errors import ValidationError
from guardrails.functional.guard import Guard as FGuard
Expand All @@ -13,26 +13,31 @@

class Guard(FGuard, Runnable):
def invoke(self, input: T, config: Optional[RunnableConfig] = None) -> T:
output = deepcopy(input)
str_input = input
output = BaseMessage(content="", type="")
str_input = None
input_is_chat_message = False
if isinstance(input, BaseMessage):
input_is_chat_message = True
str_input = input.content
str_input = str(input.content)
output = deepcopy(input)
else:
str_input = str(input)

response = self.validate(str_input)

validated_output = response.validated_output
if not validated_output:
raise ValidationError((
"The response from the LLM failed validation!"
"See `guard.history` for more details."
))
raise ValidationError(
(
"The response from the LLM failed validation!"
"See `guard.history` for more details."
)
)

if isinstance(validated_output, Dict):
validated_output = json.dumps(validated_output)

if input_is_chat_message:
output.content = validated_output
return output
return validated_output
return cast(T, output)
return cast(T, validated_output)

0 comments on commit 28952cd

Please sign in to comment.