Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 19 additions & 14 deletions nemoguardrails/actions/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,46 +14,51 @@
# limitations under the License.

from dataclasses import dataclass, field
from typing import Any, List, Optional
from typing import Any, Callable, List, Optional, TypedDict, Union


class ActionMeta(TypedDict, total=False):
name: str
is_system_action: bool
execute_async: bool
output_mapping: Optional[Callable[[Any], bool]]


def action(
is_system_action: bool = False,
name: Optional[str] = None,
execute_async: bool = False,
):
output_mapping: Optional[Callable[[Any], bool]] = None,
) -> Callable[[Union[Callable, type]], Union[Callable, type]]:
"""Decorator to mark a function or class as an action.

Args:
is_system_action (bool): Flag indicating if the action is a system action.
name (Optional[str]): The name to associate with the action.
execute_async: Whether the function should be executed in async mode.

output_mapping (Optional[Callable[[Any], bool]]): A function to interpret the action's result.
It should accept the return value (e.g. the first element of a tuple) and return True if the output
should be considered blocked.
Returns:
callable: The decorated function or class.
"""

def decorator(fn_or_cls):
def decorator(fn_or_cls: Union[Callable, type]) -> Union[Callable, type]:
"""Inner decorator function to add metadata to the action.

Args:
fn_or_cls: The function or class being decorated.
"""
fn_or_cls_target = getattr(fn_or_cls, "__func__", fn_or_cls)

# Detect the decorator being applied to staticmethod or classmethod.
# Will annotate the the inner function in that case as otherwise
# metaclass will be giving us the unannotated enclosed function on
# attribute lookup.
if hasattr(fn_or_cls, "__func__"):
fn_or_cls_target = fn_or_cls.__func__
else:
fn_or_cls_target = fn_or_cls

fn_or_cls_target.action_meta = {
action_meta: ActionMeta = {
"name": name or fn_or_cls.__name__,
"is_system_action": is_system_action,
"execute_async": execute_async,
"output_mapping": output_mapping,
}

setattr(fn_or_cls_target, "action_meta", action_meta)
return fn_or_cls

return decorator
Expand Down
48 changes: 46 additions & 2 deletions nemoguardrails/library/activefence/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,52 @@
log = logging.getLogger(__name__)


@action(is_system_action=True)
async def call_activefence_api(text: Optional[str] = None):
def call_activefence_api_mapping(result: dict) -> bool:
"""
Mapping for call_activefence_api.

Expects result to be a dict with:
- "max_risk_score": a float representing the overall risk.
- "violations": a dict mapping violation type (str) to risk score (float).

Block (return True) if any of the following conditions are met:
1. Overall max risk score > 0.7.
2. Individual thresholds are exceeded:
- "abusive_or_harmful.harassment_or_bullying" > 0.8
- "abusive_or_harmful.profanity" > 0.7
- "abusive_or_harmful.hate_speech" > 0.6
- "abusive_or_harmful.child_grooming" > 0.4
- "abusive_or_harmful.general_violence" > 0.7
- "self_harm.general" > 0.8
- "adult_content.general" > 0.3
- "privacy_violation.pii" > 0.8
"""
max_risk = result.get("max_risk_score", 0.0)
if max_risk > 0.7:
return True

violations = result.get("violations", {})

thresholds = {
"abusive_or_harmful.harassment_or_bullying": 0.8,
"abusive_or_harmful.profanity": 0.7,
"abusive_or_harmful.hate_speech": 0.6,
"abusive_or_harmful.child_grooming": 0.4,
"abusive_or_harmful.general_violence": 0.7,
"self_harm.general": 0.8,
"adult_content.general": 0.3,
"privacy_violation.pii": 0.8,
}

for violation_type, threshold in thresholds.items():
if violations.get(violation_type, 0) > threshold:
return True

return False


@action(is_system_action=True, output_mapping=call_activefence_api_mapping)
async def call_activefence_api(text: Optional[str] = None, **kwargs):
api_key = os.environ.get("ACTIVEFENCE_API_KEY")

if api_key is None:
Expand Down
47 changes: 44 additions & 3 deletions nemoguardrails/library/autoalign/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,38 @@
default_groundedness_config = {"groundedness_checker": {"verify_response": False}}


def autoalign_output_api_mapping(result: dict) -> bool:
"""
Mapping for autoalign_output_api.

Expects result to be a dict with a key "guardrails_triggered" (a boolean).
Returns True (block) if guardrails were triggered.
"""
return result.get("guardrails_triggered", False)


def autoalign_groundedness_output_api_mapping(result: float) -> bool:
"""
Mapping for autoalign_groundedness_output_api.

Expects result to be a numeric score.
Returns True (block) if the score is below the default groundedness threshold.
"""
DEFAULT_GROUNDEDNESS_THRESHOLD = 0.5
return result < DEFAULT_GROUNDEDNESS_THRESHOLD


def autoalign_factcheck_output_api_mapping(result: float) -> bool:
"""
Mapping for autoalign_factcheck_output_api.

Expects result to be a numeric score.
Returns True (block) if the score is below the default factcheck threshold.
"""
DEFAULT_FACTCHECK_THRESHOLD = 0.5
return result < DEFAULT_FACTCHECK_THRESHOLD


def process_autoalign_output(responses: List[Any], show_toxic_phrases: bool = False):
"""Processes the output provided AutoAlign API"""

Expand Down Expand Up @@ -252,6 +284,7 @@ async def autoalign_input_api(
context: Optional[dict] = None,
show_autoalign_message: bool = True,
show_toxic_phrases: bool = False,
**kwargs,
):
"""Calls AutoAlign API for the user message and guardrail configuration provided"""
user_message = context.get("user_message")
Expand Down Expand Up @@ -285,12 +318,13 @@ async def autoalign_input_api(
return autoalign_response


@action(name="autoalign_output_api")
@action(name="autoalign_output_api", output_mapping=autoalign_output_api_mapping)
async def autoalign_output_api(
llm_task_manager: LLMTaskManager,
context: Optional[dict] = None,
show_autoalign_message: bool = True,
show_toxic_phrases: bool = False,
**kwargs,
):
"""Calls AutoAlign API for the bot message and guardrail configuration provided"""
bot_message = context.get("bot_message")
Expand Down Expand Up @@ -319,12 +353,16 @@ async def autoalign_output_api(
return autoalign_response


@action(name="autoalign_groundedness_output_api")
@action(
name="autoalign_groundedness_output_api",
output_mapping=autoalign_groundedness_output_api_mapping,
)
async def autoalign_groundedness_output_api(
llm_task_manager: LLMTaskManager,
context: Optional[dict] = None,
factcheck_threshold: float = 0.0,
show_autoalign_message: bool = True,
**kwargs,
):
"""Calls AutoAlign groundedness check API and checks whether the bot message is factually grounded according to given
documents"""
Expand Down Expand Up @@ -355,7 +393,10 @@ async def autoalign_groundedness_output_api(
return score


@action(name="autoalign_factcheck_output_api")
@action(
name="autoalign_factcheck_output_api",
output_mapping=autoalign_factcheck_output_api_mapping,
)
async def autoalign_factcheck_output_api(
llm_task_manager: LLMTaskManager,
context: Optional[dict] = None,
Expand Down
14 changes: 13 additions & 1 deletion nemoguardrails/library/cleanlab/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,21 @@
log = logging.getLogger(__name__)


@action(name="call cleanlab api", is_system_action=True)
def cleanlab_output_mapping(result: dict) -> bool:
"""Block if the trustworthiness score is below 0.6."""

score = result.get("trustworthiness_score", 1)
return score < 0.6


@action(
name="call cleanlab api",
is_system_action=True,
output_mapping=cleanlab_output_mapping,
)
async def call_cleanlab_api(
context: Optional[dict] = None,
**kwargs,
) -> Union[ValueError, ImportError, Dict]:
api_key = os.environ.get("CLEANLAB_API_KEY")

Expand Down
20 changes: 19 additions & 1 deletion nemoguardrails/library/content_safety/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ async def content_safety_check_input(
llm_task_manager: LLMTaskManager,
model_name: Optional[str] = None,
context: Optional[dict] = None,
**kwargs,
) -> dict:
_MAX_TOKENS = 3
user_input: str = ""
Expand Down Expand Up @@ -90,12 +91,29 @@ async def content_safety_check_input(
return {"allowed": is_safe, "policy_violations": violated_policies}


@action()
def content_safety_check_output_mapping(result: dict) -> bool:
"""
Mapping function for content_safety_check_output.

Assumes result is a dictionary with:
- "allowed": a boolean where True means the content is safe.
- "policy_violations": a list of policies that were violated (optional in the mapping logic).

Returns:
True if the content should be blocked (i.e. allowed is False),
False if the content is safe.
"""
allowed = result.get("allowed", True)
return not allowed


@action(output_mapping=content_safety_check_output_mapping)
async def content_safety_check_output(
llms: Dict[str, BaseLLM],
llm_task_manager: LLMTaskManager,
model_name: Optional[str] = None,
context: Optional[dict] = None,
**kwargs,
) -> dict:
_MAX_TOKENS = 3
user_input: str = ""
Expand Down
14 changes: 13 additions & 1 deletion nemoguardrails/library/factchecking/align_score/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,24 @@
log = logging.getLogger(__name__)


@action()
def alignscore_check_facts_mapping(result: float) -> bool:
"""
Mapping for alignscore_check_facts.

Expects result to be a numeric score (float) representing the factual accuracy.
Returns True (i.e. block the output) if the score is below 0.5.
"""
THRESHOLD = 0.5
return result < THRESHOLD


@action(output_mapping=alignscore_check_facts_mapping)
async def alignscore_check_facts(
llm_task_manager: LLMTaskManager,
context: Optional[dict] = None,
llm: Optional[BaseLLM] = None,
config: Optional[RailsConfig] = None,
**kwargs,
):
"""Checks the facts for the bot response using an information alignment score."""
fact_checking_config = llm_task_manager.config.rails.config.fact_checking
Expand Down
80 changes: 78 additions & 2 deletions nemoguardrails/library/gcp_moderate_text/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,84 @@
log = logging.getLogger(__name__)


@action(name="call gcpnlp api", is_system_action=True)
async def call_gcp_text_moderation_api(context: Optional[dict] = None):
def gcp_text_moderation_mapping(result: dict) -> bool:
"""Determines whether the output from the GCP text moderation API should be blocked.

The mapping checks two things:
1. If the overall 'max_risk_score' exceeds 0.8.
2. If any individual violation score (from the 'violations' dictionary) exceeds its defined threshold.
- Toxic: > 0.8
- Insult: > 0.7
- Profanity: > 0.6
- Derogatory: > 0.4
- Violent: > 0.8
- Sexual: > 0.7
- Death, Harm & Tragedy: > 0.8
- Firearms & Weapons: > 0.8
- Illicit Drugs: > 0.8
- Public Safety: > 0.8
- Health: > 0.8
- Religion & Belief: > 0.8
- War & Conflict: > 0.8
- Politics: > 0.8
- Finance: > 0.8
- Legal: > 0.8

Returns True (i.e. block the output) if any of the conditions are met.
"""
# Check overall risk score if available.
max_risk_score = result.get("max_risk_score", 0)
if max_risk_score > 0.8:
return True

# Check detailed violations, if provided.
violations = result.get("violations", {})

if violations.get("Toxic", 0) > 0.8:
return True
if violations.get("Insult", 0) > 0.7:
return True
if violations.get("Profanity", 0) > 0.6:
return True
if violations.get("Derogatory", 0) > 0.4:
return True
if violations.get("Violent", 0) > 0.8:
return True
if violations.get("Sexual", 0) > 0.7:
return True
if violations.get("Death, Harm & Tragedy", 0) > 0.8:
return True
if violations.get("Firearms & Weapons", 0) > 0.8:
return True
if violations.get("Illicit Drugs", 0) > 0.8:
return True
if violations.get("Public Safety", 0) > 0.8:
return True
if violations.get("Health", 0) > 0.8:
return True
if violations.get("Religion & Belief", 0) > 0.8:
return True
if violations.get("War & Conflict", 0) > 0.8:
return True
if violations.get("Politics", 0) > 0.8:
return True
if violations.get("Finance", 0) > 0.8:
return True
if violations.get("Legal", 0) > 0.8:
return True

# If none of the thresholds are exceeded, allow the output.
return False


@action(
name="call gcpnlp api",
is_system_action=True,
output_mapping=gcp_text_moderation_mapping,
)
async def call_gcp_text_moderation_api(
context: Optional[dict] = None, **kwargs
) -> dict:
"""
Application Default Credentials (ADC) is a strategy used by the GCP authentication libraries to automatically
find credentials based on the application environment. ADC searches for credentials in the following locations (Search order):
Expand Down
Loading