diff --git a/docs/user_guides/community/privateai.md b/docs/user_guides/community/privateai.md new file mode 100644 index 000000000..b305d7d53 --- /dev/null +++ b/docs/user_guides/community/privateai.md @@ -0,0 +1,61 @@ +# Private AI Integration + +[Private AI](https://docs.private-ai.com/?utm_medium=github&utm_campaign=nemo-guardrails) allows you to detect and mask Personally Identifiable Information (PII) in your data. This integration enables NeMo Guardrails to use Private AI for PII detection in input, output and retrieval flows. + +## Setup + +1. Ensure that you have access to Private AI API server running locally or in the cloud. To get started with the cloud version, you can use the [Private AI Portal](https://portal.private-ai.com/?utm_medium=github&utm_campaign=nemo-guardrails). For containerized deployments, check out this [Quickstart Guide](https://docs.private-ai.com/quickstart/?utm_medium=github&utm_campaign=nemo-guardrails). + +2. Update your `config.yml` file to include the Private AI settings: + +```yaml +rails: + config: + privateai: + server_endpoint: http://your-privateai-api-endpoint/process/text # Replace this with your Private AI process text endpoint + input: + entities: # If no entity is specified here, all supported entities will be detected by default. + - NAME_FAMILY + - LOCATION_ADDRESS_STREET + - EMAIL_ADDRESS + output: + entities: + - NAME_FAMILY + - LOCATION_ADDRESS_STREET + - EMAIL_ADDRESS + input: + flows: + - detect pii on input + output: + flows: + - detect pii on output +``` + +Replace `http://your-privateai-api-endpoint/process/text` with your actual Private AI process text endpoint and set the `PAI_API_KEY` environment variable if you're using the Private AI cloud API. + +3. You can customize the `entities` list under both `input` and `output` to include the PII types you want to detect. A full list of supported entities can be found [here](https://docs.private-ai.com/entities/?utm_medium=github&utm_campaign=nemo-guardrails). + +## Usage + +Once configured, the Private AI integration will automatically: + +1. Detect PII in user inputs before they are processed by the LLM. +2. Detect PII in LLM outputs before they are sent back to the user. +3. Detect PII in retrieved chunks before they are sent to the LLM. + +The `detect_pii` action in `nemoguardrails/library/privateai/actions.py` handles the PII detection process. + +## Customization + +You can customize the PII detection behavior by modifying the `entities` lists in the `config.yml` file. Refer to the Private AI documentation for a complete list of [supported entity types](https://docs.private-ai.com/entities/?utm_medium=github&utm_campaign=nemo-guardrails). + +## Error Handling + +If the Private AI detection API request fails, the system will assume PII is present as a precautionary measure. + +## Notes + +- Ensure that your Private AI process text endpoint is properly set up and accessible from your NeMo Guardrails environment. +- The integration currently supports PII detection only. + +For more information on Private AI and its capabilities, please refer to the [Private AI documentation](https://docs.private-ai.com/?utm_medium=github&utm_campaign=nemo-guardrails). diff --git a/docs/user_guides/guardrails-library.md b/docs/user_guides/guardrails-library.md index 17170c9a3..578c26a83 100644 --- a/docs/user_guides/guardrails-library.md +++ b/docs/user_guides/guardrails-library.md @@ -22,6 +22,7 @@ NeMo Guardrails comes with a library of built-in guardrails that you can easily - [AutoAlign](#autoalign) - [Cleanlab Trustworthiness Score](#cleanlab) - [GCP Text Moderation](#gcp-text-moderation) + - [Private AI PII detection](#private-ai-pii-detection) - OpenAI Moderation API - *[COMING SOON]* 4. Other @@ -670,6 +671,46 @@ rails: For more details, check out the [GCP Text Moderation](./community/gcp-text-moderations.md) page. +### Private AI PII Detection + +NeMo Guardrails supports using [Private AI API](https://docs.private-ai.com/?utm_medium=github&utm_campaign=nemo-guardrails) for PII detection in input, output and retrieval flows. + +To activate the PII detection, you need specify `server_endpoint`, and the entities that you want to detect. You'll also need to set the `PAI_API_KEY` environment variable if you're using the Private AI cloud API. + +```yaml +rails: + config: + privateai: + server_endpoint: http://your-privateai-api-endpoint/process/text # Replace this with your Private AI process text endpoint + input: + entities: # If no entity is specified here, all supported entities will be detected by default. + - NAME_FAMILY + - EMAIL_ADDRESS + ... + output: + entities: + - NAME_FAMILY + - EMAIL_ADDRESS + ... +``` + +#### Example usage + +```yaml +rails: + input: + flows: + - detect pii on input + output: + flows: + - detect pii on output + retrieval: + flows: + - detect pii on retrieval +``` + +For more details, check out the [Private AI Integration](./community/privateai.md) page. + ## Other ### Jailbreak Detection Heuristics diff --git a/examples/configs/privateai/README.md b/examples/configs/privateai/README.md new file mode 100644 index 000000000..0b3d423f5 --- /dev/null +++ b/examples/configs/privateai/README.md @@ -0,0 +1,11 @@ +# Private AI Configuration Example + +This example contains configuration files for using Private AI in your NeMo Guardrails project. + +For more details on the Private AI integration, see [Private AI Integration User Guide](../../../docs/user_guides/community/privateai.md). + +## Structure + +The Private AI configuration example is organized as follows: + +1. [pii_detection](./pii_detection) - Configuration for using Private AI for PII detection. diff --git a/examples/configs/privateai/pii_detection/config.yml b/examples/configs/privateai/pii_detection/config.yml new file mode 100644 index 000000000..094aa5339 --- /dev/null +++ b/examples/configs/privateai/pii_detection/config.yml @@ -0,0 +1,26 @@ +models: + - type: main + engine: openai + model: gpt-3.5-turbo-instruct + +rails: + config: + privateai: + server_endpoint: https://api.private-ai.com/cloud/v3/process/text + input: + entities: + - NAME_FAMILY + - LOCATION_ADDRESS_STREET + - EMAIL_ADDRESS + output: + entities: # If no entity is specified here, all supported entities will be detected by default. + - NAME_FAMILY + - LOCATION_ADDRESS_STREET + - EMAIL_ADDRESS + input: + flows: + - detect pii on input + + output: + flows: + - detect pii on output diff --git a/examples/notebooks/privateai_pii_detection.ipynb b/examples/notebooks/privateai_pii_detection.ipynb new file mode 100644 index 000000000..40f2692fe --- /dev/null +++ b/examples/notebooks/privateai_pii_detection.ipynb @@ -0,0 +1,186 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Private AI PII detection example\n", + "\n", + "This notebook shows how to use Private AI for PII detection in NeMo Guardrails." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Import libraries" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import nest_asyncio\n", + "\n", + "nest_asyncio.apply()" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "from nemoguardrails import LLMRails, RailsConfig" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create rails with Private AI PII detection\n", + "\n", + "For this step you'll need your OpenAI API key & Private AI API key.\n", + "\n", + "You can get your Private AI API key by signing up on the [Private AI Portal](https://portal.private-ai.com). For more details on Private AI integration, check out this [user guide](../../docs/user_guides/community/privateai.md).\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "os.environ[\"PAI_API_KEY\"] = \"YOUR PRIVATE AI API KEY\" # Visit https://portal.private-ai.com to get your API key\n", + "\n", + "YAML_CONFIG = \"\"\"\n", + "models:\n", + " - type: main\n", + " engine: openai\n", + " model: gpt-3.5-turbo-instruct\n", + "\n", + "rails:\n", + " config:\n", + " privateai:\n", + " server_endpoint: https://api.private-ai.com/cloud/v3/process/text\n", + " input:\n", + " entities:\n", + " - NAME_FAMILY\n", + " - LOCATION_ADDRESS_STREET\n", + " - EMAIL_ADDRESS\n", + " output:\n", + " entities:\n", + " - NAME_FAMILY\n", + " - LOCATION_ADDRESS_STREET\n", + " - EMAIL_ADDRESS\n", + " input:\n", + " flows:\n", + " - detect pii on input\n", + "\n", + " output:\n", + " flows:\n", + " - detect pii on output\n", + "\"\"\"\n", + "\n", + "\n", + "\n", + "config = RailsConfig.from_content(yaml_content=YAML_CONFIG)\n", + "rails = LLMRails(config)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Input rails" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "response = rails.generate(messages=[{\"role\": \"user\", \"content\": \"Hello! I'm John. My email id is text@gmail.com. I live in California, USA.\"}])\n", + "\n", + "info = rails.explain()\n", + "\n", + "print(\"Response\")\n", + "print(\"----------------------------------------\")\n", + "print(response[\"content\"])\n", + "\n", + "\n", + "print(\"\\n\\nColang history\")\n", + "print(\"----------------------------------------\")\n", + "print(info.colang_history)\n", + "\n", + "print(\"\\n\\nLLM calls summary\")\n", + "print(\"----------------------------------------\")\n", + "info.print_llm_calls_summary()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Output rails" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "response = rails.generate(messages=[{\"role\": \"user\", \"content\": \"give me a sample email id\"}])\n", + "\n", + "info = rails.explain()\n", + "\n", + "print(\"Response\")\n", + "print(\"----------------------------------------\\n\\n\")\n", + "print(response[\"content\"])\n", + "\n", + "\n", + "print(\"\\n\\nColang history\")\n", + "print(\"----------------------------------------\")\n", + "print(info.colang_history)\n", + "\n", + "print(\"\\n\\nLLM calls summary\")\n", + "print(\"----------------------------------------\")\n", + "info.print_llm_calls_summary()\n", + "\n", + "\n", + "print(\"\\n\\nCompletions where PII was detected!\")\n", + "print(\"----------------------------------------\")\n", + "print(info.llm_calls[0].completion)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "nemo", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.15" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/nemoguardrails/library/privateai/__init__.py b/nemoguardrails/library/privateai/__init__.py new file mode 100644 index 000000000..9ba9d4310 --- /dev/null +++ b/nemoguardrails/library/privateai/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemoguardrails/library/privateai/actions.py b/nemoguardrails/library/privateai/actions.py new file mode 100644 index 000000000..ade2e3abc --- /dev/null +++ b/nemoguardrails/library/privateai/actions.py @@ -0,0 +1,66 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PII detection using Private AI.""" + +import logging +import os + +from nemoguardrails import RailsConfig +from nemoguardrails.actions import action +from nemoguardrails.library.privateai.request import private_ai_detection_request +from nemoguardrails.rails.llm.config import PrivateAIDetection + +log = logging.getLogger(__name__) + + +@action(is_system_action=True) +async def detect_pii(source: str, text: str, config: RailsConfig): + """Checks whether the provided text contains any PII. + + Args + source: The source for the text, i.e. "input", "output", "retrieval". + text: The text to check. + config: The rails configuration object. + + Returns + True if PII is detected, False otherwise. + """ + + pai_config: PrivateAIDetection = getattr(config.rails.config, "privateai") + pai_api_key = os.environ.get("PAI_API_KEY") + server_endpoint = pai_config.server_endpoint + enabled_entities = getattr(pai_config, source).entities + + if "api.private-ai.com" in server_endpoint and not pai_api_key: + raise ValueError( + "PAI_API_KEY environment variable required for Private AI cloud API." + ) + + valid_sources = ["input", "output", "retrieval"] + if source not in valid_sources: + raise ValueError( + f"Private AI can only be defined in the following flows: {valid_sources}. " + f"The current flow, '{source}', is not allowed." + ) + + entity_detected = await private_ai_detection_request( + text, + enabled_entities, + server_endpoint, + pai_api_key, + ) + + return entity_detected diff --git a/nemoguardrails/library/privateai/flows.co b/nemoguardrails/library/privateai/flows.co new file mode 100644 index 000000000..04465deba --- /dev/null +++ b/nemoguardrails/library/privateai/flows.co @@ -0,0 +1,34 @@ +# INPUT RAILS + +@active +flow detect pii on input + """Check if the user input has PII.""" + $has_pii = await DetectPiiAction(source="input", text=$user_message) + + if $has_pii + bot inform answer unknown + abort + + +# INPUT RAILS + +@active +flow detect pii on output + """Check if the bot output has PII.""" + $has_pii = await DetectPiiAction(source="output", text=$bot_message) + + if $has_pii + bot inform answer unknown + abort + + +# RETRIVAL RAILS + +@active +flow detect pii on retrieval + """Check if the relevant chunks from the knowledge base have any PII.""" + $has_pii = await DetectPiiAction(source="retrieval", text=$relevant_chunks) + + if $has_pii + bot inform answer unknown + abort diff --git a/nemoguardrails/library/privateai/flows.v1.co b/nemoguardrails/library/privateai/flows.v1.co new file mode 100644 index 000000000..a7e4fca55 --- /dev/null +++ b/nemoguardrails/library/privateai/flows.v1.co @@ -0,0 +1,31 @@ +# INPUT RAILS + +define subflow detect pii on input + """Check if the user input has PII.""" + $has_pii = execute detect_pii(source="input", text=$user_message) + + if $has_pii + bot inform answer unknown + stop + + +# INPUT RAILS + +define subflow detect pii on output + """Check if the bot output has PII.""" + $has_pii = execute detect_pii(source="output", text=$bot_message) + + if $has_pii + bot inform answer unknown + stop + + +# RETRIVAL RAILS + +define subflow detect pii on retrieval + """Check if the relevant chunks from the knowledge base have any PII.""" + $has_pii = execute detect_pii(source="retrieval", text=$relevant_chunks) + + if $has_pii + bot inform answer unknown + stop diff --git a/nemoguardrails/library/privateai/request.py b/nemoguardrails/library/privateai/request.py new file mode 100644 index 000000000..9662c8856 --- /dev/null +++ b/nemoguardrails/library/privateai/request.py @@ -0,0 +1,76 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module for handling Private AI detection requests.""" + +import json +import logging +from typing import Any, Dict, List, Optional + +import aiohttp + +log = logging.getLogger(__name__) + + +async def private_ai_detection_request( + text: str, + enabled_entities: List[str], + server_endpoint: str, + api_key: Optional[str] = None, +): + """ + Send a detection request to the Private AI API. + + Args: + text: The text to analyze. + enabled_entities: List of entity types to detect. + server_endpoint: The API endpoint URL. + api_key: The API key for the Private AI service. + + Returns: + True if PII is detected, False otherwise. + """ + if "api.private-ai.com" in server_endpoint and not api_key: + raise ValueError("'api_key' is required for Private AI cloud API.") + + payload: Dict[str, Any] = { + "text": [text], + "link_batch": False, + "entity_detection": {"accuracy": "high_automatic", "return_entity": False}, + } + + headers: Dict[str, str] = { + "Content-Type": "application/json", + } + + if api_key: + headers["x-api-key"] = api_key + + if enabled_entities: + payload["entity_detection"]["entity_types"] = [ + {"type": "ENABLE", "value": enabled_entities} + ] + + async with aiohttp.ClientSession() as session: + async with session.post(server_endpoint, json=payload, headers=headers) as resp: + if resp.status != 200: + raise ValueError( + f"Private AI call failed with status code {resp.status}.\n" + f"Details: {await resp.text()}" + ) + + result = await resp.json() + + return any(res["entities_present"] for res in result) diff --git a/nemoguardrails/rails/llm/config.py b/nemoguardrails/rails/llm/config.py index c6294cedc..e58b0ec33 100644 --- a/nemoguardrails/rails/llm/config.py +++ b/nemoguardrails/rails/llm/config.py @@ -123,6 +123,36 @@ class SensitiveDataDetection(BaseModel): ) +class PrivateAIDetectionOptions(BaseModel): + """Configuration options for Private AI.""" + + entities: List[str] = Field( + default_factory=list, + description="The list of entities that should be detected.", + ) + + +class PrivateAIDetection(BaseModel): + """Configuration for Private AI.""" + + server_endpoint: str = Field( + default="http://localhost:8080/process/text", + description="The endpoint for the private AI detection server.", + ) + input: PrivateAIDetectionOptions = Field( + default_factory=PrivateAIDetectionOptions, + description="Configuration of the entities to be detected on the user input.", + ) + output: PrivateAIDetectionOptions = Field( + default_factory=PrivateAIDetectionOptions, + description="Configuration of the entities to be detected on the bot output.", + ) + retrieval: PrivateAIDetectionOptions = Field( + default_factory=PrivateAIDetectionOptions, + description="Configuration of the entities to be detected on retrieved relevant chunks.", + ) + + class MessageTemplate(BaseModel): """Template for a message structure.""" @@ -395,6 +425,11 @@ class RailsConfigData(BaseModel): description="Configuration for jailbreak detection.", ) + privateai: Optional[PrivateAIDetection] = Field( + default_factory=PrivateAIDetection, + description="Configuration for Private AI.", + ) + class Rails(BaseModel): """Configuration of specific rails.""" diff --git a/tests/test_privateai.py b/tests/test_privateai.py new file mode 100644 index 000000000..4d127d6b3 --- /dev/null +++ b/tests/test_privateai.py @@ -0,0 +1,259 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from nemoguardrails import RailsConfig +from nemoguardrails.actions.actions import ActionResult, action +from tests.utils import TestChat + + +@action() +def retrieve_relevant_chunks(): + context_updates = {"relevant_chunks": "Mock retrieved context."} + + return ActionResult( + return_value=context_updates["relevant_chunks"], + context_updates=context_updates, + ) + + +def mock_detect_pii(return_value=True): + def mock_request(*args, **kwargs): + return return_value + + return mock_request + + +@pytest.mark.unit +def test_privateai_pii_detection_no_active_pii_detection(): + config = RailsConfig.from_content( + yaml_content=""" + models: [] + rails: + config: + privateai: + server_endpoint: https://api.private-ai.com/cloud/v3/process/text + """, + colang_content=""" + define user express greeting + "hi" + + define flow + user express greeting + bot express greeting + + define bot inform answer unknown + "I can't answer that." + """, + ) + + chat = TestChat( + config, + llm_completions=[ + " express greeting", + ' "Hi! My name is John as well."', + ], + ) + + chat.app.register_action(retrieve_relevant_chunks, "retrieve_relevant_chunks") + chat.app.register_action(mock_detect_pii(True), "detect_pii") + chat >> "Hi! I am Mr. John! And my email is test@gmail.com" + chat << "Hi! My name is John as well." + + +@pytest.mark.unit +def test_privateai_pii_detection_input(): + config = RailsConfig.from_content( + yaml_content=""" + models: [] + rails: + config: + privateai: + server_endpoint: https://api.private-ai.com/cloud/v3/process/text + input: + entities: + - EMAIL_ADDRESS + - NAME + input: + flows: + - detect pii on input + """, + colang_content=""" + define user express greeting + "hi" + + define flow + user express greeting + bot express greeting + + define bot inform answer unknown + "I can't answer that." + """, + ) + + chat = TestChat( + config, + llm_completions=[ + " express greeting", + ' "Hi! My name is John as well."', + ], + ) + + chat.app.register_action(retrieve_relevant_chunks, "retrieve_relevant_chunks") + chat.app.register_action(mock_detect_pii(True), "detect_pii") + chat >> "Hi! I am Mr. John! And my email is test@gmail.com" + chat << "I can't answer that." + + +@pytest.mark.unit +def test_privateai_pii_detection_output(): + config = RailsConfig.from_content( + yaml_content=""" + models: [] + rails: + config: + privateai: + server_endpoint: https://api.private-ai.com/cloud/v3/process/text + output: + entities: + - EMAIL_ADDRESS + - NAME + output: + flows: + - detect pii on output + """, + colang_content=""" + define user express greeting + "hi" + + define flow + user express greeting + bot express greeting + + define bot inform answer unknown + "I can't answer that." + """, + ) + + chat = TestChat( + config, + llm_completions=[ + " express greeting", + ' "Hi! My name is John as well."', + ], + ) + + chat.app.register_action(retrieve_relevant_chunks, "retrieve_relevant_chunks") + chat.app.register_action(mock_detect_pii(True), "detect_pii") + chat >> "Hi!" + chat << "I can't answer that." + + +@pytest.mark.skip(reason="This test needs refinement.") +@pytest.mark.unit +def test_privateai_pii_detection_retrieval_with_pii(): + # TODO: @pouyanpi and @letmerecall: Find an alternative approach to test this functionality. + config = RailsConfig.from_content( + yaml_content=""" + models: [] + rails: + config: + privateai: + server_endpoint: https://api.private-ai.com/cloud/v3/process/text + retrieval: + entities: + - EMAIL_ADDRESS + - NAME + retrieval: + flows: + - detect pii on retrieval + """, + colang_content=""" + define user express greeting + "hi" + + define flow + user express greeting + bot express greeting + + define bot inform answer unknown + "I can't answer that." + """, + ) + + chat = TestChat( + config, + llm_completions=[ + " express greeting", + ' "Hi! My name is John as well."', + ], + ) + + chat.app.register_action(retrieve_relevant_chunks, "retrieve_relevant_chunks") + chat.app.register_action(mock_detect_pii(True), "detect_pii") + + # When the relevant_chunks has_pii, a bot intent will get invoked via (bot inform answer unknown), which in turn + # will invoke retrieve_relevant_chunks action. + # With a mocked retrieve_relevant_chunks always returning something & mocked detect_pii always returning True, + # the process goes in an infinite loop and raises an Exception: Too many events. + with pytest.raises(Exception, match="Too many events."): + chat >> "Hi!" + chat << "I can't answer that." + + +@pytest.mark.unit +def test_privateai_pii_detection_retrieval_with_no_pii(): + config = RailsConfig.from_content( + yaml_content=""" + models: [] + rails: + config: + privateai: + server_endpoint: https://api.private-ai.com/cloud/v3/process/text + retrieval: + entities: + - EMAIL_ADDRESS + - NAME + retrieval: + flows: + - detect pii on retrieval + """, + colang_content=""" + define user express greeting + "hi" + + define flow + user express greeting + bot express greeting + + define bot inform answer unknown + "I can't answer that." + """, + ) + + chat = TestChat( + config, + llm_completions=[ + " express greeting", + ' "Hi! My name is John as well."', + ], + ) + + chat.app.register_action(retrieve_relevant_chunks, "retrieve_relevant_chunks") + chat.app.register_action(mock_detect_pii(False), "detect_pii") + + chat >> "Hi!" + chat << "Hi! My name is John as well."