From b9cdc025cd8ef6963f51c5c0e539a9919670ea31 Mon Sep 17 00:00:00 2001 From: prezakhani <13303554+Pouyanpi@users.noreply.github.com> Date: Fri, 31 Jan 2025 22:01:55 +0100 Subject: [PATCH 1/5] feat: add output_mapping to action decorators add mapping for prompt_secuirty output mapping refactor: change output_mapping name --- nemoguardrails/actions/actions.py | 8 +- nemoguardrails/library/activefence/actions.py | 46 ++++++++++- nemoguardrails/library/autoalign/actions.py | 44 ++++++++++- nemoguardrails/library/cleanlab/actions.py | 13 +++- .../library/content_safety/actions.py | 18 ++++- .../factchecking/align_score/actions.py | 13 +++- .../library/gcp_moderate_text/actions.py | 76 ++++++++++++++++++- .../library/hallucination/actions.py | 2 +- nemoguardrails/library/llama_guard/actions.py | 18 ++++- nemoguardrails/library/patronusai/actions.py | 32 +++++++- nemoguardrails/library/privateai/actions.py | 12 ++- .../library/prompt_security/actions.py | 17 ++++- .../library/self_check/facts/actions.py | 13 +++- .../self_check/output_check/actions.py | 2 +- .../sensitive_data_detection/actions.py | 12 ++- 15 files changed, 307 insertions(+), 19 deletions(-) diff --git a/nemoguardrails/actions/actions.py b/nemoguardrails/actions/actions.py index 690e841ef..a8729de8a 100644 --- a/nemoguardrails/actions/actions.py +++ b/nemoguardrails/actions/actions.py @@ -14,13 +14,14 @@ # limitations under the License. from dataclasses import dataclass, field -from typing import Any, List, Optional +from typing import Any, Callable, List, Optional def action( is_system_action: bool = False, name: Optional[str] = None, execute_async: bool = False, + output_mapping: Optional[Callable[[Any], bool]] = None, ): """Decorator to mark a function or class as an action. @@ -28,7 +29,9 @@ def action( is_system_action (bool): Flag indicating if the action is a system action. name (Optional[str]): The name to associate with the action. execute_async: Whether the function should be executed in async mode. - + output_mapping (Optional[Callable[[Any], bool]]): A function to interpret the action's result. + It should accept the return value (e.g. the first element of a tuple) and return True if the output + should be considered blocked. Returns: callable: The decorated function or class. """ @@ -53,6 +56,7 @@ def decorator(fn_or_cls): "name": name or fn_or_cls.__name__, "is_system_action": is_system_action, "execute_async": execute_async, + "output_mapping": output_mapping, } return fn_or_cls diff --git a/nemoguardrails/library/activefence/actions.py b/nemoguardrails/library/activefence/actions.py index a9ee04c4d..d3cdcc867 100644 --- a/nemoguardrails/library/activefence/actions.py +++ b/nemoguardrails/library/activefence/actions.py @@ -25,7 +25,51 @@ log = logging.getLogger(__name__) -@action(is_system_action=True) +def mapping_call_activefence_api(result: dict) -> bool: + """ + Mapping for call_activefence_api. + + Expects result to be a dict with: + - "max_risk_score": a float representing the overall risk. + - "violations": a dict mapping violation type (str) to risk score (float). + + Block (return True) if any of the following conditions are met: + 1. Overall max risk score > 0.7. + 2. Individual thresholds are exceeded: + - "abusive_or_harmful.harassment_or_bullying" > 0.8 + - "abusive_or_harmful.profanity" > 0.7 + - "abusive_or_harmful.hate_speech" > 0.6 + - "abusive_or_harmful.child_grooming" > 0.4 + - "abusive_or_harmful.general_violence" > 0.7 + - "self_harm.general" > 0.8 + - "adult_content.general" > 0.3 + - "privacy_violation.pii" > 0.8 + """ + max_risk = result.get("max_risk_score", 0.0) + if max_risk > 0.7: + return True + + violations = result.get("violations", {}) + + thresholds = { + "abusive_or_harmful.harassment_or_bullying": 0.8, + "abusive_or_harmful.profanity": 0.7, + "abusive_or_harmful.hate_speech": 0.6, + "abusive_or_harmful.child_grooming": 0.4, + "abusive_or_harmful.general_violence": 0.7, + "self_harm.general": 0.8, + "adult_content.general": 0.3, + "privacy_violation.pii": 0.8, + } + + for violation_type, threshold in thresholds.items(): + if violations.get(violation_type, 0) > threshold: + return True + + return False + + +@action(is_system_action=True, output_mapping=mapping_call_activefence_api) async def call_activefence_api(text: Optional[str] = None): api_key = os.environ.get("ACTIVEFENCE_API_KEY") diff --git a/nemoguardrails/library/autoalign/actions.py b/nemoguardrails/library/autoalign/actions.py index edfdb8da4..9c922f1c6 100644 --- a/nemoguardrails/library/autoalign/actions.py +++ b/nemoguardrails/library/autoalign/actions.py @@ -79,6 +79,38 @@ default_groundedness_config = {"groundedness_checker": {"verify_response": False}} +def mapping_autoalign_output_api(result: dict) -> bool: + """ + Mapping for autoalign_output_api. + + Expects result to be a dict with a key "guardrails_triggered" (a boolean). + Returns True (block) if guardrails were triggered. + """ + return result.get("guardrails_triggered", False) + + +def mapping_autoalign_groundedness_output_api(result: float) -> bool: + """ + Mapping for autoalign_groundedness_output_api. + + Expects result to be a numeric score. + Returns True (block) if the score is below the default groundedness threshold. + """ + DEFAULT_GROUNDEDNESS_THRESHOLD = 0.5 + return result < DEFAULT_GROUNDEDNESS_THRESHOLD + + +def mapping_autoalign_factcheck_output_api(result: float) -> bool: + """ + Mapping for autoalign_factcheck_output_api. + + Expects result to be a numeric score. + Returns True (block) if the score is below the default factcheck threshold. + """ + DEFAULT_FACTCHECK_THRESHOLD = 0.5 + return result < DEFAULT_FACTCHECK_THRESHOLD + + def process_autoalign_output(responses: List[Any], show_toxic_phrases: bool = False): """Processes the output provided AutoAlign API""" @@ -285,7 +317,7 @@ async def autoalign_input_api( return autoalign_response -@action(name="autoalign_output_api") +@action(name="autoalign_output_api", output_mapping=mapping_autoalign_output_api) async def autoalign_output_api( llm_task_manager: LLMTaskManager, context: Optional[dict] = None, @@ -319,7 +351,10 @@ async def autoalign_output_api( return autoalign_response -@action(name="autoalign_groundedness_output_api") +@action( + name="autoalign_groundedness_output_api", + output_mapping=mapping_autoalign_groundedness_output_api, +) async def autoalign_groundedness_output_api( llm_task_manager: LLMTaskManager, context: Optional[dict] = None, @@ -355,7 +390,10 @@ async def autoalign_groundedness_output_api( return score -@action(name="autoalign_factcheck_output_api") +@action( + name="autoalign_factcheck_output_api", + output_mapping=mapping_autoalign_factcheck_output_api, +) async def autoalign_factcheck_output_api( llm_task_manager: LLMTaskManager, context: Optional[dict] = None, diff --git a/nemoguardrails/library/cleanlab/actions.py b/nemoguardrails/library/cleanlab/actions.py index 57e40c37a..87072cd5a 100644 --- a/nemoguardrails/library/cleanlab/actions.py +++ b/nemoguardrails/library/cleanlab/actions.py @@ -22,7 +22,18 @@ log = logging.getLogger(__name__) -@action(name="call cleanlab api", is_system_action=True) +def cleanlab_output_mapping(result: dict) -> bool: + """Block if the trustworthiness score is below 0.6.""" + + score = result.get("trustworthiness_score", 1) + return score < 0.6 + + +@action( + name="call cleanlab api", + is_system_action=True, + output_mapping=cleanlab_output_mapping, +) async def call_cleanlab_api( context: Optional[dict] = None, ) -> Union[ValueError, ImportError, Dict]: diff --git a/nemoguardrails/library/content_safety/actions.py b/nemoguardrails/library/content_safety/actions.py index 110b6563a..e5b774576 100644 --- a/nemoguardrails/library/content_safety/actions.py +++ b/nemoguardrails/library/content_safety/actions.py @@ -90,7 +90,23 @@ async def content_safety_check_input( return {"allowed": is_safe, "policy_violations": violated_policies} -@action() +def content_safety_check_output_mapping(result: dict) -> bool: + """ + Mapping function for content_safety_check_output. + + Assumes result is a dictionary with: + - "allowed": a boolean where True means the content is safe. + - "policy_violations": a list of policies that were violated (optional in the mapping logic). + + Returns: + True if the content should be blocked (i.e. allowed is False), + False if the content is safe. + """ + allowed = result.get("allowed", True) + return not allowed + + +@action(output_mapping=content_safety_check_output_mapping) async def content_safety_check_output( llms: Dict[str, BaseLLM], llm_task_manager: LLMTaskManager, diff --git a/nemoguardrails/library/factchecking/align_score/actions.py b/nemoguardrails/library/factchecking/align_score/actions.py index a977c25c6..cd5c7fdd8 100644 --- a/nemoguardrails/library/factchecking/align_score/actions.py +++ b/nemoguardrails/library/factchecking/align_score/actions.py @@ -27,7 +27,18 @@ log = logging.getLogger(__name__) -@action() +def alignscore_check_facts_mapping(result: float) -> bool: + """ + Mapping for alignscore_check_facts. + + Expects result to be a numeric score (float) representing the factual accuracy. + Returns True (i.e. block the output) if the score is below 0.5. + """ + THRESHOLD = 0.5 + return result < THRESHOLD + + +@action(output_mapping=alignscore_check_facts_mapping) async def alignscore_check_facts( llm_task_manager: LLMTaskManager, context: Optional[dict] = None, diff --git a/nemoguardrails/library/gcp_moderate_text/actions.py b/nemoguardrails/library/gcp_moderate_text/actions.py index 404100412..efd744329 100644 --- a/nemoguardrails/library/gcp_moderate_text/actions.py +++ b/nemoguardrails/library/gcp_moderate_text/actions.py @@ -28,7 +28,81 @@ log = logging.getLogger(__name__) -@action(name="call gcpnlp api", is_system_action=True) +def gcp_text_moderation_mapping(result: dict) -> bool: + """Determines whether the output from the GCP text moderation API should be blocked. + + The mapping checks two things: + 1. If the overall 'max_risk_score' exceeds 0.8. + 2. If any individual violation score (from the 'violations' dictionary) exceeds its defined threshold. + - Toxic: > 0.8 + - Insult: > 0.7 + - Profanity: > 0.6 + - Derogatory: > 0.4 + - Violent: > 0.8 + - Sexual: > 0.7 + - Death, Harm & Tragedy: > 0.8 + - Firearms & Weapons: > 0.8 + - Illicit Drugs: > 0.8 + - Public Safety: > 0.8 + - Health: > 0.8 + - Religion & Belief: > 0.8 + - War & Conflict: > 0.8 + - Politics: > 0.8 + - Finance: > 0.8 + - Legal: > 0.8 + + Returns True (i.e. block the output) if any of the conditions are met. + """ + # Check overall risk score if available. + max_risk_score = result.get("max_risk_score", 0) + if max_risk_score > 0.8: + return True + + # Check detailed violations, if provided. + violations = result.get("violations", {}) + + if violations.get("Toxic", 0) > 0.8: + return True + if violations.get("Insult", 0) > 0.7: + return True + if violations.get("Profanity", 0) > 0.6: + return True + if violations.get("Derogatory", 0) > 0.4: + return True + if violations.get("Violent", 0) > 0.8: + return True + if violations.get("Sexual", 0) > 0.7: + return True + if violations.get("Death, Harm & Tragedy", 0) > 0.8: + return True + if violations.get("Firearms & Weapons", 0) > 0.8: + return True + if violations.get("Illicit Drugs", 0) > 0.8: + return True + if violations.get("Public Safety", 0) > 0.8: + return True + if violations.get("Health", 0) > 0.8: + return True + if violations.get("Religion & Belief", 0) > 0.8: + return True + if violations.get("War & Conflict", 0) > 0.8: + return True + if violations.get("Politics", 0) > 0.8: + return True + if violations.get("Finance", 0) > 0.8: + return True + if violations.get("Legal", 0) > 0.8: + return True + + # If none of the thresholds are exceeded, allow the output. + return False + + +@action( + name="call gcpnlp api", + is_system_action=True, + output_mapping=gcp_text_moderation_mapping, +) async def call_gcp_text_moderation_api(context: Optional[dict] = None): """ Application Default Credentials (ADC) is a strategy used by the GCP authentication libraries to automatically diff --git a/nemoguardrails/library/hallucination/actions.py b/nemoguardrails/library/hallucination/actions.py index 684ad4103..e5d0f5bc7 100644 --- a/nemoguardrails/library/hallucination/actions.py +++ b/nemoguardrails/library/hallucination/actions.py @@ -39,7 +39,7 @@ HALLUCINATION_NUM_EXTRA_RESPONSES = 2 -@action() +@action(output_mapping=lambda value: value) async def self_check_hallucination( llm: BaseLLM, llm_task_manager: LLMTaskManager, diff --git a/nemoguardrails/library/llama_guard/actions.py b/nemoguardrails/library/llama_guard/actions.py index d52fd83f0..b2e75618b 100644 --- a/nemoguardrails/library/llama_guard/actions.py +++ b/nemoguardrails/library/llama_guard/actions.py @@ -82,7 +82,23 @@ async def llama_guard_check_input( return {"allowed": allowed, "policy_violations": policy_violations} -@action() +def mapping_llama_guard_check_output(result: dict) -> bool: + """ + Mapping for llama_guard_check_output. + + Expects result to be a dict with: + - "allowed": a boolean indicating if the response passed the safety check. + - "policy_violations": additional details (not used in the mapping logic). + + Returns: + True if the response should be blocked (i.e. if "allowed" is False), + False otherwise. + """ + allowed = result.get("allowed", True) + return not allowed + + +@action(output_mapping=mapping_llama_guard_check_output) async def llama_guard_check_output( llm_task_manager: LLMTaskManager, context: Optional[dict] = None, diff --git a/nemoguardrails/library/patronusai/actions.py b/nemoguardrails/library/patronusai/actions.py index 2903c6ba0..133318c7d 100644 --- a/nemoguardrails/library/patronusai/actions.py +++ b/nemoguardrails/library/patronusai/actions.py @@ -60,7 +60,19 @@ def parse_patronus_lynx_response( return hallucination, reasoning -@action() +def mapping_patronus_lynx_check_output_hallucination(result: dict) -> bool: + """ + Mapping for patronus_lynx_check_output_hallucination. + + Expects result to be a dict with: + "hallucination": a boolean where True indicates a hallucination was detected. + + Block (return True) if "hallucination" is True. + """ + return result.get("hallucination", False) + + +@action(output_mapping=mapping_patronus_lynx_check_output_hallucination) async def patronus_lynx_check_output_hallucination( llm_task_manager: LLMTaskManager, context: Optional[dict] = None, @@ -215,7 +227,23 @@ async def patronus_evaluate_request( return response_json -@action(name="patronus_api_check_output") +def mapping_patronus_api_check_output(result: dict) -> bool: + """ + Mapping for patronus_api_check_output. + + Expects result to be a dict with: + "pass": a boolean where True means the output passed the check. + + Block (return True) if "pass" is False. + """ + # Default to True (pass) if the key is missing + passed = result.get("pass", True) + return not passed + + +@action( + name="patronus_api_check_output", output_mapping=mapping_patronus_api_check_output +) async def patronus_api_check_output( llm_task_manager: LLMTaskManager, context: Optional[dict] = None, diff --git a/nemoguardrails/library/privateai/actions.py b/nemoguardrails/library/privateai/actions.py index bb021357e..814920c35 100644 --- a/nemoguardrails/library/privateai/actions.py +++ b/nemoguardrails/library/privateai/actions.py @@ -27,7 +27,17 @@ log = logging.getLogger(__name__) -@action(is_system_action=True) +def mapping_detect_pii(result: bool) -> bool: + """ + Mapping for detect_pii. + + Since the function returns True when PII is detected, + we block if result is True. + """ + return result + + +@action(is_system_action=True, output_mapping=mapping_detect_pii) async def detect_pii(source: str, text: str, config: RailsConfig): """Checks whether the provided text contains any PII. diff --git a/nemoguardrails/library/prompt_security/actions.py b/nemoguardrails/library/prompt_security/actions.py index b37135eb3..eb1f6c605 100644 --- a/nemoguardrails/library/prompt_security/actions.py +++ b/nemoguardrails/library/prompt_security/actions.py @@ -87,7 +87,22 @@ async def ps_protect_api_async( } -@action(is_system_action=True) +def mapping_protect_text(result: dict) -> bool: + """ + Mapping for protect_text action. + + Expects result to be a dict with: + - "is_blocked": a boolean indicating if the response passed to prompt security should be blocked. + + Returns: + True if the response should be blocked (i.e. if "is_blocked" is True), + False otherwise. + """ + blocked = result.get("is_blocked", True) + return blocked + + +@action(is_system_action=True, output_mapping=mapping_protect_text) async def protect_text( user_prompt: Optional[str] = None, bot_response: Optional[str] = None ): diff --git a/nemoguardrails/library/self_check/facts/actions.py b/nemoguardrails/library/self_check/facts/actions.py index 976c232c9..177f6124f 100644 --- a/nemoguardrails/library/self_check/facts/actions.py +++ b/nemoguardrails/library/self_check/facts/actions.py @@ -30,7 +30,18 @@ log = logging.getLogger(__name__) -@action() +def mapping_self_check_facts(result: float) -> bool: + """ + Mapping for self_check_facts. + + Expects result to be a numeric score (float) representing the factual accuracy. + Returns True (i.e. block the output) if the score is below 0.5. + """ + THRESHOLD = 0.5 + return result < THRESHOLD + + +@action(output_mapping=mapping_self_check_facts) async def self_check_facts( llm_task_manager: LLMTaskManager, context: Optional[dict] = None, diff --git a/nemoguardrails/library/self_check/output_check/actions.py b/nemoguardrails/library/self_check/output_check/actions.py index 400b0b87e..904d301f5 100644 --- a/nemoguardrails/library/self_check/output_check/actions.py +++ b/nemoguardrails/library/self_check/output_check/actions.py @@ -30,7 +30,7 @@ log = logging.getLogger(__name__) -@action(is_system_action=True) +@action(is_system_action=True, output_mapping=lambda value: not value) async def self_check_output( llm_task_manager: LLMTaskManager, context: Optional[dict] = None, diff --git a/nemoguardrails/library/sensitive_data_detection/actions.py b/nemoguardrails/library/sensitive_data_detection/actions.py index 40f6cff2a..afa9091db 100644 --- a/nemoguardrails/library/sensitive_data_detection/actions.py +++ b/nemoguardrails/library/sensitive_data_detection/actions.py @@ -85,7 +85,17 @@ def _get_ad_hoc_recognizers(sdd_config: SensitiveDataDetection): return ad_hoc_recognizers -@action(is_system_action=True) +def mapping_detect_sensitive_data(result: bool) -> bool: + """ + Mapping for detect_sensitive_data. + + Since the function returns True when sensitive data is detected, + we block if result is True. + """ + return result + + +@action(is_system_action=True, output_mapping=mapping_detect_sensitive_data) async def detect_sensitive_data(source: str, text: str, config: RailsConfig): """Checks whether the provided text contains any sensitive data. From 19d91f89f654df60f9fc295ff4fb36e0a3cf7045 Mon Sep 17 00:00:00 2001 From: prezakhani <13303554+Pouyanpi@users.noreply.github.com> Date: Sun, 2 Feb 2025 16:21:48 +0100 Subject: [PATCH 2/5] feat: add kwargs to actions for flexibility --- nemoguardrails/library/activefence/actions.py | 2 +- nemoguardrails/library/autoalign/actions.py | 3 +++ nemoguardrails/library/cleanlab/actions.py | 1 + nemoguardrails/library/content_safety/actions.py | 2 ++ nemoguardrails/library/factchecking/align_score/actions.py | 1 + nemoguardrails/library/gcp_moderate_text/actions.py | 4 +++- nemoguardrails/library/hallucination/actions.py | 1 + nemoguardrails/library/jailbreak_detection/actions.py | 4 +++- nemoguardrails/library/llama_guard/actions.py | 1 + nemoguardrails/library/patronusai/actions.py | 1 + nemoguardrails/library/privateai/actions.py | 7 ++++++- nemoguardrails/library/prompt_security/actions.py | 2 +- nemoguardrails/library/self_check/facts/actions.py | 1 + nemoguardrails/library/self_check/input_check/actions.py | 1 + nemoguardrails/library/self_check/output_check/actions.py | 1 + nemoguardrails/library/sensitive_data_detection/actions.py | 7 ++++++- nemoguardrails/library/topic_safety/actions.py | 1 + 17 files changed, 34 insertions(+), 6 deletions(-) diff --git a/nemoguardrails/library/activefence/actions.py b/nemoguardrails/library/activefence/actions.py index d3cdcc867..bf8da4eeb 100644 --- a/nemoguardrails/library/activefence/actions.py +++ b/nemoguardrails/library/activefence/actions.py @@ -70,7 +70,7 @@ def mapping_call_activefence_api(result: dict) -> bool: @action(is_system_action=True, output_mapping=mapping_call_activefence_api) -async def call_activefence_api(text: Optional[str] = None): +async def call_activefence_api(text: Optional[str] = None, **kwargs): api_key = os.environ.get("ACTIVEFENCE_API_KEY") if api_key is None: diff --git a/nemoguardrails/library/autoalign/actions.py b/nemoguardrails/library/autoalign/actions.py index 9c922f1c6..4578426ae 100644 --- a/nemoguardrails/library/autoalign/actions.py +++ b/nemoguardrails/library/autoalign/actions.py @@ -284,6 +284,7 @@ async def autoalign_input_api( context: Optional[dict] = None, show_autoalign_message: bool = True, show_toxic_phrases: bool = False, + **kwargs, ): """Calls AutoAlign API for the user message and guardrail configuration provided""" user_message = context.get("user_message") @@ -323,6 +324,7 @@ async def autoalign_output_api( context: Optional[dict] = None, show_autoalign_message: bool = True, show_toxic_phrases: bool = False, + **kwargs, ): """Calls AutoAlign API for the bot message and guardrail configuration provided""" bot_message = context.get("bot_message") @@ -360,6 +362,7 @@ async def autoalign_groundedness_output_api( context: Optional[dict] = None, factcheck_threshold: float = 0.0, show_autoalign_message: bool = True, + **kwargs, ): """Calls AutoAlign groundedness check API and checks whether the bot message is factually grounded according to given documents""" diff --git a/nemoguardrails/library/cleanlab/actions.py b/nemoguardrails/library/cleanlab/actions.py index 87072cd5a..a7f95cb8d 100644 --- a/nemoguardrails/library/cleanlab/actions.py +++ b/nemoguardrails/library/cleanlab/actions.py @@ -36,6 +36,7 @@ def cleanlab_output_mapping(result: dict) -> bool: ) async def call_cleanlab_api( context: Optional[dict] = None, + **kwargs, ) -> Union[ValueError, ImportError, Dict]: api_key = os.environ.get("CLEANLAB_API_KEY") diff --git a/nemoguardrails/library/content_safety/actions.py b/nemoguardrails/library/content_safety/actions.py index e5b774576..7c1cbf406 100644 --- a/nemoguardrails/library/content_safety/actions.py +++ b/nemoguardrails/library/content_safety/actions.py @@ -34,6 +34,7 @@ async def content_safety_check_input( llm_task_manager: LLMTaskManager, model_name: Optional[str] = None, context: Optional[dict] = None, + **kwargs, ) -> dict: _MAX_TOKENS = 3 user_input: str = "" @@ -112,6 +113,7 @@ async def content_safety_check_output( llm_task_manager: LLMTaskManager, model_name: Optional[str] = None, context: Optional[dict] = None, + **kwargs, ) -> dict: _MAX_TOKENS = 3 user_input: str = "" diff --git a/nemoguardrails/library/factchecking/align_score/actions.py b/nemoguardrails/library/factchecking/align_score/actions.py index cd5c7fdd8..b2650cc9a 100644 --- a/nemoguardrails/library/factchecking/align_score/actions.py +++ b/nemoguardrails/library/factchecking/align_score/actions.py @@ -44,6 +44,7 @@ async def alignscore_check_facts( context: Optional[dict] = None, llm: Optional[BaseLLM] = None, config: Optional[RailsConfig] = None, + **kwargs, ): """Checks the facts for the bot response using an information alignment score.""" fact_checking_config = llm_task_manager.config.rails.config.fact_checking diff --git a/nemoguardrails/library/gcp_moderate_text/actions.py b/nemoguardrails/library/gcp_moderate_text/actions.py index efd744329..afb7004f0 100644 --- a/nemoguardrails/library/gcp_moderate_text/actions.py +++ b/nemoguardrails/library/gcp_moderate_text/actions.py @@ -103,7 +103,9 @@ def gcp_text_moderation_mapping(result: dict) -> bool: is_system_action=True, output_mapping=gcp_text_moderation_mapping, ) -async def call_gcp_text_moderation_api(context: Optional[dict] = None): +async def call_gcp_text_moderation_api( + context: Optional[dict] = None, **kwargs +) -> dict: """ Application Default Credentials (ADC) is a strategy used by the GCP authentication libraries to automatically find credentials based on the application environment. ADC searches for credentials in the following locations (Search order): diff --git a/nemoguardrails/library/hallucination/actions.py b/nemoguardrails/library/hallucination/actions.py index e5d0f5bc7..778f43f51 100644 --- a/nemoguardrails/library/hallucination/actions.py +++ b/nemoguardrails/library/hallucination/actions.py @@ -46,6 +46,7 @@ async def self_check_hallucination( context: Optional[dict] = None, use_llm_checking: bool = True, config: Optional[RailsConfig] = None, + **kwargs, ): """Checks if the last bot response is a hallucination by checking multiple completions for self-consistency. diff --git a/nemoguardrails/library/jailbreak_detection/actions.py b/nemoguardrails/library/jailbreak_detection/actions.py index a90a233f8..416c70218 100644 --- a/nemoguardrails/library/jailbreak_detection/actions.py +++ b/nemoguardrails/library/jailbreak_detection/actions.py @@ -43,7 +43,9 @@ @action() async def jailbreak_detection_heuristics( - llm_task_manager: LLMTaskManager, context: Optional[dict] = None + llm_task_manager: LLMTaskManager, + context: Optional[dict] = None, + **kwargs, ) -> bool: """Checks the user's prompt to determine if it is attempt to jailbreak the model.""" jailbreak_config = llm_task_manager.config.rails.config.jailbreak_detection diff --git a/nemoguardrails/library/llama_guard/actions.py b/nemoguardrails/library/llama_guard/actions.py index b2e75618b..8e6733ec7 100644 --- a/nemoguardrails/library/llama_guard/actions.py +++ b/nemoguardrails/library/llama_guard/actions.py @@ -58,6 +58,7 @@ async def llama_guard_check_input( llm_task_manager: LLMTaskManager, context: Optional[dict] = None, llama_guard_llm: Optional[BaseLLM] = None, + **kwargs, ) -> dict: """ Checks user messages using the configured Llama Guard model diff --git a/nemoguardrails/library/patronusai/actions.py b/nemoguardrails/library/patronusai/actions.py index 133318c7d..6e0836766 100644 --- a/nemoguardrails/library/patronusai/actions.py +++ b/nemoguardrails/library/patronusai/actions.py @@ -77,6 +77,7 @@ async def patronus_lynx_check_output_hallucination( llm_task_manager: LLMTaskManager, context: Optional[dict] = None, patronus_lynx_llm: Optional[BaseLLM] = None, + **kwargs, ) -> dict: """ Check the bot response for hallucinations based on the given chunks diff --git a/nemoguardrails/library/privateai/actions.py b/nemoguardrails/library/privateai/actions.py index 814920c35..80d63c9ea 100644 --- a/nemoguardrails/library/privateai/actions.py +++ b/nemoguardrails/library/privateai/actions.py @@ -38,7 +38,12 @@ def mapping_detect_pii(result: bool) -> bool: @action(is_system_action=True, output_mapping=mapping_detect_pii) -async def detect_pii(source: str, text: str, config: RailsConfig): +async def detect_pii( + source: str, + text: str, + config: RailsConfig, + **kwargs, +): """Checks whether the provided text contains any PII. Args diff --git a/nemoguardrails/library/prompt_security/actions.py b/nemoguardrails/library/prompt_security/actions.py index eb1f6c605..9cbbc0a41 100644 --- a/nemoguardrails/library/prompt_security/actions.py +++ b/nemoguardrails/library/prompt_security/actions.py @@ -104,7 +104,7 @@ def mapping_protect_text(result: dict) -> bool: @action(is_system_action=True, output_mapping=mapping_protect_text) async def protect_text( - user_prompt: Optional[str] = None, bot_response: Optional[str] = None + user_prompt: Optional[str] = None, bot_response: Optional[str] = None, **kwargs ): """Protects the given user_prompt or bot_response. Args: diff --git a/nemoguardrails/library/self_check/facts/actions.py b/nemoguardrails/library/self_check/facts/actions.py index 177f6124f..3977d9bab 100644 --- a/nemoguardrails/library/self_check/facts/actions.py +++ b/nemoguardrails/library/self_check/facts/actions.py @@ -47,6 +47,7 @@ async def self_check_facts( context: Optional[dict] = None, llm: Optional[BaseLLM] = None, config: Optional[RailsConfig] = None, + **kwargs, ): """Checks the facts for the bot response by appropriately prompting the base llm.""" _MAX_TOKENS = 3 diff --git a/nemoguardrails/library/self_check/input_check/actions.py b/nemoguardrails/library/self_check/input_check/actions.py index c1ced0004..2fe9afc11 100644 --- a/nemoguardrails/library/self_check/input_check/actions.py +++ b/nemoguardrails/library/self_check/input_check/actions.py @@ -37,6 +37,7 @@ async def self_check_input( context: Optional[dict] = None, llm: Optional[BaseLLM] = None, config: Optional[RailsConfig] = None, + **kwargs, ): """Checks the input from the user. diff --git a/nemoguardrails/library/self_check/output_check/actions.py b/nemoguardrails/library/self_check/output_check/actions.py index 904d301f5..ca1254fbe 100644 --- a/nemoguardrails/library/self_check/output_check/actions.py +++ b/nemoguardrails/library/self_check/output_check/actions.py @@ -36,6 +36,7 @@ async def self_check_output( context: Optional[dict] = None, llm: Optional[BaseLLM] = None, config: Optional[RailsConfig] = None, + **kwargs, ): """Checks if the output from the bot. diff --git a/nemoguardrails/library/sensitive_data_detection/actions.py b/nemoguardrails/library/sensitive_data_detection/actions.py index afa9091db..27d5b8dd2 100644 --- a/nemoguardrails/library/sensitive_data_detection/actions.py +++ b/nemoguardrails/library/sensitive_data_detection/actions.py @@ -96,7 +96,12 @@ def mapping_detect_sensitive_data(result: bool) -> bool: @action(is_system_action=True, output_mapping=mapping_detect_sensitive_data) -async def detect_sensitive_data(source: str, text: str, config: RailsConfig): +async def detect_sensitive_data( + source: str, + text: str, + config: RailsConfig, + **kwargs, +): """Checks whether the provided text contains any sensitive data. Args diff --git a/nemoguardrails/library/topic_safety/actions.py b/nemoguardrails/library/topic_safety/actions.py index 358c1d65c..55021a282 100644 --- a/nemoguardrails/library/topic_safety/actions.py +++ b/nemoguardrails/library/topic_safety/actions.py @@ -36,6 +36,7 @@ async def topic_safety_check_input( model_name: Optional[str] = None, context: Optional[dict] = None, events: Optional[List[dict]] = None, + **kwargs, ) -> dict: _MAX_TOKENS = 10 user_input: str = "" From bc4fe8708532e7acdfc13ec7ff3ad55fc8bfb7fe Mon Sep 17 00:00:00 2001 From: prezakhani <13303554+Pouyanpi@users.noreply.github.com> Date: Sun, 2 Feb 2025 19:04:38 +0100 Subject: [PATCH 3/5] feat: add ActionMeta TypedDict and improve action decorator --- nemoguardrails/actions/actions.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/nemoguardrails/actions/actions.py b/nemoguardrails/actions/actions.py index a8729de8a..8b067f3c8 100644 --- a/nemoguardrails/actions/actions.py +++ b/nemoguardrails/actions/actions.py @@ -14,7 +14,14 @@ # limitations under the License. from dataclasses import dataclass, field -from typing import Any, Callable, List, Optional +from typing import Any, Callable, List, Optional, TypedDict, Union + + +class ActionMeta(TypedDict, total=False): + name: str + is_system_action: bool + execute_async: bool + output_mapping: Optional[Callable[[Any], bool]] def action( @@ -22,7 +29,7 @@ def action( name: Optional[str] = None, execute_async: bool = False, output_mapping: Optional[Callable[[Any], bool]] = None, -): +) -> Callable[[Union[Callable, type]], Union[Callable, type]]: """Decorator to mark a function or class as an action. Args: @@ -36,28 +43,22 @@ def action( callable: The decorated function or class. """ - def decorator(fn_or_cls): + def decorator(fn_or_cls: Union[Callable, type]) -> Union[Callable, type]: """Inner decorator function to add metadata to the action. Args: fn_or_cls: The function or class being decorated. """ + fn_or_cls_target = getattr(fn_or_cls, "__func__", fn_or_cls) - # Detect the decorator being applied to staticmethod or classmethod. - # Will annotate the the inner function in that case as otherwise - # metaclass will be giving us the unannotated enclosed function on - # attribute lookup. - if hasattr(fn_or_cls, "__func__"): - fn_or_cls_target = fn_or_cls.__func__ - else: - fn_or_cls_target = fn_or_cls - - fn_or_cls_target.action_meta = { + action_meta: ActionMeta = { "name": name or fn_or_cls.__name__, "is_system_action": is_system_action, "execute_async": execute_async, "output_mapping": output_mapping, } + + setattr(fn_or_cls_target, "action_meta", action_meta) return fn_or_cls return decorator From 608bb45634a8ea95b91c389e4cb150afade93db0 Mon Sep 17 00:00:00 2001 From: prezakhani <13303554+Pouyanpi@users.noreply.github.com> Date: Sun, 2 Feb 2025 19:04:59 +0100 Subject: [PATCH 4/5] test: refactor and add new tests for action decorator --- tests/test_actions.py | 74 +++++++++++++++++------------------- tests/test_actions_server.py | 59 ++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+), 39 deletions(-) create mode 100644 tests/test_actions_server.py diff --git a/tests/test_actions.py b/tests/test_actions.py index 89c2b623c..d5a1161a5 100644 --- a/tests/test_actions.py +++ b/tests/test_actions.py @@ -14,46 +14,42 @@ # limitations under the License. import pytest -from fastapi.testclient import TestClient - -from nemoguardrails.actions_server import actions_server - -client = TestClient(actions_server.app) - - -@pytest.mark.skip( - reason="Should only be run locally as it fetches data from wikipedia." -) -@pytest.mark.parametrize( - "action_name, action_parameters, result_field, status", - [ - ( - "action-test", - {"content": "Hello", "parameter": "parameters"}, - [], - "failed", - ), - ("Wikipedia", {"query": "president of US?"}, ["text"], "success"), - ], -) -def test_run(action_name, action_parameters, result_field, status): - response = client.post( - "/v1/actions/run", - json={ - "action_name": action_name, - "action_parameters": action_parameters, - }, - ) - assert response.status_code == 200 - res = response.json() - assert list(res["result"].keys()) == result_field - assert res["status"] == status +from nemoguardrails.actions.actions import ActionResult, action + + +def test_action_decorator(): + @action(is_system_action=True, name="test_action", execute_async=True) + def sample_action(): + return "test" + + assert hasattr(sample_action, "action_meta") + assert sample_action.action_meta["name"] == "test_action" + assert sample_action.action_meta["is_system_action"] is True + assert sample_action.action_meta["execute_async"] is True + +def test_action_decorator_with_output_mapping(): + def sample_output_mapping(result): + return result == "blocked" -def test_get_actions(): - response = client.get("/v1/actions/list") + @action(output_mapping=sample_output_mapping) + def sample_action(): + return "blocked" + + assert hasattr(sample_action, "action_meta") + assert sample_action.action_meta["output_mapping"] is not None + assert sample_action.action_meta["output_mapping"]("blocked") is True + assert sample_action.action_meta["output_mapping"]("not_blocked") is False + + +def test_action_result(): + result = ActionResult( + return_value="test_value", + events=[{"event": "test_event"}], + context_updates={"key": "value"}, + ) - # Check that we have at least one config - result = response.json() - assert len(result) >= 1 + assert result.return_value == "test_value" + assert result.events == [{"event": "test_event"}] + assert result.context_updates == {"key": "value"} diff --git a/tests/test_actions_server.py b/tests/test_actions_server.py new file mode 100644 index 000000000..89c2b623c --- /dev/null +++ b/tests/test_actions_server.py @@ -0,0 +1,59 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from fastapi.testclient import TestClient + +from nemoguardrails.actions_server import actions_server + +client = TestClient(actions_server.app) + + +@pytest.mark.skip( + reason="Should only be run locally as it fetches data from wikipedia." +) +@pytest.mark.parametrize( + "action_name, action_parameters, result_field, status", + [ + ( + "action-test", + {"content": "Hello", "parameter": "parameters"}, + [], + "failed", + ), + ("Wikipedia", {"query": "president of US?"}, ["text"], "success"), + ], +) +def test_run(action_name, action_parameters, result_field, status): + response = client.post( + "/v1/actions/run", + json={ + "action_name": action_name, + "action_parameters": action_parameters, + }, + ) + + assert response.status_code == 200 + res = response.json() + assert list(res["result"].keys()) == result_field + assert res["status"] == status + + +def test_get_actions(): + response = client.get("/v1/actions/list") + + # Check that we have at least one config + result = response.json() + assert len(result) >= 1 From 91737538bff905c723437777ceb52dfc02f91be1 Mon Sep 17 00:00:00 2001 From: prezakhani <13303554+Pouyanpi@users.noreply.github.com> Date: Mon, 3 Feb 2025 11:25:26 +0100 Subject: [PATCH 5/5] refactor: rename mapping functions for consistency --- nemoguardrails/library/activefence/actions.py | 4 ++-- nemoguardrails/library/autoalign/actions.py | 12 ++++++------ nemoguardrails/library/llama_guard/actions.py | 4 ++-- nemoguardrails/library/patronusai/actions.py | 8 ++++---- nemoguardrails/library/privateai/actions.py | 4 ++-- nemoguardrails/library/prompt_security/actions.py | 4 ++-- .../library/sensitive_data_detection/actions.py | 4 ++-- 7 files changed, 20 insertions(+), 20 deletions(-) diff --git a/nemoguardrails/library/activefence/actions.py b/nemoguardrails/library/activefence/actions.py index bf8da4eeb..3fafda552 100644 --- a/nemoguardrails/library/activefence/actions.py +++ b/nemoguardrails/library/activefence/actions.py @@ -25,7 +25,7 @@ log = logging.getLogger(__name__) -def mapping_call_activefence_api(result: dict) -> bool: +def call_activefence_api_mapping(result: dict) -> bool: """ Mapping for call_activefence_api. @@ -69,7 +69,7 @@ def mapping_call_activefence_api(result: dict) -> bool: return False -@action(is_system_action=True, output_mapping=mapping_call_activefence_api) +@action(is_system_action=True, output_mapping=call_activefence_api_mapping) async def call_activefence_api(text: Optional[str] = None, **kwargs): api_key = os.environ.get("ACTIVEFENCE_API_KEY") diff --git a/nemoguardrails/library/autoalign/actions.py b/nemoguardrails/library/autoalign/actions.py index 4578426ae..57dd79e2a 100644 --- a/nemoguardrails/library/autoalign/actions.py +++ b/nemoguardrails/library/autoalign/actions.py @@ -79,7 +79,7 @@ default_groundedness_config = {"groundedness_checker": {"verify_response": False}} -def mapping_autoalign_output_api(result: dict) -> bool: +def autoalign_output_api_mapping(result: dict) -> bool: """ Mapping for autoalign_output_api. @@ -89,7 +89,7 @@ def mapping_autoalign_output_api(result: dict) -> bool: return result.get("guardrails_triggered", False) -def mapping_autoalign_groundedness_output_api(result: float) -> bool: +def autoalign_groundedness_output_api_mapping(result: float) -> bool: """ Mapping for autoalign_groundedness_output_api. @@ -100,7 +100,7 @@ def mapping_autoalign_groundedness_output_api(result: float) -> bool: return result < DEFAULT_GROUNDEDNESS_THRESHOLD -def mapping_autoalign_factcheck_output_api(result: float) -> bool: +def autoalign_factcheck_output_api_mapping(result: float) -> bool: """ Mapping for autoalign_factcheck_output_api. @@ -318,7 +318,7 @@ async def autoalign_input_api( return autoalign_response -@action(name="autoalign_output_api", output_mapping=mapping_autoalign_output_api) +@action(name="autoalign_output_api", output_mapping=autoalign_output_api_mapping) async def autoalign_output_api( llm_task_manager: LLMTaskManager, context: Optional[dict] = None, @@ -355,7 +355,7 @@ async def autoalign_output_api( @action( name="autoalign_groundedness_output_api", - output_mapping=mapping_autoalign_groundedness_output_api, + output_mapping=autoalign_groundedness_output_api_mapping, ) async def autoalign_groundedness_output_api( llm_task_manager: LLMTaskManager, @@ -395,7 +395,7 @@ async def autoalign_groundedness_output_api( @action( name="autoalign_factcheck_output_api", - output_mapping=mapping_autoalign_factcheck_output_api, + output_mapping=autoalign_factcheck_output_api_mapping, ) async def autoalign_factcheck_output_api( llm_task_manager: LLMTaskManager, diff --git a/nemoguardrails/library/llama_guard/actions.py b/nemoguardrails/library/llama_guard/actions.py index 8e6733ec7..23502be05 100644 --- a/nemoguardrails/library/llama_guard/actions.py +++ b/nemoguardrails/library/llama_guard/actions.py @@ -83,7 +83,7 @@ async def llama_guard_check_input( return {"allowed": allowed, "policy_violations": policy_violations} -def mapping_llama_guard_check_output(result: dict) -> bool: +def llama_guard_check_output_mapping(result: dict) -> bool: """ Mapping for llama_guard_check_output. @@ -99,7 +99,7 @@ def mapping_llama_guard_check_output(result: dict) -> bool: return not allowed -@action(output_mapping=mapping_llama_guard_check_output) +@action(output_mapping=llama_guard_check_output_mapping) async def llama_guard_check_output( llm_task_manager: LLMTaskManager, context: Optional[dict] = None, diff --git a/nemoguardrails/library/patronusai/actions.py b/nemoguardrails/library/patronusai/actions.py index 6e0836766..19fe128ea 100644 --- a/nemoguardrails/library/patronusai/actions.py +++ b/nemoguardrails/library/patronusai/actions.py @@ -60,7 +60,7 @@ def parse_patronus_lynx_response( return hallucination, reasoning -def mapping_patronus_lynx_check_output_hallucination(result: dict) -> bool: +def patronus_lynx_check_output_hallucination_mapping(result: dict) -> bool: """ Mapping for patronus_lynx_check_output_hallucination. @@ -72,7 +72,7 @@ def mapping_patronus_lynx_check_output_hallucination(result: dict) -> bool: return result.get("hallucination", False) -@action(output_mapping=mapping_patronus_lynx_check_output_hallucination) +@action(output_mapping=patronus_lynx_check_output_hallucination_mapping) async def patronus_lynx_check_output_hallucination( llm_task_manager: LLMTaskManager, context: Optional[dict] = None, @@ -228,7 +228,7 @@ async def patronus_evaluate_request( return response_json -def mapping_patronus_api_check_output(result: dict) -> bool: +def patronus_api_check_output_mapping(result: dict) -> bool: """ Mapping for patronus_api_check_output. @@ -243,7 +243,7 @@ def mapping_patronus_api_check_output(result: dict) -> bool: @action( - name="patronus_api_check_output", output_mapping=mapping_patronus_api_check_output + name="patronus_api_check_output", output_mapping=patronus_api_check_output_mapping ) async def patronus_api_check_output( llm_task_manager: LLMTaskManager, diff --git a/nemoguardrails/library/privateai/actions.py b/nemoguardrails/library/privateai/actions.py index 80d63c9ea..1fa21e286 100644 --- a/nemoguardrails/library/privateai/actions.py +++ b/nemoguardrails/library/privateai/actions.py @@ -27,7 +27,7 @@ log = logging.getLogger(__name__) -def mapping_detect_pii(result: bool) -> bool: +def detect_pii_mapping(result: bool) -> bool: """ Mapping for detect_pii. @@ -37,7 +37,7 @@ def mapping_detect_pii(result: bool) -> bool: return result -@action(is_system_action=True, output_mapping=mapping_detect_pii) +@action(is_system_action=True, output_mapping=detect_pii_mapping) async def detect_pii( source: str, text: str, diff --git a/nemoguardrails/library/prompt_security/actions.py b/nemoguardrails/library/prompt_security/actions.py index 9cbbc0a41..d5e1de240 100644 --- a/nemoguardrails/library/prompt_security/actions.py +++ b/nemoguardrails/library/prompt_security/actions.py @@ -87,7 +87,7 @@ async def ps_protect_api_async( } -def mapping_protect_text(result: dict) -> bool: +def protect_text_mapping(result: dict) -> bool: """ Mapping for protect_text action. @@ -102,7 +102,7 @@ def mapping_protect_text(result: dict) -> bool: return blocked -@action(is_system_action=True, output_mapping=mapping_protect_text) +@action(is_system_action=True, output_mapping=protect_text_mapping) async def protect_text( user_prompt: Optional[str] = None, bot_response: Optional[str] = None, **kwargs ): diff --git a/nemoguardrails/library/sensitive_data_detection/actions.py b/nemoguardrails/library/sensitive_data_detection/actions.py index 27d5b8dd2..8bd6748da 100644 --- a/nemoguardrails/library/sensitive_data_detection/actions.py +++ b/nemoguardrails/library/sensitive_data_detection/actions.py @@ -85,7 +85,7 @@ def _get_ad_hoc_recognizers(sdd_config: SensitiveDataDetection): return ad_hoc_recognizers -def mapping_detect_sensitive_data(result: bool) -> bool: +def detect_sensitive_data_mapping(result: bool) -> bool: """ Mapping for detect_sensitive_data. @@ -95,7 +95,7 @@ def mapping_detect_sensitive_data(result: bool) -> bool: return result -@action(is_system_action=True, output_mapping=mapping_detect_sensitive_data) +@action(is_system_action=True, output_mapping=detect_sensitive_data_mapping) async def detect_sensitive_data( source: str, text: str,