Skip to content

Commit 6798290

Browse files
committed
feat: add GuardrailsAI community integration (#1298)
1 parent 51459d0 commit 6798290

File tree

12 files changed

+1809
-0
lines changed

12 files changed

+1809
-0
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
models:
2+
- type: main
3+
engine: openai
4+
model: gpt-4
5+
6+
rails:
7+
config:
8+
guardrails_ai:
9+
validators:
10+
- name: toxic_language
11+
parameters:
12+
threshold: 0.5
13+
validation_method: "sentence"
14+
metadata: {}
15+
- name: guardrails_pii
16+
parameters:
17+
entities: ["phone_number", "email", "ssn"]
18+
metadata: {}
19+
- name: competitor_check
20+
parameters:
21+
competitors: ["Apple", "Google", "Microsoft"]
22+
metadata: {}
23+
- name: restricttotopic
24+
parameters:
25+
valid_topics: ["technology", "science", "education"]
26+
metadata: {}
27+
input:
28+
flows:
29+
- guardrailsai check input $validator="guardrails_pii"
30+
- guardrailsai check input $validator="competitor_check"
31+
output:
32+
flows:
33+
- guardrailsai check output $validator="restricttotopic"
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Dynamic validator loading for Guardrails AI integration."""
17+
18+
import importlib
19+
import logging
20+
from functools import lru_cache
21+
from typing import Any, Dict, Optional, Type
22+
23+
try:
24+
from guardrails import Guard
25+
except ImportError:
26+
# Mock Guard class for when guardrails is not available
27+
class Guard:
28+
def __init__(self):
29+
pass
30+
31+
def use(self, validator):
32+
return self
33+
34+
def validate(self, text, metadata=None):
35+
return None
36+
37+
38+
from nemoguardrails.actions import action
39+
from nemoguardrails.library.guardrails_ai.errors import GuardrailsAIValidationError
40+
from nemoguardrails.library.guardrails_ai.registry import get_validator_info
41+
from nemoguardrails.rails.llm.config import RailsConfig
42+
43+
log = logging.getLogger(__name__)
44+
45+
46+
# cache for loaded validator classes and guard instances
47+
_validator_class_cache: Dict[str, Type] = {}
48+
_guard_cache: Dict[tuple, Guard] = {}
49+
50+
51+
def guardrails_ai_validation_mapping(result: Dict[str, Any]) -> bool:
52+
"""Map Guardrails AI validation result to NeMo Guardrails format."""
53+
# The Guardrails AI `validate` method returns a ValidationResult object.
54+
# On failure (PII found, Jailbreak detected, etc.), it's often a FailResult.
55+
# Both PassResult and FailResult have a `validation_passed` boolean attribute
56+
# which indicates if the validation criteria were met.
57+
# FailResult also often contains `fixed_value` if a fix like anonymization was applied.
58+
# We map `validation_passed=False` to `True` (block) and `validation_passed=True` to `False` (don't block).
59+
validation_result = result.get("validation_result", {})
60+
61+
# Handle both dict and object formats
62+
if hasattr(validation_result, "validation_passed"):
63+
valid = validation_result.validation_passed
64+
else:
65+
valid = validation_result.get("validation_passed", False)
66+
67+
return valid # {"valid": valid, "validation_result": validation_result}
68+
69+
70+
# TODO: we need to do this
71+
# from guardrails.hub import RegexMatch, ValidLength
72+
# from guardrails import Guard
73+
#
74+
# guard = Guard().use_many(
75+
# RegexMatch(regex="^[A-Z][a-z]*$"),
76+
# ValidLength(min=1, max=12)
77+
# )
78+
#
79+
# print(guard.parse("Caesar").validation_passed) # Guardrail Passes
80+
# print(
81+
# guard.parse("Caesar Salad")
82+
# .validation_passed
83+
# ) # Guardrail Fails due to regex match
84+
# print(
85+
# guard.parse("Caesarisagreatleader")
86+
# .validation_passed
87+
# ) # Guardrail Fails due to length
88+
89+
90+
@action(
91+
name="validate_guardrails_ai_input",
92+
output_mapping=guardrails_ai_validation_mapping,
93+
is_system_action=False,
94+
)
95+
def validate_guardrails_ai_input(
96+
validator: str,
97+
config: RailsConfig,
98+
context: Optional[dict] = None,
99+
text: Optional[str] = None,
100+
**kwargs,
101+
) -> Dict[str, Any]:
102+
"""Unified action for all Guardrails AI validators.
103+
104+
Args:
105+
validator: Name of the validator to use (from VALIDATOR_REGISTRY)
106+
text: Text to validate
107+
context: Optional context dictionary
108+
109+
Returns:
110+
Dict with validation_result
111+
"""
112+
113+
text = text or context.get("user_message", "")
114+
if not text:
115+
raise ValueError("Either 'text' or 'context' must be provided.")
116+
117+
validator_config = config.rails.config.guardrails_ai.get_validator_config(validator)
118+
parameters = validator_config.parameters or {}
119+
metadata = validator_config.metadata or {}
120+
121+
joined_parameters = {**parameters, **metadata}
122+
123+
validation_result = validate_guardrails_ai(validator, text, **joined_parameters)
124+
125+
# Transform to the expected format for Colang flows
126+
return validation_result
127+
128+
129+
@action(
130+
name="validate_guardrails_ai_output",
131+
output_mapping=guardrails_ai_validation_mapping,
132+
is_system_action=False,
133+
)
134+
def validate_guardrails_ai_output(
135+
validator: str,
136+
context: Optional[dict] = None,
137+
text: Optional[str] = None,
138+
config: Optional[RailsConfig] = None,
139+
**kwargs,
140+
) -> Dict[str, Any]:
141+
"""Unified action for all Guardrails AI validators.
142+
143+
Args:
144+
validator: Name of the validator to use (from VALIDATOR_REGISTRY)
145+
text: Text to validate
146+
context: Optional context dictionary
147+
148+
Returns:
149+
Dict with validation_result
150+
"""
151+
152+
text = text or context.get("bot_message", "")
153+
if not text:
154+
raise ValueError("Either 'text' or 'context' must be provided.")
155+
156+
validator_config = config.rails.config.guardrails_ai.get_validator_config(validator)
157+
parameters = validator_config.parameters or {}
158+
metadata = validator_config.metadata or {}
159+
160+
# join parameters and metadata into a single dict
161+
joined_parameters = {**parameters, **metadata}
162+
163+
validation_result = validate_guardrails_ai(validator, text, **joined_parameters)
164+
165+
return validation_result
166+
167+
168+
def validate_guardrails_ai(validator_name: str, text: str, **kwargs) -> Dict[str, Any]:
169+
"""Unified action for all Guardrails AI validators.
170+
171+
Args:
172+
validator: Name of the validator to use (from VALIDATOR_REGISTRY)
173+
text: Text to validate
174+
175+
176+
Returns:
177+
Dict with validation_result
178+
"""
179+
180+
try:
181+
# extract metadata if provided as a dict
182+
183+
metadata = kwargs.pop("metadata", {})
184+
validator_params = kwargs
185+
186+
validator_params = {k: v for k, v in validator_params.items() if v is not None}
187+
188+
# get or create the guard with all non-metadata params
189+
guard = _get_guard(validator_name, **validator_params)
190+
191+
try:
192+
validation_result = guard.validate(text, metadata=metadata)
193+
return {"validation_result": validation_result}
194+
except GuardrailsAIValidationError as e:
195+
# handle Guardrails validation errors (when on_fail="exception")
196+
# return a failed validation result instead of raising
197+
log.warning(f"Guardrails validation failed for {validator_name}: {str(e)}")
198+
199+
# create a mock validation result for failed validations
200+
class FailedValidation:
201+
validation_passed = False
202+
error = str(e)
203+
204+
return {"validation_result": FailedValidation()}
205+
206+
except Exception as e:
207+
log.error(f"Error validating with {validator_name}: {str(e)}")
208+
raise GuardrailsAIValidationError(f"Validation failed: {str(e)}")
209+
210+
211+
@lru_cache(maxsize=None)
212+
def _load_validator_class(validator_name: str) -> Type:
213+
"""Dynamically load a validator class."""
214+
cache_key = f"class_{validator_name}"
215+
216+
if cache_key in _validator_class_cache:
217+
return _validator_class_cache[cache_key]
218+
219+
try:
220+
validator_info = get_validator_info(validator_name)
221+
222+
module_name = validator_info["module"]
223+
class_name = validator_info["class"]
224+
225+
try:
226+
module = importlib.import_module(module_name)
227+
validator_class = getattr(module, class_name)
228+
_validator_class_cache[cache_key] = validator_class
229+
return validator_class
230+
except (ImportError, AttributeError):
231+
log.warning(
232+
f"Could not import {class_name} from {module_name}. "
233+
f"Make sure to install it first: guardrails hub install {validator_info['hub_path']}"
234+
)
235+
raise ImportError(
236+
f"Validator {validator_name} not installed. "
237+
f"Install with: guardrails hub install {validator_info['hub_path']}"
238+
)
239+
240+
except Exception as e:
241+
raise ImportError(f"Failed to load validator {validator_name}: {str(e)}")
242+
243+
244+
def _get_guard(validator_name: str, **validator_params) -> Guard:
245+
"""Get or create a Guard instance for a validator."""
246+
247+
# create a hashable cache key
248+
def make_hashable(obj):
249+
if isinstance(obj, list):
250+
return tuple(obj)
251+
elif isinstance(obj, dict):
252+
return tuple(sorted((k, make_hashable(v)) for k, v in obj.items()))
253+
return obj
254+
255+
cache_items = [(k, make_hashable(v)) for k, v in validator_params.items()]
256+
cache_key = (validator_name, tuple(sorted(cache_items)))
257+
258+
if cache_key not in _guard_cache:
259+
validator_class = _load_validator_class(validator_name)
260+
261+
# TODO(@zayd): is this needed?
262+
# default handling for all validators
263+
if "on_fail" not in validator_params:
264+
validator_params["on_fail"] = "noop"
265+
266+
try:
267+
validator_instance = validator_class(**validator_params)
268+
except TypeError as e:
269+
log.error(
270+
f"Failed to instantiate {validator_name} with params {validator_params}: {str(e)}"
271+
)
272+
raise
273+
274+
guard = Guard().use(validator_instance)
275+
_guard_cache[cache_key] = guard
276+
277+
return _guard_cache[cache_key]
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
try:
17+
from guardrails.errors import ValidationError
18+
19+
GuardrailsAIValidationError = ValidationError
20+
except ImportError:
21+
# create a fallback error class when guardrails is not installed
22+
class GuardrailsAIValidationError(Exception):
23+
"""Fallback validation error when guardrails package is not available."""
24+
25+
pass
26+
27+
28+
class GuardrailsAIError(Exception):
29+
"""Base exception for Guardrails AI integration."""
30+
31+
pass
32+
33+
34+
class GuardrailsAIConfigError(GuardrailsAIError):
35+
"""Raised when configuration is invalid."""
36+
37+
pass
38+
39+
40+
__all__ = [
41+
"GuardrailsAIError",
42+
"GuardrailsAIValidationError",
43+
"GuardrailsAIConfigError",
44+
]
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
flow guardrailsai check input $validator
2+
"""Check input text using relevant Guardrails AI validators."""
3+
$result = await ValidateGuardrailsAiInputAction(validator=$validator, text=$user_message)
4+
if not $result["valid"]
5+
if $system.config.enable_rails_exceptions
6+
send GuardrailsAIException(message="Guardrails AI {$validator} validation failed")
7+
else
8+
bot refuse to respond
9+
abort
10+
11+
12+
flow guardrailsai check output $validator
13+
"""Check output text using relevant Guardrails AI validators."""
14+
$result = await ValidateGuardrailsAiOutputAction(validator=$validator, text=$bot_message)
15+
if not $result["valid"]
16+
if $system.config.enable_rails_exceptions
17+
send GuardrailsAIException(message="Guardrails AI {$validator} validation failed")
18+
else
19+
bot refuse to respond
20+
abort

0 commit comments

Comments
 (0)