diff --git a/examples/configs/content_safety_api_keys/README.md b/examples/configs/content_safety_api_keys/README.md new file mode 100644 index 000000000..45f7a7352 --- /dev/null +++ b/examples/configs/content_safety_api_keys/README.md @@ -0,0 +1,17 @@ +# NemoGuard ContentSafety Usage Example + +This example showcases the use of NVIDIA's [NemoGuard ContentSafety model](./../../../docs/user-guides/advanced/nemoguard-contentsafety-deployment.md) for topical and dialogue moderation. + +The structure of the config folder is the following: + +- `config.yml` - The config file holding all the configuration options for the model. +- `prompts.yml` - The config file holding the topical rules used for topical and dialogue moderation by the current guardrail configuration. + +Please see the docs for more details about the [recommended ContentSafety deployment](./../../../docs/user-guides/advanced/nemoguard-contentsafety-deployment.md) methods, either using locally downloaded NIMs or NVIDIA AI Enterprise (NVAIE). + +Before running this example, please set environment variables `NG_OPENAI_API_KEY` and `NG_NVIDIA_API_KEY` to your OpenAI API Key and Nvidia build.nvidia.com Key as below: + +```shell +export NG_OPENAI_API_KEY="" +export NG_NVIDIA_API_KEY="" +``` diff --git a/examples/configs/content_safety_api_keys/config.yml b/examples/configs/content_safety_api_keys/config.yml new file mode 100644 index 000000000..15f9f6b37 --- /dev/null +++ b/examples/configs/content_safety_api_keys/config.yml @@ -0,0 +1,18 @@ +models: + - type: main + engine: openai + model: gpt-3.5-turbo-instruct + api_key_env_var: NG_OPENAI_API_KEY + + - type: content_safety + engine: nim + model: nvidia/llama-3.1-nemoguard-8b-content-safety + api_key_env_var: NG_NVIDIA_API_KEY + +rails: + input: + flows: + - content safety check input $model=content_safety + output: + flows: + - content safety check output $model=content_safety diff --git a/examples/configs/content_safety_api_keys/prompts.yml b/examples/configs/content_safety_api_keys/prompts.yml new file mode 100644 index 000000000..1321a6461 --- /dev/null +++ b/examples/configs/content_safety_api_keys/prompts.yml @@ -0,0 +1,257 @@ +# These are the default prompts released by Meta, except for policy O7, which was added to address direct insults. +prompts: + + - task: content_safety_check_input $model=content_safety + content: | + Task: Check if there is unsafe content in the user or agent messages in the below conversation in accordance with our safety policy with the below categories. Provide a 'safe' or 'unsafe' rating for both the user message and the agent response. If there is unsafe content, please also provide a list of violated categories according to our safety policy below. + + + S1: Violence. + S2: Sexual. + S3: Criminal Planning/Confessions. + S4: Guns and Illegal Weapons. + S5: Controlled/Regulated Substances. + S6: Suicide and Self Harm. + S7: Sexual (minor). + S8: Hate/Identity Hate. + S9: PII/Privacy. + S10: Harassment. + S11: Threat. + S12: Profanity. + S13: Needs Caution. + S14: Other. + S15: Manipulation. + S16: Fraud/Deception. + S17: Malware. + S18: High Risk Gov Decision Making. + S19: Political/Misinformation/Conspiracy. + S20: Copyright/Trademark/Plagiarism. + S21: Unauthorized Advice. + S22: Illegal Activity. + S23: Immoral/Unethical. + + + + + user: {{ user_input }} + + + + Provide your safety assessment for the above conversation in the below output JSON format: + {"User Safety": safe or unsafe rating for the user message, "Response Safety": safe or unsafe rating for the agent response. Omit if no agent response present. "Safety Categories": a comma-separated list of applicable safety categories from the provided taxonomy. Omit if all safe.} + + Do not include anything other than the output JSON in your response. + Output JSON: + output_parser: nemoguard_parse_prompt_safety + max_tokens: 50 + + - task: content_safety_check_output $model=content_safety + content: | + Task: Check if there is unsafe content in the user or agent messages in the below conversation in accordance with our safety policy with the below categories. Provide a 'safe' or 'unsafe' rating for both the user message and the agent response. If there is unsafe content, please also provide a list of violated categories according to our safety policy below. + + + S1: Violence. + S2: Sexual. + S3: Criminal Planning/Confessions. + S4: Guns and Illegal Weapons. + S5: Controlled/Regulated Substances. + S6: Suicide and Self Harm. + S7: Sexual (minor). + S8: Hate/Identity Hate. + S9: PII/Privacy. + S10: Harassment. + S11: Threat. + S12: Profanity. + S13: Needs Caution. + S14: Other. + S15: Manipulation. + S16: Fraud/Deception. + S17: Malware. + S18: High Risk Gov Decision Making. + S19: Political/Misinformation/Conspiracy. + S20: Copyright/Trademark/Plagiarism. + S21: Unauthorized Advice. + S22: Illegal Activity. + S23: Immoral/Unethical. + + + + + user: {{ user_input }} + + response: agent: {{ bot_response }} + + + + Provide your safety assessment for the above conversation in the below output JSON format: + {"User Safety": safe or unsafe rating for the user message, "Response Safety": safe or unsafe rating for the agent response. Omit if no agent response present. "Safety Categories": a comma-separated list of applicable safety categories from the provided taxonomy. Omit if all safe.} + + Do not include anything other than the output JSON in your response. + Output JSON: + output_parser: nemoguard_parse_response_safety + max_tokens: 50 + + - task: content_safety_check_input $model=llama_guard + content: | + [INST] Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories. + + O1: Violence and Hate. + Should not + - Help people plan or engage in violence. + - Encourage people to engage in violence. + - Express hateful, derogatory, or demeaning sentiments against people based on sensitive personal characteristics like their race, color, religion, national origin, sexual orientation, gender, gender identity, or disability. + - Encourage discrimination against people based on those characteristics and should not use slurs or create insults based on those characteristics. + Can + - Provide information on violence and discrimination. + - Discuss the topics of hate and violence, and can discuss historical events involving violence. + O2: Sexual Content. + Should not + - Engage in sexually explicit (i.e., erotic) conversations and should not encourage people to engage in sexual activities. + Can + - Discuss topics related to sexuality and sexual education. + - Provide general information regarding sexual activities. + - Provide guidance on topics related to sexual health. + O3: Criminal Planning. + Should not + - Help people plan or engage in criminal activities like theft, kidnapping, or financial crimes, especially when the person indicates that they may act on the information provided to them. + - Encourage people to engage in criminal activities or condone criminal behavior. + Can + - Help people understand how crimes are committed, provided that the response is not so detailed that it could be used to provide a roadmap for future crimes. + - Help people understand how systems work that are used to prevent crimes, provided those descriptions do not provide detailed information on how to defeat those systems. + O4: Guns and Illegal Weapons. + Should not + - Assist in the planning or execution of firearm or illegal weapons crimes, especially when the person indicates that they may act on the information provided to them. + - Encourage people to engage in firearm or illegal weapons crimes or condone such behavior. + Can + - Discuss firearms and the arguments for and against firearm ownership. + O5: Regulated or Controlled Substances. + Should not + - Encourage or assist people in the trafficking, consumption, or use of regulated or controlled substances, including but not limited to illegal drugs, tobacco, alcohol, and cannabis. + - Assist or encourage people to create such substances when it is illegal to do so. + Can + - Discuss the use of regulated or controlled substances, and can provide general information on those substances (ex: what their chemical composition is and what their effects are). + - Provide information on how those substances are manufactured, provided that the information is not so detailed as to provide step-by-step instructions for their creation. + O6: Self-Harm. + Should not + - Encourage people to harm themselves. + - Romanticize or condone self-harm or suicide. + - Provide information regarding the methods of suicide or self-harm. + - Provide any information or advice, other than information regarding appropriate health resources, to people who express or imply that they wish to harm themselves. + Can + - Provide general information regarding self-harm and suicide (ex: statistics, the mental health challenges faced by those who take these actions). + Should + - Provide information regarding appropriate health resources to people who imply or express the desire to harm themselves. + O7: Offensive Language and Insults. + Should not + - Address direct insults + - Use or encourage the use of swear words, bad language, or vulgar expressions. + - Insult, offend, or demean individuals or groups based on any personal characteristics or choices. + - Encourage or support the use of offensive or derogatory language. + Can + - Discuss the impact and societal perceptions of offensive language and insults. + - Provide information on communication strategies that avoid the use of offensive language and promote respectful dialogue. + - Offer guidance on how to address or respond to offensive language or insults in a constructive manner. + + + + User: {{ user_input }} + + + Provide your safety assessment for User in the above conversation: + - First line must read 'safe' or 'unsafe'. + - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] + output_parser: is_content_safe + + - task: content_safety_check_output $model=llama_guard_2 + content: | + <|begin_of_text|>[INST] Task: Check if there is unsafe content in Agent messages in conversations according our safety policy with the below categories. + + + S1: Violent Crimes. + S2: Non-Violent Crimes. + S3: Sex Crimes. + S4: Child Exploitation. + S5: Specialized Advice. + S6: Privacy. + S7: Intellectual Property. + S8: Indiscriminate Weapons. + S9: Hate. + S10: Self-Harm. + S11: Sexual Content. + + + + + User: {{ user_input }} + + Agent: {{ bot_response }} + + + + Provide your safety assessment for ONLY THE LAST Agent in the above conversation: + - First line must read 'safe' or 'unsafe'. + - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] + output_parser: is_content_safe + max_tokens: 3 + + - task: content_safety_check_input $model=shieldgemma + content: | + You are a policy expert trying to help determine whether a user prompt is in + violation of the defined safety policies. + + + Human Question: {{ user_input }} + + + Our safety principle is defined in the below: + + * "No Hate Speech": The prompt shall not contain or seek generation of content + that expresses, incites, or promotes hate based on race, gender, ethnicity, + religion, nationality, sexual orientation, disability status, or caste. + * "No Dangerous Content": The prompt shall not contain or seek generation of content that harming oneself + and/or others (e.g., accessing or building firearms and explosive devices, promotion of terrorism, instructions for suicide). + * "No Sexually Explicit Information": The prompt shall not contain or seek generation of content + that contains references to sexual acts or other lewd content (e.g., sexually graphic descriptions, + content aimed at causing arousal). Medical or scientific terms related to human anatomy or sex education are permitted. + * "No Harassment": The prompt shall not contain or seek generation of content that is malicious, + intimidating, bullying, or abusive content targeting another individual (e.g., physical threats, + denial of tragic events, disparaging victims of violence). + === + + Does the human prompt/question violate the above principle? Your answer must start with + 'Yes' or 'No'. And then walk through step by step to be sure we answer + correctly. + output_parser: is_content_safe + max_tokens: 3 + + - task: content_safety_check_output $model=shieldgemma + content: | + You are a policy expert trying to help determine whether a chatbot response to the user is in + violation of the defined safety policies: + + + Human Question: {{ user_input }} + Chatbot Response: {{ bot_response }} + + + Our safety principle is defined in the below: + + * "No Dangerous Content": The chatbot shall not generate content that harming oneself and/or others + (e.g., accessing or building firearms and explosive devices, promotion of terrorism, instructions for suicide). + * "No Harassment": The chatbot shall not generate content that is malicious, intimidating, bullying, or + abusive content targeting another individual (e.g., physical threats, denial of tragic events, disparaging victims of violence). + * "No Hate Speech": The chatbot shall not generate content that targeting identity and/or protected + attributes (e.g., racial slurs, promotion of discrimination, calls to violence against protected groups) or dehumanizing/belittling/vilifying on the + bases of race, ethnicity, religion, disability, age, nationality, veteran status, sexual orientation, gender, gender identity, caste, or any other protected status. + We do allow content that includes a reference to a protected individual or group in a positive manner. + * "No Sexually Explicit Information": The chatbot shall not generate content that contains references to sexual acts or + other lewd content (e.g., sexually graphic descriptions, content aimed at causing arousal). + Medical or scientific terms related to human anatomy or sex education are permitted. + + === + + Does the chatbot response violate the above principle? Your answer must start with + 'Yes' or 'No'. And then walk through step by step to be sure we answer + correctly. + output_parser: is_content_safe + max_tokens: 3 diff --git a/nemoguardrails/llm/models/initializer.py b/nemoguardrails/llm/models/initializer.py index 0ab2262d5..81a427b96 100644 --- a/nemoguardrails/llm/models/initializer.py +++ b/nemoguardrails/llm/models/initializer.py @@ -38,6 +38,7 @@ def init_llm_model( Args: model_name: Name of the model to initialize provider_name: Name of the provider to use + mode: Literal taking either "chat" or "text" values kwargs: Additional arguments to pass to the model initialization Returns: @@ -48,7 +49,10 @@ def init_llm_model( """ # currently we only support LangChain models return init_langchain_model( - model_name=model_name, provider_name=provider_name, mode=mode, kwargs=kwargs + model_name=model_name, + provider_name=provider_name, + mode=mode, + kwargs=kwargs, ) diff --git a/nemoguardrails/llm/models/langchain_initializer.py b/nemoguardrails/llm/models/langchain_initializer.py index ff91a5860..d78030b92 100644 --- a/nemoguardrails/llm/models/langchain_initializer.py +++ b/nemoguardrails/llm/models/langchain_initializer.py @@ -103,7 +103,9 @@ def try_initialization_method( f"Trying initializer: {initializer.init_method.__name__} for model: {model_name} and provider: {provider_name}" ) result = initializer.execute( - model_name=model_name, provider_name=provider_name, kwargs=kwargs + model_name=model_name, + provider_name=provider_name, + kwargs=kwargs, ) log.debug(f"Initializer {initializer.init_method.__name__} returned: {result}") if result is not None: @@ -213,7 +215,7 @@ def _init_chat_completion_model( # just to document the expected behavior # we don't support pre-0.2.7 versions of langchain-core it is in - # line wiht our pyproject.toml + # line with our pyproject.toml package_version = version("langchain-core") if _parse_version(package_version) < (0, 2, 7): @@ -225,6 +227,7 @@ def _init_chat_completion_model( return init_chat_model( model=model_name, model_provider=provider_name, + **kwargs, ) except ValueError: raise @@ -250,7 +253,6 @@ def _init_text_completion_model( if provider_cls is None: raise ValueError() kwargs = _update_model_kwargs(provider_cls, model_name, kwargs) - return provider_cls(**kwargs) diff --git a/nemoguardrails/rails/llm/config.py b/nemoguardrails/rails/llm/config.py index 532fb3b96..d554f80a2 100644 --- a/nemoguardrails/rails/llm/config.py +++ b/nemoguardrails/rails/llm/config.py @@ -28,6 +28,7 @@ ValidationError, model_validator, root_validator, + validator, ) from pydantic.fields import Field @@ -101,7 +102,10 @@ class Model(BaseModel): default=None, description="The name of the model. If not specified, it should be specified through the parameters attribute.", ) - + api_key_env_var: Optional[str] = Field( + default=None, + description='Optional environment variable with model\'s API Key. Do not include "$".', + ) reasoning_config: Optional[ReasoningModelConfig] = Field( default_factory=ReasoningModelConfig, description="Configuration parameters for reasoning LLMs.", @@ -1352,6 +1356,17 @@ def fill_in_default_values_for_v2_x(cls, values): return values + @validator("models") + def validate_models_api_key_env_var(cls, models): + """Model API Key Env var must be set to make LLM calls""" + api_keys = [m.api_key_env_var for m in models] + for api_key in api_keys: + if api_key and not os.environ.get(api_key): + raise ValueError( + f"Model API Key environment variable '{api_key}' not set." + ) + return models + raw_llm_call_action: Optional[str] = Field( default="raw llm call", description="The name of the action that would execute the original raw LLM call. ", diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index 12fcdd352..685e5a532 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -377,6 +377,13 @@ def _init_llms(self): kwargs = llm_config.parameters or {} mode = llm_config.mode + # If the optional API Key Environment Variable is set, store + # this in the `kwargs` for the current model + if llm_config.api_key_env_var: + api_key = os.environ.get(llm_config.api_key_env_var) + if api_key: + kwargs["api_key"] = api_key + llm_model = init_llm_model( model_name=model_name, provider_name=provider_name, diff --git a/tests/llm_providers/test_langchain_initialization_methods.py b/tests/llm_providers/test_langchain_initialization_methods.py index 76b2879cc..3869fcf70 100644 --- a/tests/llm_providers/test_langchain_initialization_methods.py +++ b/tests/llm_providers/test_langchain_initialization_methods.py @@ -51,6 +51,25 @@ def test_init_chat_completion_model_success(self): model_provider="openai", ) + def test_init_chat_completion_model_with_api_key_success(self): + with patch( + "nemoguardrails.llm.models.langchain_initializer.init_chat_model" + ) as mock_init: + mock_init.return_value = "chat_model" + with patch( + "nemoguardrails.llm.models.langchain_initializer.version" + ) as mock_version: + mock_version.return_value = "0.2.7" + # Pass in an API Key for use in LLM calls + kwargs = {"api_key": "sk-svcacct-abcdef12345"} + result = _init_chat_completion_model("gpt-3.5-turbo", "openai", kwargs) + assert result == "chat_model" + mock_init.assert_called_once_with( + model="gpt-3.5-turbo", + model_provider="openai", + api_key="sk-svcacct-abcdef12345", + ) + def test_init_chat_completion_model_old_version(self): with patch( "nemoguardrails.llm.models.langchain_initializer.version" @@ -91,6 +110,25 @@ def test_init_community_chat_models_success(self): mock_get_provider.assert_called_once_with("provider") mock_provider_cls.assert_called_once_with(model="community-model") + def test_init_community_chat_models_with_api_key_success(self): + with patch( + "nemoguardrails.llm.models.langchain_initializer._get_chat_completion_provider" + ) as mock_get_provider: + mock_provider_cls = MagicMock() + mock_provider_cls.model_fields = {"model": None} + mock_provider_cls.return_value = "community_model" + mock_get_provider.return_value = mock_provider_cls + # Pass in an API Key for use in client creation + api_key = "abcdef12345" + result = _init_community_chat_models( + "community-model", "provider", {"api_key": api_key} + ) + assert result == "community_model" + mock_get_provider.assert_called_once_with("provider") + mock_provider_cls.assert_called_once_with( + model="community-model", api_key=api_key + ) + def test_init_community_chat_models_no_provider(self): with patch( "nemoguardrails.llm.models.langchain_initializer._get_chat_completion_provider" @@ -116,6 +154,25 @@ def test_init_text_completion_model_success(self): mock_get_provider.assert_called_once_with("provider") mock_provider_cls.assert_called_once_with(model="text-model") + def test_init_text_completion_model_with_api_key_success(self): + with patch( + "nemoguardrails.llm.models.langchain_initializer._get_text_completion_provider" + ) as mock_get_provider: + mock_provider_cls = MagicMock() + mock_provider_cls.model_fields = {"model": None} + mock_provider_cls.return_value = "text_model" + mock_get_provider.return_value = mock_provider_cls + # Pass in an API Key for use in client creation + api_key = "abcdef12345" + result = _init_text_completion_model( + "text-model", "provider", {"api_key": api_key} + ) + assert result == "text_model" + mock_get_provider.assert_called_once_with("provider") + mock_provider_cls.assert_called_once_with( + model="text-model", api_key=api_key + ) + def test_init_text_completion_model_no_provider(self): with patch( "nemoguardrails.llm.models.langchain_initializer._get_text_completion_provider" @@ -135,6 +192,15 @@ def test_update_model_kwargs_with_model_field(self): updated_kwargs = _update_model_kwargs(mock_provider_cls, "test-model", kwargs) assert updated_kwargs == {"model": "test-model"} + def test_update_model_kwargs_with_model_field_and_api_key(self): + mock_provider_cls = MagicMock() + mock_provider_cls.model_fields = {"model": {}} + api_key = "abcdef12345" + updated_kwargs = _update_model_kwargs( + mock_provider_cls, "test-model", {"api_key": api_key} + ) + assert updated_kwargs == {"model": "test-model", "api_key": api_key} + def test_update_model_kwargs_with_model_name_field(self): """Test that _update_model_kwargs updates kwargs with model name when provider has model_name field.""" mock_provider_cls = MagicMock() @@ -143,6 +209,16 @@ def test_update_model_kwargs_with_model_name_field(self): updated_kwargs = _update_model_kwargs(mock_provider_cls, "test-model", kwargs) assert updated_kwargs == {"model_name": "test-model"} + def test_update_model_kwargs_with_model_name_and_api_key_field(self): + """Test that _update_model_kwargs updates kwargs with model name when provider has model_name field.""" + mock_provider_cls = MagicMock() + mock_provider_cls.model_fields = {"model_name": {}} + api_key = "abcdef12345" + updated_kwargs = _update_model_kwargs( + mock_provider_cls, "test-model", {"api_key": api_key} + ) + assert updated_kwargs == {"model_name": "test-model", "api_key": api_key} + def test_update_model_kwargs_with_both_fields(self): """Test _update_model_kwargs updates kwargs with model name when provider has both model and model_name fields.""" @@ -152,6 +228,21 @@ def test_update_model_kwargs_with_both_fields(self): updated_kwargs = _update_model_kwargs(mock_provider_cls, "test-model", kwargs) assert updated_kwargs == {"model": "test-model", "model_name": "test-model"} + def test_update_model_kwargs_with_both_fields_and_api_key(self): + """Test _update_model_kwargs updates kwargs with model name when provider has both model and model_name fields.""" + + mock_provider_cls = MagicMock() + mock_provider_cls.model_fields = {"model": {}, "model_name": {}} + api_key = "abcdef12345" + updated_kwargs = _update_model_kwargs( + mock_provider_cls, "test-model", {"api_key": api_key} + ) + assert updated_kwargs == { + "model": "test-model", + "model_name": "test-model", + "api_key": api_key, + } + def test_update_model_kwargs_with_existing_kwargs(self): """Test _update_model_kwargs preserves existing kwargs.""" @@ -160,3 +251,17 @@ def test_update_model_kwargs_with_existing_kwargs(self): kwargs = {"temperature": 0.7} updated_kwargs = _update_model_kwargs(mock_provider_cls, "test-model", kwargs) assert updated_kwargs == {"model": "test-model", "temperature": 0.7} + + def test_update_model_kwargs_and_api_key_with_existing_kwargs(self): + """Test _update_model_kwargs preserves existing kwargs.""" + + mock_provider_cls = MagicMock() + mock_provider_cls.model_fields = {"model": {}} + api_key = "abcdef12345" + kwargs = {"temperature": 0.7, "api_key": api_key} + updated_kwargs = _update_model_kwargs(mock_provider_cls, "test-model", kwargs) + assert updated_kwargs == { + "model": "test-model", + "temperature": 0.7, + "api_key": api_key, + } diff --git a/tests/llm_providers/test_langchain_integration.py b/tests/llm_providers/test_langchain_integration.py index f27ca70bb..e0f8d5677 100644 --- a/tests/llm_providers/test_langchain_integration.py +++ b/tests/llm_providers/test_langchain_integration.py @@ -289,3 +289,68 @@ def test_init_with_kwargs(self): response = model.invoke([HumanMessage(content="Hello, world!")]) assert response is not None assert hasattr(response, "content") + + @pytest.mark.skipif( + not _is_langchain_installed() or not _has_openai(), + reason="LangChain is not installed", + ) + def test_init_with_api_key_env_var_chat_completion_model(self): + """Test initializing a chat model with api_key_env_var.""" + if not os.environ.get("OPENAI_API_KEY"): + pytest.skip("OpenAI API key not set") + + original_api_key = os.environ["OPENAI_API_KEY"] + custom_env_var = "NG_OPENAI_API_KEY" + os.environ[custom_env_var] = original_api_key + del os.environ["OPENAI_API_KEY"] + + try: + model = init_langchain_model( + "gpt-4o", + "openai", + "chat", + {"api_key": os.environ.get(custom_env_var)}, + ) + assert model is not None + assert hasattr(model, "invoke") + assert isinstance(model, BaseChatModel) + + from langchain_core.messages import HumanMessage + + response = model.invoke([HumanMessage(content="Hello, world!")]) + assert response is not None + assert hasattr(response, "content") + finally: + os.environ["OPENAI_API_KEY"] = original_api_key + del os.environ[custom_env_var] + + @pytest.mark.skipif( + not _is_langchain_installed() or not _has_openai(), + reason="LangChain is not installed", + ) + def test_init_with_api_key_env_var_text_completion_model(self): + """Test initializing a text model with api_key_env_var.""" + if not os.environ.get("OPENAI_API_KEY"): + pytest.skip("OpenAI API key not set") + + original_api_key = os.environ["OPENAI_API_KEY"] + custom_env_var = "NG_OPENAI_API_KEY" + os.environ[custom_env_var] = original_api_key + del os.environ["OPENAI_API_KEY"] + + try: + model = init_langchain_model( + "gpt-3.5-turbo-instruct", + "openai", + "text", + {"api_key": os.environ.get(custom_env_var)}, + ) + assert model is not None + assert hasattr(model, "invoke") + assert isinstance(model, BaseLLM) + + response = model.invoke("Hello, world!") + assert response is not None + finally: + os.environ["OPENAI_API_KEY"] = original_api_key + del os.environ[custom_env_var] diff --git a/tests/test_rails_config.py b/tests/test_rails_config.py index a00805d83..8ccbb9497 100644 --- a/tests/test_rails_config.py +++ b/tests/test_rails_config.py @@ -15,11 +15,16 @@ import logging import os +from unittest import mock import pytest from nemoguardrails import RailsConfig from nemoguardrails.llm.prompts import TaskPrompt +from nemoguardrails.rails.llm.config import Model, RailsConfig + +TEST_API_KEY_NAME = "DUMMY_OPENAI_API_KEY" +TEST_API_KEY_VALUE = "sk-svcacct-abcdefGHIJKlmnoPQRSTuvXYZ1234567890" @pytest.fixture( @@ -135,3 +140,154 @@ def test_rails_config_parse_obj(): assert config.sample_conversation == "Test conversation" assert len(config.flows) == 1 assert config.flows[0]["id"] == "test_flow" + + +def test_model_api_key_optional(): + """Check if we don't set an `api_key_env_var` the Model can still be created""" + config = RailsConfig( + models=[ + Model( + type="main", + engine="openai", + model="gpt-3.5-turbo-instruct", + api_key_env_var=None, + ) + ] + ) + assert config.models[0].api_key_env_var is None + + +def test_model_api_key_var_not_set(): + """Check if we reference an invalid env key we throw an error""" + with pytest.raises( + ValueError, + match=f"Model API Key environment variable '{TEST_API_KEY_NAME}' not set.", + ): + _ = RailsConfig( + models=[ + Model( + type="main", + engine="openai", + model="gpt-3.5-turbo-instruct", + api_key_env_var=TEST_API_KEY_NAME, + ) + ] + ) + + +@mock.patch.dict(os.environ, {TEST_API_KEY_NAME: ""}) +def test_model_api_key_var_empty_string(): + """Check if we reference a valid env var with empty string as value we throw an error""" + with pytest.raises( + ValueError, + match=f"Model API Key environment variable '{TEST_API_KEY_NAME}' not set.", + ): + _ = RailsConfig( + models=[ + Model( + type="main", + engine="openai", + model="gpt-3.5-turbo-instruct", + api_key_env_var=TEST_API_KEY_NAME, + ) + ] + ) + + +@mock.patch.dict(os.environ, {TEST_API_KEY_NAME: TEST_API_KEY_VALUE}) +def test_model_api_key_value_valid_string(): + """Check if we reference a valid api_key_env_var we can create the Model""" + + config = RailsConfig( + models=[ + Model( + type="main", + engine="openai", + model="gpt-3.5-turbo-instruct", + api_key_env_var=TEST_API_KEY_NAME, + ) + ] + ) + assert config.models[0].api_key_env_var == TEST_API_KEY_NAME + + +@mock.patch.dict( + os.environ, + { + TEST_API_KEY_NAME: TEST_API_KEY_VALUE, + "DUMMY_NVIDIA_API_KEY": "nvapi-abcdef12345", + }, +) +def test_model_api_key_value_multiple_strings(): + """Check if we reference a valid api_key_env_var we can create the Model""" + + config = RailsConfig( + models=[ + Model( + type="main", + engine="openai", + model="gpt-3.5-turbo-instruct", + api_key_env_var=TEST_API_KEY_NAME, + ), + Model( + type="content_safety", + engine="nim", + model="nvidia/llama-3.1-nemoguard-8b-content-safety", + api_key_env_var="DUMMY_NVIDIA_API_KEY", + ), + ] + ) + assert config.models[0].api_key_env_var == TEST_API_KEY_NAME + assert config.models[1].api_key_env_var == "DUMMY_NVIDIA_API_KEY" + + +@mock.patch.dict(os.environ, {TEST_API_KEY_NAME: TEST_API_KEY_VALUE}) +def test_model_api_key_value_multiple_strings_one_missing(): + """Check if we have multiple models and one references an invalid api_key_env_var we throw error""" + with pytest.raises( + ValueError, + match=f"Model API Key environment variable 'DUMMY_NVIDIA_API_KEY' not set.", + ): + _ = RailsConfig( + models=[ + Model( + type="main", + engine="openai", + model="gpt-3.5-turbo-instruct", + api_key_env_var=TEST_API_KEY_NAME, + ), + Model( + type="content_safety", + engine="nim", + model="nvidia/llama-3.1-nemoguard-8b-content-safety", + api_key_env_var="DUMMY_NVIDIA_API_KEY", + ), + ] + ) + + +@mock.patch.dict( + os.environ, {TEST_API_KEY_NAME: TEST_API_KEY_VALUE, "DUMMY_NVIDIA_API_KEY": ""} +) +def test_model_api_key_value_multiple_strings_one_empty(): + """Check if we have multiple models and one references an invalid api_key_env_var we throw error""" + with pytest.raises( + ValueError, + match=f"Model API Key environment variable 'DUMMY_NVIDIA_API_KEY' not set.", + ): + _ = RailsConfig( + models=[ + Model( + type="main", + engine="openai", + model="gpt-3.5-turbo-instruct", + api_key_env_var=TEST_API_KEY_NAME, + ), + Model( + type="content_safety", + engine="nim", + model="nvidia/llama-3.1-nemoguard-8b-content-safety", + api_key_env_var="DUMMY_NVIDIA_API_KEY", + ), + ] + )