Skip to content

Commit c58f262

Browse files
authored
Feat: Add unified output mapping for actions (#965)
* feat: add output_mapping to action decorators * feat: add kwargs to actions for flexibility * feat: add ActionMeta TypedDict and improve action decorator * test: refactor and add new tests for action decorator
1 parent 08569c8 commit c58f262

File tree

20 files changed

+448
-76
lines changed

20 files changed

+448
-76
lines changed

nemoguardrails/actions/actions.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,46 +14,51 @@
1414
# limitations under the License.
1515

1616
from dataclasses import dataclass, field
17-
from typing import Any, List, Optional
17+
from typing import Any, Callable, List, Optional, TypedDict, Union
18+
19+
20+
class ActionMeta(TypedDict, total=False):
21+
name: str
22+
is_system_action: bool
23+
execute_async: bool
24+
output_mapping: Optional[Callable[[Any], bool]]
1825

1926

2027
def action(
2128
is_system_action: bool = False,
2229
name: Optional[str] = None,
2330
execute_async: bool = False,
24-
):
31+
output_mapping: Optional[Callable[[Any], bool]] = None,
32+
) -> Callable[[Union[Callable, type]], Union[Callable, type]]:
2533
"""Decorator to mark a function or class as an action.
2634
2735
Args:
2836
is_system_action (bool): Flag indicating if the action is a system action.
2937
name (Optional[str]): The name to associate with the action.
3038
execute_async: Whether the function should be executed in async mode.
31-
39+
output_mapping (Optional[Callable[[Any], bool]]): A function to interpret the action's result.
40+
It should accept the return value (e.g. the first element of a tuple) and return True if the output
41+
should be considered blocked.
3242
Returns:
3343
callable: The decorated function or class.
3444
"""
3545

36-
def decorator(fn_or_cls):
46+
def decorator(fn_or_cls: Union[Callable, type]) -> Union[Callable, type]:
3747
"""Inner decorator function to add metadata to the action.
3848
3949
Args:
4050
fn_or_cls: The function or class being decorated.
4151
"""
52+
fn_or_cls_target = getattr(fn_or_cls, "__func__", fn_or_cls)
4253

43-
# Detect the decorator being applied to staticmethod or classmethod.
44-
# Will annotate the the inner function in that case as otherwise
45-
# metaclass will be giving us the unannotated enclosed function on
46-
# attribute lookup.
47-
if hasattr(fn_or_cls, "__func__"):
48-
fn_or_cls_target = fn_or_cls.__func__
49-
else:
50-
fn_or_cls_target = fn_or_cls
51-
52-
fn_or_cls_target.action_meta = {
54+
action_meta: ActionMeta = {
5355
"name": name or fn_or_cls.__name__,
5456
"is_system_action": is_system_action,
5557
"execute_async": execute_async,
58+
"output_mapping": output_mapping,
5659
}
60+
61+
setattr(fn_or_cls_target, "action_meta", action_meta)
5762
return fn_or_cls
5863

5964
return decorator

nemoguardrails/library/activefence/actions.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,52 @@
2525
log = logging.getLogger(__name__)
2626

2727

28-
@action(is_system_action=True)
29-
async def call_activefence_api(text: Optional[str] = None):
28+
def call_activefence_api_mapping(result: dict) -> bool:
29+
"""
30+
Mapping for call_activefence_api.
31+
32+
Expects result to be a dict with:
33+
- "max_risk_score": a float representing the overall risk.
34+
- "violations": a dict mapping violation type (str) to risk score (float).
35+
36+
Block (return True) if any of the following conditions are met:
37+
1. Overall max risk score > 0.7.
38+
2. Individual thresholds are exceeded:
39+
- "abusive_or_harmful.harassment_or_bullying" > 0.8
40+
- "abusive_or_harmful.profanity" > 0.7
41+
- "abusive_or_harmful.hate_speech" > 0.6
42+
- "abusive_or_harmful.child_grooming" > 0.4
43+
- "abusive_or_harmful.general_violence" > 0.7
44+
- "self_harm.general" > 0.8
45+
- "adult_content.general" > 0.3
46+
- "privacy_violation.pii" > 0.8
47+
"""
48+
max_risk = result.get("max_risk_score", 0.0)
49+
if max_risk > 0.7:
50+
return True
51+
52+
violations = result.get("violations", {})
53+
54+
thresholds = {
55+
"abusive_or_harmful.harassment_or_bullying": 0.8,
56+
"abusive_or_harmful.profanity": 0.7,
57+
"abusive_or_harmful.hate_speech": 0.6,
58+
"abusive_or_harmful.child_grooming": 0.4,
59+
"abusive_or_harmful.general_violence": 0.7,
60+
"self_harm.general": 0.8,
61+
"adult_content.general": 0.3,
62+
"privacy_violation.pii": 0.8,
63+
}
64+
65+
for violation_type, threshold in thresholds.items():
66+
if violations.get(violation_type, 0) > threshold:
67+
return True
68+
69+
return False
70+
71+
72+
@action(is_system_action=True, output_mapping=call_activefence_api_mapping)
73+
async def call_activefence_api(text: Optional[str] = None, **kwargs):
3074
api_key = os.environ.get("ACTIVEFENCE_API_KEY")
3175

3276
if api_key is None:

nemoguardrails/library/autoalign/actions.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,38 @@
7979
default_groundedness_config = {"groundedness_checker": {"verify_response": False}}
8080

8181

82+
def autoalign_output_api_mapping(result: dict) -> bool:
83+
"""
84+
Mapping for autoalign_output_api.
85+
86+
Expects result to be a dict with a key "guardrails_triggered" (a boolean).
87+
Returns True (block) if guardrails were triggered.
88+
"""
89+
return result.get("guardrails_triggered", False)
90+
91+
92+
def autoalign_groundedness_output_api_mapping(result: float) -> bool:
93+
"""
94+
Mapping for autoalign_groundedness_output_api.
95+
96+
Expects result to be a numeric score.
97+
Returns True (block) if the score is below the default groundedness threshold.
98+
"""
99+
DEFAULT_GROUNDEDNESS_THRESHOLD = 0.5
100+
return result < DEFAULT_GROUNDEDNESS_THRESHOLD
101+
102+
103+
def autoalign_factcheck_output_api_mapping(result: float) -> bool:
104+
"""
105+
Mapping for autoalign_factcheck_output_api.
106+
107+
Expects result to be a numeric score.
108+
Returns True (block) if the score is below the default factcheck threshold.
109+
"""
110+
DEFAULT_FACTCHECK_THRESHOLD = 0.5
111+
return result < DEFAULT_FACTCHECK_THRESHOLD
112+
113+
82114
def process_autoalign_output(responses: List[Any], show_toxic_phrases: bool = False):
83115
"""Processes the output provided AutoAlign API"""
84116

@@ -252,6 +284,7 @@ async def autoalign_input_api(
252284
context: Optional[dict] = None,
253285
show_autoalign_message: bool = True,
254286
show_toxic_phrases: bool = False,
287+
**kwargs,
255288
):
256289
"""Calls AutoAlign API for the user message and guardrail configuration provided"""
257290
user_message = context.get("user_message")
@@ -285,12 +318,13 @@ async def autoalign_input_api(
285318
return autoalign_response
286319

287320

288-
@action(name="autoalign_output_api")
321+
@action(name="autoalign_output_api", output_mapping=autoalign_output_api_mapping)
289322
async def autoalign_output_api(
290323
llm_task_manager: LLMTaskManager,
291324
context: Optional[dict] = None,
292325
show_autoalign_message: bool = True,
293326
show_toxic_phrases: bool = False,
327+
**kwargs,
294328
):
295329
"""Calls AutoAlign API for the bot message and guardrail configuration provided"""
296330
bot_message = context.get("bot_message")
@@ -319,12 +353,16 @@ async def autoalign_output_api(
319353
return autoalign_response
320354

321355

322-
@action(name="autoalign_groundedness_output_api")
356+
@action(
357+
name="autoalign_groundedness_output_api",
358+
output_mapping=autoalign_groundedness_output_api_mapping,
359+
)
323360
async def autoalign_groundedness_output_api(
324361
llm_task_manager: LLMTaskManager,
325362
context: Optional[dict] = None,
326363
factcheck_threshold: float = 0.0,
327364
show_autoalign_message: bool = True,
365+
**kwargs,
328366
):
329367
"""Calls AutoAlign groundedness check API and checks whether the bot message is factually grounded according to given
330368
documents"""
@@ -355,7 +393,10 @@ async def autoalign_groundedness_output_api(
355393
return score
356394

357395

358-
@action(name="autoalign_factcheck_output_api")
396+
@action(
397+
name="autoalign_factcheck_output_api",
398+
output_mapping=autoalign_factcheck_output_api_mapping,
399+
)
359400
async def autoalign_factcheck_output_api(
360401
llm_task_manager: LLMTaskManager,
361402
context: Optional[dict] = None,

nemoguardrails/library/cleanlab/actions.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,21 @@
2222
log = logging.getLogger(__name__)
2323

2424

25-
@action(name="call cleanlab api", is_system_action=True)
25+
def cleanlab_output_mapping(result: dict) -> bool:
26+
"""Block if the trustworthiness score is below 0.6."""
27+
28+
score = result.get("trustworthiness_score", 1)
29+
return score < 0.6
30+
31+
32+
@action(
33+
name="call cleanlab api",
34+
is_system_action=True,
35+
output_mapping=cleanlab_output_mapping,
36+
)
2637
async def call_cleanlab_api(
2738
context: Optional[dict] = None,
39+
**kwargs,
2840
) -> Union[ValueError, ImportError, Dict]:
2941
api_key = os.environ.get("CLEANLAB_API_KEY")
3042

nemoguardrails/library/content_safety/actions.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ async def content_safety_check_input(
3434
llm_task_manager: LLMTaskManager,
3535
model_name: Optional[str] = None,
3636
context: Optional[dict] = None,
37+
**kwargs,
3738
) -> dict:
3839
_MAX_TOKENS = 3
3940
user_input: str = ""
@@ -90,12 +91,29 @@ async def content_safety_check_input(
9091
return {"allowed": is_safe, "policy_violations": violated_policies}
9192

9293

93-
@action()
94+
def content_safety_check_output_mapping(result: dict) -> bool:
95+
"""
96+
Mapping function for content_safety_check_output.
97+
98+
Assumes result is a dictionary with:
99+
- "allowed": a boolean where True means the content is safe.
100+
- "policy_violations": a list of policies that were violated (optional in the mapping logic).
101+
102+
Returns:
103+
True if the content should be blocked (i.e. allowed is False),
104+
False if the content is safe.
105+
"""
106+
allowed = result.get("allowed", True)
107+
return not allowed
108+
109+
110+
@action(output_mapping=content_safety_check_output_mapping)
94111
async def content_safety_check_output(
95112
llms: Dict[str, BaseLLM],
96113
llm_task_manager: LLMTaskManager,
97114
model_name: Optional[str] = None,
98115
context: Optional[dict] = None,
116+
**kwargs,
99117
) -> dict:
100118
_MAX_TOKENS = 3
101119
user_input: str = ""

nemoguardrails/library/factchecking/align_score/actions.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,24 @@
2727
log = logging.getLogger(__name__)
2828

2929

30-
@action()
30+
def alignscore_check_facts_mapping(result: float) -> bool:
31+
"""
32+
Mapping for alignscore_check_facts.
33+
34+
Expects result to be a numeric score (float) representing the factual accuracy.
35+
Returns True (i.e. block the output) if the score is below 0.5.
36+
"""
37+
THRESHOLD = 0.5
38+
return result < THRESHOLD
39+
40+
41+
@action(output_mapping=alignscore_check_facts_mapping)
3142
async def alignscore_check_facts(
3243
llm_task_manager: LLMTaskManager,
3344
context: Optional[dict] = None,
3445
llm: Optional[BaseLLM] = None,
3546
config: Optional[RailsConfig] = None,
47+
**kwargs,
3648
):
3749
"""Checks the facts for the bot response using an information alignment score."""
3850
fact_checking_config = llm_task_manager.config.rails.config.fact_checking

nemoguardrails/library/gcp_moderate_text/actions.py

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,84 @@
2828
log = logging.getLogger(__name__)
2929

3030

31-
@action(name="call gcpnlp api", is_system_action=True)
32-
async def call_gcp_text_moderation_api(context: Optional[dict] = None):
31+
def gcp_text_moderation_mapping(result: dict) -> bool:
32+
"""Determines whether the output from the GCP text moderation API should be blocked.
33+
34+
The mapping checks two things:
35+
1. If the overall 'max_risk_score' exceeds 0.8.
36+
2. If any individual violation score (from the 'violations' dictionary) exceeds its defined threshold.
37+
- Toxic: > 0.8
38+
- Insult: > 0.7
39+
- Profanity: > 0.6
40+
- Derogatory: > 0.4
41+
- Violent: > 0.8
42+
- Sexual: > 0.7
43+
- Death, Harm & Tragedy: > 0.8
44+
- Firearms & Weapons: > 0.8
45+
- Illicit Drugs: > 0.8
46+
- Public Safety: > 0.8
47+
- Health: > 0.8
48+
- Religion & Belief: > 0.8
49+
- War & Conflict: > 0.8
50+
- Politics: > 0.8
51+
- Finance: > 0.8
52+
- Legal: > 0.8
53+
54+
Returns True (i.e. block the output) if any of the conditions are met.
55+
"""
56+
# Check overall risk score if available.
57+
max_risk_score = result.get("max_risk_score", 0)
58+
if max_risk_score > 0.8:
59+
return True
60+
61+
# Check detailed violations, if provided.
62+
violations = result.get("violations", {})
63+
64+
if violations.get("Toxic", 0) > 0.8:
65+
return True
66+
if violations.get("Insult", 0) > 0.7:
67+
return True
68+
if violations.get("Profanity", 0) > 0.6:
69+
return True
70+
if violations.get("Derogatory", 0) > 0.4:
71+
return True
72+
if violations.get("Violent", 0) > 0.8:
73+
return True
74+
if violations.get("Sexual", 0) > 0.7:
75+
return True
76+
if violations.get("Death, Harm & Tragedy", 0) > 0.8:
77+
return True
78+
if violations.get("Firearms & Weapons", 0) > 0.8:
79+
return True
80+
if violations.get("Illicit Drugs", 0) > 0.8:
81+
return True
82+
if violations.get("Public Safety", 0) > 0.8:
83+
return True
84+
if violations.get("Health", 0) > 0.8:
85+
return True
86+
if violations.get("Religion & Belief", 0) > 0.8:
87+
return True
88+
if violations.get("War & Conflict", 0) > 0.8:
89+
return True
90+
if violations.get("Politics", 0) > 0.8:
91+
return True
92+
if violations.get("Finance", 0) > 0.8:
93+
return True
94+
if violations.get("Legal", 0) > 0.8:
95+
return True
96+
97+
# If none of the thresholds are exceeded, allow the output.
98+
return False
99+
100+
101+
@action(
102+
name="call gcpnlp api",
103+
is_system_action=True,
104+
output_mapping=gcp_text_moderation_mapping,
105+
)
106+
async def call_gcp_text_moderation_api(
107+
context: Optional[dict] = None, **kwargs
108+
) -> dict:
33109
"""
34110
Application Default Credentials (ADC) is a strategy used by the GCP authentication libraries to automatically
35111
find credentials based on the application environment. ADC searches for credentials in the following locations (Search order):

0 commit comments

Comments
 (0)