From 4b207bf9484c5beecac40f2a88309b799396ce19 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Mon, 30 Sep 2024 18:08:19 +1000 Subject: [PATCH 1/7] created custom error when env var for providers does not exist --- src/exchange/providers/anthropic.py | 7 ++----- src/exchange/providers/azure.py | 31 +++++++++++----------------- src/exchange/providers/base.py | 12 ++++++++++- src/exchange/providers/bedrock.py | 15 +++++++------- src/exchange/providers/databricks.py | 21 +++++++------------ src/exchange/providers/openai.py | 9 +++----- src/exchange/providers/utils.py | 7 +++++++ tests/providers/test_anthropic.py | 8 +++++++ tests/providers/test_azure.py | 24 +++++++++++++++++++++ tests/providers/test_bedrock.py | 21 +++++++++++++++++++ tests/providers/test_databricks.py | 20 ++++++++++++++++++ tests/providers/test_openai.py | 11 ++++++++++ 12 files changed, 135 insertions(+), 51 deletions(-) diff --git a/src/exchange/providers/anthropic.py b/src/exchange/providers/anthropic.py index 154ec5f..05a980c 100644 --- a/src/exchange/providers/anthropic.py +++ b/src/exchange/providers/anthropic.py @@ -7,7 +7,7 @@ from exchange.content import Text, ToolResult, ToolUse from exchange.providers.base import Provider, Usage from tenacity import retry, wait_fixed, stop_after_attempt -from exchange.providers.utils import retry_if_status +from exchange.providers.utils import get_provider_env_value, retry_if_status from exchange.providers.utils import raise_for_status ANTHROPIC_HOST = "https://api.anthropic.com/v1/messages" @@ -27,10 +27,7 @@ def __init__(self, client: httpx.Client) -> None: @classmethod def from_env(cls: Type["AnthropicProvider"]) -> "AnthropicProvider": url = os.environ.get("ANTHROPIC_HOST", ANTHROPIC_HOST) - try: - key = os.environ["ANTHROPIC_API_KEY"] - except KeyError: - raise RuntimeError("Failed to get ANTHROPIC_API_KEY from the environment") + key = get_provider_env_value("ANTHROPIC_API_KEY", "anthropic") client = httpx.Client( base_url=url, headers={ diff --git a/src/exchange/providers/azure.py b/src/exchange/providers/azure.py index 7bacb9d..0ff761c 100644 --- a/src/exchange/providers/azure.py +++ b/src/exchange/providers/azure.py @@ -4,6 +4,8 @@ import httpx from exchange.providers import OpenAiProvider +from exchange.providers.base import MissingProviderEnvVariableError +from exchange.providers.utils import get_provider_env_value class AzureProvider(OpenAiProvider): @@ -11,28 +13,14 @@ class AzureProvider(OpenAiProvider): def __init__(self, client: httpx.Client) -> None: super().__init__(client) - + @classmethod def from_env(cls: Type["AzureProvider"]) -> "AzureProvider": - try: - url = os.environ["AZURE_CHAT_COMPLETIONS_HOST_NAME"] - except KeyError: - raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_HOST_NAME from the environment.") - - try: - deployment_name = os.environ["AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME"] - except KeyError: - raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME from the environment.") + url = cls._get_env_variable("AZURE_CHAT_COMPLETIONS_HOST_NAME") + deployment_name = cls._get_env_variable("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME") - try: - api_version = os.environ["AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION"] - except KeyError: - raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION from the environment.") - - try: - key = os.environ["AZURE_CHAT_COMPLETIONS_KEY"] - except KeyError: - raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_KEY from the environment.") + api_version = cls._get_env_variable("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION") + key = cls._get_env_variable("AZURE_CHAT_COMPLETIONS_KEY") # format the url host/"openai/deployments/" + deployment_name + "/?api-version=" + api_version url = f"{url}/openai/deployments/{deployment_name}/" @@ -43,3 +31,8 @@ def from_env(cls: Type["AzureProvider"]) -> "AzureProvider": timeout=httpx.Timeout(60 * 10), ) return cls(client) + + @classmethod + def _get_env_variable(cls:Type["AzureProvider"], key: str) -> str: + return get_provider_env_value(key, "azure") + \ No newline at end of file diff --git a/src/exchange/providers/base.py b/src/exchange/providers/base.py index 7b7ff88..056f93c 100644 --- a/src/exchange/providers/base.py +++ b/src/exchange/providers/base.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from attrs import define, field -from typing import List, Tuple, Type +from typing import List, Optional, Tuple, Type from exchange.message import Message from exchange.tool import Tool @@ -28,3 +28,13 @@ def complete( ) -> Tuple[Message, Usage]: """Generate the next message using the specified model""" pass + +class MissingProviderEnvVariableError(Exception): + def __init__(self, env_variable: str, provider: str, instructions: Optional[str] = None) -> None: + self.env_variable = env_variable + self.provider = provider + self.instructions = instructions + self.message = f"Missing environment variable: {env_variable} for provider {provider}" + if instructions: + self.message += f". {instructions}" + super().__init__(self.message) \ No newline at end of file diff --git a/src/exchange/providers/bedrock.py b/src/exchange/providers/bedrock.py index 2a5f53d..04c1ffe 100644 --- a/src/exchange/providers/bedrock.py +++ b/src/exchange/providers/bedrock.py @@ -13,7 +13,7 @@ from exchange.message import Message from exchange.providers import Provider, Usage from tenacity import retry, wait_fixed, stop_after_attempt -from exchange.providers.utils import retry_if_status +from exchange.providers.utils import get_provider_env_value, retry_if_status from exchange.providers.utils import raise_for_status from exchange.tool import Tool @@ -154,12 +154,9 @@ def __init__(self, client: AwsClient) -> None: @classmethod def from_env(cls: Type["BedrockProvider"]) -> "BedrockProvider": aws_region = os.environ.get("AWS_REGION", "us-east-1") - try: - aws_access_key = os.environ["AWS_ACCESS_KEY_ID"] - aws_secret_key = os.environ["AWS_SECRET_ACCESS_KEY"] - aws_session_token = os.environ.get("AWS_SESSION_TOKEN") - except KeyError: - raise RuntimeError("Failed to get AWS_ACCESS_KEY_ID or AWS_SECRET_ACCESS_KEY from the environment") + aws_access_key = cls._get_env_variable("AWS_ACCESS_KEY_ID") + aws_secret_key = cls._get_env_variable("AWS_SECRET_ACCESS_KEY") + aws_session_token = cls._get_env_variable("AWS_SESSION_TOKEN") client = AwsClient( aws_region=aws_region, @@ -326,3 +323,7 @@ def tools_to_bedrock_spec(tools: Tuple[Tool]) -> Optional[dict]: tools_added.add(tool.name) tool_config = {"tools": tool_config_list} return tool_config + + @classmethod + def _get_env_variable(cls:Type["BedrockProvider"], key: str) -> str: + return get_provider_env_value(key, "bedrock") diff --git a/src/exchange/providers/databricks.py b/src/exchange/providers/databricks.py index 84dc751..e9d45b4 100644 --- a/src/exchange/providers/databricks.py +++ b/src/exchange/providers/databricks.py @@ -6,7 +6,7 @@ from exchange.message import Message from exchange.providers.base import Provider, Usage from tenacity import retry, wait_fixed, stop_after_attempt -from exchange.providers.utils import raise_for_status, retry_if_status +from exchange.providers.utils import get_provider_env_value, raise_for_status, retry_if_status from exchange.providers.utils import ( messages_to_openai_spec, openai_response_to_message, @@ -37,18 +37,8 @@ def __init__(self, client: httpx.Client) -> None: @classmethod def from_env(cls: Type["DatabricksProvider"]) -> "DatabricksProvider": - try: - url = os.environ["DATABRICKS_HOST"] - except KeyError: - raise RuntimeError( - "Failed to get DATABRICKS_HOST from the environment. See https://docs.databricks.com/en/dev-tools/auth/index.html#general-host-token-and-account-id-environment-variables-and-fields" - ) - try: - key = os.environ["DATABRICKS_TOKEN"] - except KeyError: - raise RuntimeError( - "Failed to get DATABRICKS_TOKEN from the environment. See https://docs.databricks.com/en/dev-tools/auth/index.html#general-host-token-and-account-id-environment-variables-and-fields" - ) + url = cls._get_env_variable("DATABRICKS_HOST") + key = cls._get_env_variable("DATABRICKS_TOKEN") client = httpx.Client( base_url=url, auth=("token", key), @@ -100,3 +90,8 @@ def _post(self, model: str, payload: dict) -> httpx.Response: json=payload, ) return raise_for_status(response).json() + + @classmethod + def _get_env_variable(cls:Type["DatabricksProvider"], key: str) -> str: + instruction = "https://docs.databricks.com/en/dev-tools/auth/index.html#general-host-token-and-account-id-environment-variables-and-fields" + return get_provider_env_value(key, "databricks", instruction) diff --git a/src/exchange/providers/openai.py b/src/exchange/providers/openai.py index dbd293b..1364c04 100644 --- a/src/exchange/providers/openai.py +++ b/src/exchange/providers/openai.py @@ -6,6 +6,7 @@ from exchange.message import Message from exchange.providers.base import Provider, Usage from exchange.providers.utils import ( + get_provider_env_value, messages_to_openai_spec, openai_response_to_message, openai_single_message_context_length_exceeded, @@ -36,12 +37,8 @@ def __init__(self, client: httpx.Client) -> None: @classmethod def from_env(cls: Type["OpenAiProvider"]) -> "OpenAiProvider": url = os.environ.get("OPENAI_HOST", OPENAI_HOST) - try: - key = os.environ["OPENAI_API_KEY"] - except KeyError: - raise RuntimeError( - "Failed to get OPENAI_API_KEY from the environment, see https://platform.openai.com/docs/api-reference/api-keys" - ) + api_key_instructions = "see https://platform.openai.com/docs/api-reference/api-keys" + key = get_provider_env_value("OPENAI_API_KEY", "openai", api_key_instructions) client = httpx.Client( base_url=url + "v1/", auth=("Bearer", key), diff --git a/src/exchange/providers/utils.py b/src/exchange/providers/utils.py index 4be7ac3..4297dc2 100644 --- a/src/exchange/providers/utils.py +++ b/src/exchange/providers/utils.py @@ -1,11 +1,13 @@ import base64 import json +import os import re from typing import Any, Callable, Dict, List, Optional, Tuple import httpx from exchange.content import Text, ToolResult, ToolUse from exchange.message import Message +from exchange.providers.base import MissingProviderEnvVariableError from exchange.tool import Tool from tenacity import retry_if_exception @@ -178,6 +180,11 @@ def openai_single_message_context_length_exceeded(error_dict: dict) -> None: if code == "context_length_exceeded" or code == "string_above_max_length": raise InitialMessageTooLargeError(f"Input message too long. Message: {error_dict.get('message')}") +def get_provider_env_value(env_variable: str, provider: str, instructions: Optional[str] = None) -> str: + try: + return os.environ[env_variable] + except KeyError: + raise MissingProviderEnvVariableError(env_variable, provider, instructions) class InitialMessageTooLargeError(Exception): """Custom error raised when the first input message in an exchange is too large.""" diff --git a/tests/providers/test_anthropic.py b/tests/providers/test_anthropic.py index a6f5bc6..f96e229 100644 --- a/tests/providers/test_anthropic.py +++ b/tests/providers/test_anthropic.py @@ -6,6 +6,7 @@ from exchange import Message, Text from exchange.content import ToolResult, ToolUse from exchange.providers.anthropic import AnthropicProvider +from exchange.providers.base import MissingProviderEnvVariableError from exchange.tool import Tool @@ -24,6 +25,13 @@ def example_fn(param: str) -> None: def anthropic_provider(): return AnthropicProvider.from_env() +def test_from_env_throw_error_when_missing_api_key(): + with patch.dict(os.environ, {}, clear=True): + with pytest.raises(MissingProviderEnvVariableError) as context: + AnthropicProvider.from_env() + assert context.value.provider == "anthropic" + assert context.value.env_variable == "ANTHROPIC_API_KEY" + assert context.value.message == "Missing environment variable: ANTHROPIC_API_KEY for provider anthropic" def test_anthropic_response_to_text_message() -> None: response = { diff --git a/tests/providers/test_azure.py b/tests/providers/test_azure.py index adafabe..21c1ab5 100644 --- a/tests/providers/test_azure.py +++ b/tests/providers/test_azure.py @@ -1,13 +1,37 @@ import os +from unittest.mock import patch import pytest from exchange import Text, ToolUse from exchange.providers.azure import AzureProvider +from exchange.providers.base import MissingProviderEnvVariableError from .conftest import complete, tools AZURE_MODEL = os.getenv("AZURE_MODEL", "gpt-4o-mini") +@pytest.mark.parametrize( + "env_var_name", + [ + ("AZURE_CHAT_COMPLETIONS_HOST_NAME"), + ("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME"), + ("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION"), + ("AZURE_CHAT_COMPLETIONS_KEY"), + ] +) +def test_from_env_throw_error_when_missing_env_var(env_var_name): + with patch.dict(os.environ, { + "AZURE_CHAT_COMPLETIONS_HOST_NAME": "test_host_name", + "AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME": "test_deployment_name", + "AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION": "test_api_version", + "AZURE_CHAT_COMPLETIONS_KEY": "test_api_key", + }, clear=True): + os.environ.pop(env_var_name) + with pytest.raises(MissingProviderEnvVariableError) as context: + AzureProvider.from_env() + assert context.value.provider == "azure" + assert context.value.env_variable == env_var_name + assert context.value.message == f"Missing environment variable: {env_var_name} for provider azure" @pytest.mark.vcr() def test_azure_complete(default_azure_env): diff --git a/tests/providers/test_bedrock.py b/tests/providers/test_bedrock.py index 2525f65..0bc2b2c 100644 --- a/tests/providers/test_bedrock.py +++ b/tests/providers/test_bedrock.py @@ -5,11 +5,32 @@ import pytest from exchange.content import Text, ToolResult, ToolUse from exchange.message import Message +from exchange.providers.base import MissingProviderEnvVariableError from exchange.providers.bedrock import BedrockProvider from exchange.tool import Tool logger = logging.getLogger(__name__) +@pytest.mark.parametrize( + "env_var_name", + [ + ("AWS_ACCESS_KEY_ID"), + ("AWS_SECRET_ACCESS_KEY"), + ("AWS_SESSION_TOKEN"), + ] +) +def test_from_env_throw_error_when_missing_env_var(env_var_name): + with patch.dict(os.environ, { + "AWS_ACCESS_KEY_ID": "test_access_key_id", + "AWS_SECRET_ACCESS_KEY": "test_secret_access_key", + "AWS_SESSION_TOKEN": "test_session_token", + }, clear=True): + os.environ.pop(env_var_name) + with pytest.raises(MissingProviderEnvVariableError) as context: + BedrockProvider.from_env() + assert context.value.provider == "bedrock" + assert context.value.env_variable == env_var_name + assert context.value.message == f"Missing environment variable: {env_var_name} for provider bedrock" @pytest.fixture @patch.dict( diff --git a/tests/providers/test_databricks.py b/tests/providers/test_databricks.py index 3c14211..4be09ac 100644 --- a/tests/providers/test_databricks.py +++ b/tests/providers/test_databricks.py @@ -3,8 +3,28 @@ import pytest from exchange import Message, Text +from exchange.providers.base import MissingProviderEnvVariableError from exchange.providers.databricks import DatabricksProvider +@pytest.mark.parametrize( + "env_var_name", + [ + ("DATABRICKS_HOST"), + ("DATABRICKS_TOKEN"), + ] +) +def test_from_env_throw_error_when_missing_env_var(env_var_name): + with patch.dict(os.environ, { + "DATABRICKS_HOST": "test_host", + "DATABRICKS_TOKEN": "test_token", + }, clear=True): + os.environ.pop(env_var_name) + with pytest.raises(MissingProviderEnvVariableError) as context: + DatabricksProvider.from_env() + assert context.value.provider == "databricks" + assert context.value.env_variable == env_var_name + assert f"Missing environment variable: {env_var_name} for provider databricks" in context.value.message + assert "https://docs.databricks.com" in context.value.message @pytest.fixture @patch.dict( diff --git a/tests/providers/test_openai.py b/tests/providers/test_openai.py index 45bc620..4a410b2 100644 --- a/tests/providers/test_openai.py +++ b/tests/providers/test_openai.py @@ -1,14 +1,25 @@ import os +from unittest.mock import patch import pytest from exchange import Text, ToolUse +from exchange.providers.base import MissingProviderEnvVariableError from exchange.providers.openai import OpenAiProvider from .conftest import complete, vision, tools OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini") +def test_from_env_throw_error_when_missing_api_key(): + with patch.dict(os.environ, {}, clear=True): + with pytest.raises(MissingProviderEnvVariableError) as context: + OpenAiProvider.from_env() + assert context.value.provider == "openai" + assert context.value.env_variable == "OPENAI_API_KEY" + assert "Missing environment variable: OPENAI_API_KEY for provider openai" in context.value.message + assert "https://platform.openai.com" in context.value.message + @pytest.mark.vcr() def test_openai_complete(default_openai_env): reply_message, reply_usage = complete(OpenAiProvider, OPENAI_MODEL) From 78d49138c77f247d435fe9a38cf91b0bde35547a Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Tue, 1 Oct 2024 13:21:39 +1000 Subject: [PATCH 2/7] fixed the ruff check --- src/exchange/providers/azure.py | 9 +++------ src/exchange/providers/base.py | 3 ++- src/exchange/providers/bedrock.py | 2 +- src/exchange/providers/databricks.py | 3 +-- src/exchange/providers/utils.py | 2 ++ tests/providers/test_anthropic.py | 2 ++ tests/providers/test_azure.py | 20 +++++++++++++------- tests/providers/test_bedrock.py | 18 ++++++++++++------ tests/providers/test_databricks.py | 16 +++++++++++----- tests/providers/test_openai.py | 1 + 10 files changed, 48 insertions(+), 28 deletions(-) diff --git a/src/exchange/providers/azure.py b/src/exchange/providers/azure.py index 0ff761c..212dcca 100644 --- a/src/exchange/providers/azure.py +++ b/src/exchange/providers/azure.py @@ -1,10 +1,8 @@ -import os from typing import Type import httpx from exchange.providers import OpenAiProvider -from exchange.providers.base import MissingProviderEnvVariableError from exchange.providers.utils import get_provider_env_value @@ -13,7 +11,7 @@ class AzureProvider(OpenAiProvider): def __init__(self, client: httpx.Client) -> None: super().__init__(client) - + @classmethod def from_env(cls: Type["AzureProvider"]) -> "AzureProvider": url = cls._get_env_variable("AZURE_CHAT_COMPLETIONS_HOST_NAME") @@ -31,8 +29,7 @@ def from_env(cls: Type["AzureProvider"]) -> "AzureProvider": timeout=httpx.Timeout(60 * 10), ) return cls(client) - + @classmethod - def _get_env_variable(cls:Type["AzureProvider"], key: str) -> str: + def _get_env_variable(cls: Type["AzureProvider"], key: str) -> str: return get_provider_env_value(key, "azure") - \ No newline at end of file diff --git a/src/exchange/providers/base.py b/src/exchange/providers/base.py index 056f93c..38bcf74 100644 --- a/src/exchange/providers/base.py +++ b/src/exchange/providers/base.py @@ -29,6 +29,7 @@ def complete( """Generate the next message using the specified model""" pass + class MissingProviderEnvVariableError(Exception): def __init__(self, env_variable: str, provider: str, instructions: Optional[str] = None) -> None: self.env_variable = env_variable @@ -37,4 +38,4 @@ def __init__(self, env_variable: str, provider: str, instructions: Optional[str] self.message = f"Missing environment variable: {env_variable} for provider {provider}" if instructions: self.message += f". {instructions}" - super().__init__(self.message) \ No newline at end of file + super().__init__(self.message) diff --git a/src/exchange/providers/bedrock.py b/src/exchange/providers/bedrock.py index 04c1ffe..c4ca400 100644 --- a/src/exchange/providers/bedrock.py +++ b/src/exchange/providers/bedrock.py @@ -325,5 +325,5 @@ def tools_to_bedrock_spec(tools: Tuple[Tool]) -> Optional[dict]: return tool_config @classmethod - def _get_env_variable(cls:Type["BedrockProvider"], key: str) -> str: + def _get_env_variable(cls: Type["BedrockProvider"], key: str) -> str: return get_provider_env_value(key, "bedrock") diff --git a/src/exchange/providers/databricks.py b/src/exchange/providers/databricks.py index e9d45b4..77d392e 100644 --- a/src/exchange/providers/databricks.py +++ b/src/exchange/providers/databricks.py @@ -1,4 +1,3 @@ -import os from typing import Any, Dict, List, Tuple, Type import httpx @@ -92,6 +91,6 @@ def _post(self, model: str, payload: dict) -> httpx.Response: return raise_for_status(response).json() @classmethod - def _get_env_variable(cls:Type["DatabricksProvider"], key: str) -> str: + def _get_env_variable(cls: Type["DatabricksProvider"], key: str) -> str: instruction = "https://docs.databricks.com/en/dev-tools/auth/index.html#general-host-token-and-account-id-environment-variables-and-fields" return get_provider_env_value(key, "databricks", instruction) diff --git a/src/exchange/providers/utils.py b/src/exchange/providers/utils.py index 4297dc2..7150598 100644 --- a/src/exchange/providers/utils.py +++ b/src/exchange/providers/utils.py @@ -180,12 +180,14 @@ def openai_single_message_context_length_exceeded(error_dict: dict) -> None: if code == "context_length_exceeded" or code == "string_above_max_length": raise InitialMessageTooLargeError(f"Input message too long. Message: {error_dict.get('message')}") + def get_provider_env_value(env_variable: str, provider: str, instructions: Optional[str] = None) -> str: try: return os.environ[env_variable] except KeyError: raise MissingProviderEnvVariableError(env_variable, provider, instructions) + class InitialMessageTooLargeError(Exception): """Custom error raised when the first input message in an exchange is too large.""" diff --git a/tests/providers/test_anthropic.py b/tests/providers/test_anthropic.py index f96e229..5db3ec5 100644 --- a/tests/providers/test_anthropic.py +++ b/tests/providers/test_anthropic.py @@ -25,6 +25,7 @@ def example_fn(param: str) -> None: def anthropic_provider(): return AnthropicProvider.from_env() + def test_from_env_throw_error_when_missing_api_key(): with patch.dict(os.environ, {}, clear=True): with pytest.raises(MissingProviderEnvVariableError) as context: @@ -33,6 +34,7 @@ def test_from_env_throw_error_when_missing_api_key(): assert context.value.env_variable == "ANTHROPIC_API_KEY" assert context.value.message == "Missing environment variable: ANTHROPIC_API_KEY for provider anthropic" + def test_anthropic_response_to_text_message() -> None: response = { "content": [{"type": "text", "text": "Hello from Claude!"}], diff --git a/tests/providers/test_azure.py b/tests/providers/test_azure.py index 21c1ab5..9d4203f 100644 --- a/tests/providers/test_azure.py +++ b/tests/providers/test_azure.py @@ -10,6 +10,7 @@ AZURE_MODEL = os.getenv("AZURE_MODEL", "gpt-4o-mini") + @pytest.mark.parametrize( "env_var_name", [ @@ -17,15 +18,19 @@ ("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME"), ("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION"), ("AZURE_CHAT_COMPLETIONS_KEY"), - ] + ], ) def test_from_env_throw_error_when_missing_env_var(env_var_name): - with patch.dict(os.environ, { - "AZURE_CHAT_COMPLETIONS_HOST_NAME": "test_host_name", - "AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME": "test_deployment_name", - "AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION": "test_api_version", - "AZURE_CHAT_COMPLETIONS_KEY": "test_api_key", - }, clear=True): + with patch.dict( + os.environ, + { + "AZURE_CHAT_COMPLETIONS_HOST_NAME": "test_host_name", + "AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME": "test_deployment_name", + "AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION": "test_api_version", + "AZURE_CHAT_COMPLETIONS_KEY": "test_api_key", + }, + clear=True, + ): os.environ.pop(env_var_name) with pytest.raises(MissingProviderEnvVariableError) as context: AzureProvider.from_env() @@ -33,6 +38,7 @@ def test_from_env_throw_error_when_missing_env_var(env_var_name): assert context.value.env_variable == env_var_name assert context.value.message == f"Missing environment variable: {env_var_name} for provider azure" + @pytest.mark.vcr() def test_azure_complete(default_azure_env): reply_message, reply_usage = complete(AzureProvider, AZURE_MODEL) diff --git a/tests/providers/test_bedrock.py b/tests/providers/test_bedrock.py index 0bc2b2c..d31a738 100644 --- a/tests/providers/test_bedrock.py +++ b/tests/providers/test_bedrock.py @@ -11,20 +11,25 @@ logger = logging.getLogger(__name__) + @pytest.mark.parametrize( "env_var_name", [ ("AWS_ACCESS_KEY_ID"), ("AWS_SECRET_ACCESS_KEY"), ("AWS_SESSION_TOKEN"), - ] + ], ) def test_from_env_throw_error_when_missing_env_var(env_var_name): - with patch.dict(os.environ, { - "AWS_ACCESS_KEY_ID": "test_access_key_id", - "AWS_SECRET_ACCESS_KEY": "test_secret_access_key", - "AWS_SESSION_TOKEN": "test_session_token", - }, clear=True): + with patch.dict( + os.environ, + { + "AWS_ACCESS_KEY_ID": "test_access_key_id", + "AWS_SECRET_ACCESS_KEY": "test_secret_access_key", + "AWS_SESSION_TOKEN": "test_session_token", + }, + clear=True, + ): os.environ.pop(env_var_name) with pytest.raises(MissingProviderEnvVariableError) as context: BedrockProvider.from_env() @@ -32,6 +37,7 @@ def test_from_env_throw_error_when_missing_env_var(env_var_name): assert context.value.env_variable == env_var_name assert context.value.message == f"Missing environment variable: {env_var_name} for provider bedrock" + @pytest.fixture @patch.dict( os.environ, diff --git a/tests/providers/test_databricks.py b/tests/providers/test_databricks.py index 4be09ac..4b6793a 100644 --- a/tests/providers/test_databricks.py +++ b/tests/providers/test_databricks.py @@ -6,18 +6,23 @@ from exchange.providers.base import MissingProviderEnvVariableError from exchange.providers.databricks import DatabricksProvider + @pytest.mark.parametrize( "env_var_name", [ ("DATABRICKS_HOST"), ("DATABRICKS_TOKEN"), - ] + ], ) def test_from_env_throw_error_when_missing_env_var(env_var_name): - with patch.dict(os.environ, { - "DATABRICKS_HOST": "test_host", - "DATABRICKS_TOKEN": "test_token", - }, clear=True): + with patch.dict( + os.environ, + { + "DATABRICKS_HOST": "test_host", + "DATABRICKS_TOKEN": "test_token", + }, + clear=True, + ): os.environ.pop(env_var_name) with pytest.raises(MissingProviderEnvVariableError) as context: DatabricksProvider.from_env() @@ -26,6 +31,7 @@ def test_from_env_throw_error_when_missing_env_var(env_var_name): assert f"Missing environment variable: {env_var_name} for provider databricks" in context.value.message assert "https://docs.databricks.com" in context.value.message + @pytest.fixture @patch.dict( os.environ, diff --git a/tests/providers/test_openai.py b/tests/providers/test_openai.py index 4a410b2..ea979ab 100644 --- a/tests/providers/test_openai.py +++ b/tests/providers/test_openai.py @@ -20,6 +20,7 @@ def test_from_env_throw_error_when_missing_api_key(): assert "Missing environment variable: OPENAI_API_KEY for provider openai" in context.value.message assert "https://platform.openai.com" in context.value.message + @pytest.mark.vcr() def test_openai_complete(default_openai_env): reply_message, reply_usage = complete(OpenAiProvider, OPENAI_MODEL) From 2e0e2f8e5b9589b34f862967655116ba0d5bf13b Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 2 Oct 2024 10:41:27 +1000 Subject: [PATCH 3/7] throw error for unknown providers or moderator --- src/exchange/load_exchange_attribute_error.py | 6 ++++++ src/exchange/moderators/__init__.py | 6 +++++- src/exchange/providers/__init__.py | 6 +++++- tests/providers/test_provider.py | 16 ++++++++++++++++ tests/test_load_exchange_attribute_error.py | 11 +++++++++++ tests/test_moderators.py | 15 +++++++++++++++ 6 files changed, 58 insertions(+), 2 deletions(-) create mode 100644 src/exchange/load_exchange_attribute_error.py create mode 100644 tests/providers/test_provider.py create mode 100644 tests/test_load_exchange_attribute_error.py create mode 100644 tests/test_moderators.py diff --git a/src/exchange/load_exchange_attribute_error.py b/src/exchange/load_exchange_attribute_error.py new file mode 100644 index 0000000..131f6f9 --- /dev/null +++ b/src/exchange/load_exchange_attribute_error.py @@ -0,0 +1,6 @@ +class LoadExchangeAttributeError(Exception): + def __init__(self, attribute_name: str, attribute_value: str) -> None: + self.attribute_name = attribute_name + self.attribute_value = attribute_value + self.message = f"Unknown {attribute_name}: {attribute_value}" + super().__init__(self.message) diff --git a/src/exchange/moderators/__init__.py b/src/exchange/moderators/__init__.py index 56b198a..5835f88 100644 --- a/src/exchange/moderators/__init__.py +++ b/src/exchange/moderators/__init__.py @@ -1,6 +1,7 @@ from functools import cache from typing import Type +from exchange.load_exchange_attribute_error import LoadExchangeAttributeError from exchange.moderators.base import Moderator from exchange.utils import load_plugins from exchange.moderators.passive import PassiveModerator # noqa @@ -10,4 +11,7 @@ @cache def get_moderator(name: str) -> Type[Moderator]: - return load_plugins(group="exchange.moderator")[name] + moderators = load_plugins(group="exchange.moderator") + if name not in moderators: + raise LoadExchangeAttributeError("moderator", name) + return moderators[name] diff --git a/src/exchange/providers/__init__.py b/src/exchange/providers/__init__.py index 177ea63..ca83c5b 100644 --- a/src/exchange/providers/__init__.py +++ b/src/exchange/providers/__init__.py @@ -1,6 +1,7 @@ from functools import cache from typing import Type +from exchange.load_exchange_attribute_error import LoadExchangeAttributeError from exchange.providers.anthropic import AnthropicProvider # noqa from exchange.providers.base import Provider, Usage # noqa from exchange.providers.databricks import DatabricksProvider # noqa @@ -13,4 +14,7 @@ @cache def get_provider(name: str) -> Type[Provider]: - return load_plugins(group="exchange.provider")[name] + providers = load_plugins(group="exchange.provider") + if name not in providers: + raise LoadExchangeAttributeError("provider", name) + return providers[name] diff --git a/tests/providers/test_provider.py b/tests/providers/test_provider.py new file mode 100644 index 0000000..39b0992 --- /dev/null +++ b/tests/providers/test_provider.py @@ -0,0 +1,16 @@ +import pytest +from exchange.load_exchange_attribute_error import LoadExchangeAttributeError +from exchange.providers import get_provider + + +def test_get_provider_valid(): + provider_name = "openai" + provider = get_provider(provider_name) + assert provider.__name__ == "OpenAiProvider" + + +def test_get_provider_throw_error_for_unknown_provider(): + with pytest.raises(LoadExchangeAttributeError) as error: + get_provider("nonexistent") + assert error.value.attribute_name == "provider" + assert error.value.attribute_value == "nonexistent" diff --git a/tests/test_load_exchange_attribute_error.py b/tests/test_load_exchange_attribute_error.py new file mode 100644 index 0000000..aa9e106 --- /dev/null +++ b/tests/test_load_exchange_attribute_error.py @@ -0,0 +1,11 @@ +from exchange.load_exchange_attribute_error import LoadExchangeAttributeError + + +def test_load_exchange_attribute_error(): + attribute_name = "provider" + attribute_value = "not_exist" + error = LoadExchangeAttributeError(attribute_name, attribute_value) + + assert error.attribute_name == attribute_name + assert error.attribute_value == attribute_value + assert error.message == "Unknown provider: not_exist" diff --git a/tests/test_moderators.py b/tests/test_moderators.py new file mode 100644 index 0000000..f9ef0b9 --- /dev/null +++ b/tests/test_moderators.py @@ -0,0 +1,15 @@ +from exchange.load_exchange_attribute_error import LoadExchangeAttributeError +from exchange.moderators import get_moderator +import pytest + + +def test_get_moderator(): + moderator = get_moderator("truncate") + assert moderator.__name__ == "ContextTruncate" + + +def test_get_moderator_raise_error_for_unknown_moderator(): + with pytest.raises(LoadExchangeAttributeError) as error: + get_moderator("nonexistent") + assert error.value.attribute_name == "moderator" + assert error.value.attribute_value == "nonexistent" From 5b2c94bd6d10e81d39be2ed9708fd715c6ba8563 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 2 Oct 2024 11:20:33 +1000 Subject: [PATCH 4/7] pass available values for moderator and provider --- src/exchange/load_exchange_attribute_error.py | 11 +++++++++-- src/exchange/moderators/__init__.py | 2 +- src/exchange/providers/__init__.py | 2 +- tests/providers/test_provider.py | 2 ++ tests/test_load_exchange_attribute_error.py | 8 +++++--- tests/test_moderators.py | 2 ++ 6 files changed, 20 insertions(+), 7 deletions(-) diff --git a/src/exchange/load_exchange_attribute_error.py b/src/exchange/load_exchange_attribute_error.py index 131f6f9..c9fdd26 100644 --- a/src/exchange/load_exchange_attribute_error.py +++ b/src/exchange/load_exchange_attribute_error.py @@ -1,6 +1,13 @@ +from typing import List + + class LoadExchangeAttributeError(Exception): - def __init__(self, attribute_name: str, attribute_value: str) -> None: + def __init__(self, attribute_name: str, attribute_value: str, available_values: List[str]) -> None: self.attribute_name = attribute_name self.attribute_value = attribute_value - self.message = f"Unknown {attribute_name}: {attribute_value}" + self.available_values = available_values + self.message = ( + f"Unknown {attribute_name}: {attribute_value}." + + f"Available {attribute_name}s: {', '.join(available_values)}" + ) super().__init__(self.message) diff --git a/src/exchange/moderators/__init__.py b/src/exchange/moderators/__init__.py index 5835f88..8bcdae8 100644 --- a/src/exchange/moderators/__init__.py +++ b/src/exchange/moderators/__init__.py @@ -13,5 +13,5 @@ def get_moderator(name: str) -> Type[Moderator]: moderators = load_plugins(group="exchange.moderator") if name not in moderators: - raise LoadExchangeAttributeError("moderator", name) + raise LoadExchangeAttributeError("moderator", name, moderators.keys()) return moderators[name] diff --git a/src/exchange/providers/__init__.py b/src/exchange/providers/__init__.py index ca83c5b..127ac16 100644 --- a/src/exchange/providers/__init__.py +++ b/src/exchange/providers/__init__.py @@ -16,5 +16,5 @@ def get_provider(name: str) -> Type[Provider]: providers = load_plugins(group="exchange.provider") if name not in providers: - raise LoadExchangeAttributeError("provider", name) + raise LoadExchangeAttributeError("provider", name, providers.keys()) return providers[name] diff --git a/tests/providers/test_provider.py b/tests/providers/test_provider.py index 39b0992..0038d2e 100644 --- a/tests/providers/test_provider.py +++ b/tests/providers/test_provider.py @@ -14,3 +14,5 @@ def test_get_provider_throw_error_for_unknown_provider(): get_provider("nonexistent") assert error.value.attribute_name == "provider" assert error.value.attribute_value == "nonexistent" + assert "openai" in error.value.available_values + assert "openai" in error.value.message diff --git a/tests/test_load_exchange_attribute_error.py b/tests/test_load_exchange_attribute_error.py index aa9e106..f6f17f1 100644 --- a/tests/test_load_exchange_attribute_error.py +++ b/tests/test_load_exchange_attribute_error.py @@ -2,10 +2,12 @@ def test_load_exchange_attribute_error(): - attribute_name = "provider" + attribute_name = "moderator" attribute_value = "not_exist" - error = LoadExchangeAttributeError(attribute_name, attribute_value) + available_values = ["truncate", "summarizer"] + error = LoadExchangeAttributeError(attribute_name, attribute_value, available_values) assert error.attribute_name == attribute_name assert error.attribute_value == attribute_value - assert error.message == "Unknown provider: not_exist" + assert error.attribute_value == attribute_value + assert error.message == "Unknown moderator: not_exist. Available moderators: truncate, summarizer" diff --git a/tests/test_moderators.py b/tests/test_moderators.py index f9ef0b9..8d8478f 100644 --- a/tests/test_moderators.py +++ b/tests/test_moderators.py @@ -13,3 +13,5 @@ def test_get_moderator_raise_error_for_unknown_moderator(): get_moderator("nonexistent") assert error.value.attribute_name == "moderator" assert error.value.attribute_value == "nonexistent" + assert "truncate" in error.value.available_values + assert "truncate" in error.value.message From bb2214b1498aa153582629377a51a825b94c19f6 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 2 Oct 2024 11:25:18 +1000 Subject: [PATCH 5/7] check api key in google --- src/exchange/providers/google.py | 11 +++-------- tests/providers/test_google.py | 11 +++++++++++ 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/exchange/providers/google.py b/src/exchange/providers/google.py index 426aa79..b58ddb6 100644 --- a/src/exchange/providers/google.py +++ b/src/exchange/providers/google.py @@ -7,7 +7,7 @@ from exchange.content import Text, ToolResult, ToolUse from exchange.providers.base import Provider, Usage from tenacity import retry, wait_fixed, stop_after_attempt -from exchange.providers.utils import retry_if_status +from exchange.providers.utils import get_provider_env_value, retry_if_status from exchange.providers.utils import raise_for_status GOOGLE_HOST = "https://generativelanguage.googleapis.com/v1beta" @@ -27,13 +27,8 @@ def __init__(self, client: httpx.Client) -> None: @classmethod def from_env(cls: Type["GoogleProvider"]) -> "GoogleProvider": url = os.environ.get("GOOGLE_HOST", GOOGLE_HOST) - try: - key = os.environ["GOOGLE_API_KEY"] - except KeyError: - raise RuntimeError( - "Failed to get GOOGLE_API_KEY from the environment, see https://ai.google.dev/gemini-api/docs/api-key" - ) - + api_key_instructions = "see https://ai.google.dev/gemini-api/docs/api-key" + key = get_provider_env_value("GOOGLE_API_KEY", "google", api_key_instructions) client = httpx.Client( base_url=url, headers={ diff --git a/tests/providers/test_google.py b/tests/providers/test_google.py index 47ad46b..76ae4c8 100644 --- a/tests/providers/test_google.py +++ b/tests/providers/test_google.py @@ -5,6 +5,7 @@ import pytest from exchange import Message, Text from exchange.content import ToolResult, ToolUse +from exchange.providers.base import MissingProviderEnvVariableError from exchange.providers.google import GoogleProvider from exchange.tool import Tool @@ -19,6 +20,16 @@ def example_fn(param: str) -> None: pass +def test_from_env_throw_error_when_missing_api_key(): + with patch.dict(os.environ, {}, clear=True): + with pytest.raises(MissingProviderEnvVariableError) as context: + GoogleProvider.from_env() + assert context.value.provider == "google" + assert context.value.env_variable == "GOOGLE_API_KEY" + assert "Missing environment variable: GOOGLE_API_KEY for provider google" in context.value.message + assert "https://ai.google.dev/gemini-api/docs/api-key" in context.value.message + + @pytest.fixture @patch.dict(os.environ, {"GOOGLE_API_KEY": "test_api_key"}) def google_provider(): From ba328b5d8acf87074c2cd88163a27cbb13dad44c Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 2 Oct 2024 11:33:42 +1000 Subject: [PATCH 6/7] fixed the test --- src/exchange/load_exchange_attribute_error.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/exchange/load_exchange_attribute_error.py b/src/exchange/load_exchange_attribute_error.py index c9fdd26..bf1c4b6 100644 --- a/src/exchange/load_exchange_attribute_error.py +++ b/src/exchange/load_exchange_attribute_error.py @@ -8,6 +8,6 @@ def __init__(self, attribute_name: str, attribute_value: str, available_values: self.available_values = available_values self.message = ( f"Unknown {attribute_name}: {attribute_value}." - + f"Available {attribute_name}s: {', '.join(available_values)}" + + f" Available {attribute_name}s: {', '.join(available_values)}" ) super().__init__(self.message) From 329274e47bb1dc311c5735be3516406c30696c65 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 2 Oct 2024 12:24:09 +1000 Subject: [PATCH 7/7] refactor instructions to instructions_url --- src/exchange/providers/base.py | 10 +++++----- src/exchange/providers/google.py | 4 ++-- src/exchange/providers/openai.py | 4 ++-- src/exchange/providers/utils.py | 4 ++-- tests/providers/test_anthropic.py | 2 +- tests/providers/test_azure.py | 2 +- tests/providers/test_bedrock.py | 2 +- tests/test_base.py | 27 +++++++++++++++++++++++++++ 8 files changed, 41 insertions(+), 14 deletions(-) create mode 100644 tests/test_base.py diff --git a/src/exchange/providers/base.py b/src/exchange/providers/base.py index 38bcf74..7ec8745 100644 --- a/src/exchange/providers/base.py +++ b/src/exchange/providers/base.py @@ -31,11 +31,11 @@ def complete( class MissingProviderEnvVariableError(Exception): - def __init__(self, env_variable: str, provider: str, instructions: Optional[str] = None) -> None: + def __init__(self, env_variable: str, provider: str, instructions_url: Optional[str] = None) -> None: self.env_variable = env_variable self.provider = provider - self.instructions = instructions - self.message = f"Missing environment variable: {env_variable} for provider {provider}" - if instructions: - self.message += f". {instructions}" + self.instructions_url = instructions_url + self.message = f"Missing environment variable: {env_variable} for provider {provider}." + if instructions_url: + self.message += f"\n Please see {instructions_url} for instructions" super().__init__(self.message) diff --git a/src/exchange/providers/google.py b/src/exchange/providers/google.py index b58ddb6..349b803 100644 --- a/src/exchange/providers/google.py +++ b/src/exchange/providers/google.py @@ -27,8 +27,8 @@ def __init__(self, client: httpx.Client) -> None: @classmethod def from_env(cls: Type["GoogleProvider"]) -> "GoogleProvider": url = os.environ.get("GOOGLE_HOST", GOOGLE_HOST) - api_key_instructions = "see https://ai.google.dev/gemini-api/docs/api-key" - key = get_provider_env_value("GOOGLE_API_KEY", "google", api_key_instructions) + api_key_instructions_url = "see https://ai.google.dev/gemini-api/docs/api-key" + key = get_provider_env_value("GOOGLE_API_KEY", "google", api_key_instructions_url) client = httpx.Client( base_url=url, headers={ diff --git a/src/exchange/providers/openai.py b/src/exchange/providers/openai.py index 1364c04..d921020 100644 --- a/src/exchange/providers/openai.py +++ b/src/exchange/providers/openai.py @@ -37,8 +37,8 @@ def __init__(self, client: httpx.Client) -> None: @classmethod def from_env(cls: Type["OpenAiProvider"]) -> "OpenAiProvider": url = os.environ.get("OPENAI_HOST", OPENAI_HOST) - api_key_instructions = "see https://platform.openai.com/docs/api-reference/api-keys" - key = get_provider_env_value("OPENAI_API_KEY", "openai", api_key_instructions) + api_key_instructions_url = "see https://platform.openai.com/docs/api-reference/api-keys" + key = get_provider_env_value("OPENAI_API_KEY", "openai", api_key_instructions_url) client = httpx.Client( base_url=url + "v1/", auth=("Bearer", key), diff --git a/src/exchange/providers/utils.py b/src/exchange/providers/utils.py index 7150598..0150430 100644 --- a/src/exchange/providers/utils.py +++ b/src/exchange/providers/utils.py @@ -181,11 +181,11 @@ def openai_single_message_context_length_exceeded(error_dict: dict) -> None: raise InitialMessageTooLargeError(f"Input message too long. Message: {error_dict.get('message')}") -def get_provider_env_value(env_variable: str, provider: str, instructions: Optional[str] = None) -> str: +def get_provider_env_value(env_variable: str, provider: str, instructions_url: Optional[str] = None) -> str: try: return os.environ[env_variable] except KeyError: - raise MissingProviderEnvVariableError(env_variable, provider, instructions) + raise MissingProviderEnvVariableError(env_variable, provider, instructions_url) class InitialMessageTooLargeError(Exception): diff --git a/tests/providers/test_anthropic.py b/tests/providers/test_anthropic.py index 5db3ec5..272ebcb 100644 --- a/tests/providers/test_anthropic.py +++ b/tests/providers/test_anthropic.py @@ -32,7 +32,7 @@ def test_from_env_throw_error_when_missing_api_key(): AnthropicProvider.from_env() assert context.value.provider == "anthropic" assert context.value.env_variable == "ANTHROPIC_API_KEY" - assert context.value.message == "Missing environment variable: ANTHROPIC_API_KEY for provider anthropic" + assert context.value.message == "Missing environment variable: ANTHROPIC_API_KEY for provider anthropic." def test_anthropic_response_to_text_message() -> None: diff --git a/tests/providers/test_azure.py b/tests/providers/test_azure.py index 9d4203f..b46be30 100644 --- a/tests/providers/test_azure.py +++ b/tests/providers/test_azure.py @@ -36,7 +36,7 @@ def test_from_env_throw_error_when_missing_env_var(env_var_name): AzureProvider.from_env() assert context.value.provider == "azure" assert context.value.env_variable == env_var_name - assert context.value.message == f"Missing environment variable: {env_var_name} for provider azure" + assert context.value.message == f"Missing environment variable: {env_var_name} for provider azure." @pytest.mark.vcr() diff --git a/tests/providers/test_bedrock.py b/tests/providers/test_bedrock.py index d31a738..f8fcaa4 100644 --- a/tests/providers/test_bedrock.py +++ b/tests/providers/test_bedrock.py @@ -35,7 +35,7 @@ def test_from_env_throw_error_when_missing_env_var(env_var_name): BedrockProvider.from_env() assert context.value.provider == "bedrock" assert context.value.env_variable == env_var_name - assert context.value.message == f"Missing environment variable: {env_var_name} for provider bedrock" + assert context.value.message == f"Missing environment variable: {env_var_name} for provider bedrock." @pytest.fixture diff --git a/tests/test_base.py b/tests/test_base.py new file mode 100644 index 0000000..9b3d6d4 --- /dev/null +++ b/tests/test_base.py @@ -0,0 +1,27 @@ +from exchange.providers.base import MissingProviderEnvVariableError + + +def test_missing_provider_env_variable_error_without_instructions_url(): + env_variable = "API_KEY" + provider = "TestProvider" + error = MissingProviderEnvVariableError(env_variable, provider) + + assert error.env_variable == env_variable + assert error.provider == provider + assert error.instructions_url is None + assert error.message == "Missing environment variable: API_KEY for provider TestProvider." + + +def test_missing_provider_env_variable_error_with_instructions_url(): + env_variable = "API_KEY" + provider = "TestProvider" + instructions_url = "http://example.com/instructions" + error = MissingProviderEnvVariableError(env_variable, provider, instructions_url) + + assert error.env_variable == env_variable + assert error.provider == provider + assert error.instructions_url == instructions_url + assert error.message == ( + "Missing environment variable: API_KEY for provider TestProvider.\n" + " Please see http://example.com/instructions for instructions" + )