diff --git a/src/exchange/load_exchange_attribute_error.py b/src/exchange/load_exchange_attribute_error.py new file mode 100644 index 0000000..bf1c4b6 --- /dev/null +++ b/src/exchange/load_exchange_attribute_error.py @@ -0,0 +1,13 @@ +from typing import List + + +class LoadExchangeAttributeError(Exception): + def __init__(self, attribute_name: str, attribute_value: str, available_values: List[str]) -> None: + self.attribute_name = attribute_name + self.attribute_value = attribute_value + self.available_values = available_values + self.message = ( + f"Unknown {attribute_name}: {attribute_value}." + + f" Available {attribute_name}s: {', '.join(available_values)}" + ) + super().__init__(self.message) diff --git a/src/exchange/moderators/__init__.py b/src/exchange/moderators/__init__.py index 56b198a..8bcdae8 100644 --- a/src/exchange/moderators/__init__.py +++ b/src/exchange/moderators/__init__.py @@ -1,6 +1,7 @@ from functools import cache from typing import Type +from exchange.load_exchange_attribute_error import LoadExchangeAttributeError from exchange.moderators.base import Moderator from exchange.utils import load_plugins from exchange.moderators.passive import PassiveModerator # noqa @@ -10,4 +11,7 @@ @cache def get_moderator(name: str) -> Type[Moderator]: - return load_plugins(group="exchange.moderator")[name] + moderators = load_plugins(group="exchange.moderator") + if name not in moderators: + raise LoadExchangeAttributeError("moderator", name, moderators.keys()) + return moderators[name] diff --git a/src/exchange/providers/__init__.py b/src/exchange/providers/__init__.py index ac7ed07..65e8374 100644 --- a/src/exchange/providers/__init__.py +++ b/src/exchange/providers/__init__.py @@ -1,6 +1,7 @@ from functools import cache from typing import Type +from exchange.load_exchange_attribute_error import LoadExchangeAttributeError from exchange.providers.anthropic import AnthropicProvider # noqa from exchange.providers.base import Provider, Usage # noqa from exchange.providers.databricks import DatabricksProvider # noqa @@ -14,4 +15,7 @@ @cache def get_provider(name: str) -> Type[Provider]: - return load_plugins(group="exchange.provider")[name] + providers = load_plugins(group="exchange.provider") + if name not in providers: + raise LoadExchangeAttributeError("provider", name, providers.keys()) + return providers[name] diff --git a/src/exchange/providers/anthropic.py b/src/exchange/providers/anthropic.py index 154ec5f..05a980c 100644 --- a/src/exchange/providers/anthropic.py +++ b/src/exchange/providers/anthropic.py @@ -7,7 +7,7 @@ from exchange.content import Text, ToolResult, ToolUse from exchange.providers.base import Provider, Usage from tenacity import retry, wait_fixed, stop_after_attempt -from exchange.providers.utils import retry_if_status +from exchange.providers.utils import get_provider_env_value, retry_if_status from exchange.providers.utils import raise_for_status ANTHROPIC_HOST = "https://api.anthropic.com/v1/messages" @@ -27,10 +27,7 @@ def __init__(self, client: httpx.Client) -> None: @classmethod def from_env(cls: Type["AnthropicProvider"]) -> "AnthropicProvider": url = os.environ.get("ANTHROPIC_HOST", ANTHROPIC_HOST) - try: - key = os.environ["ANTHROPIC_API_KEY"] - except KeyError: - raise RuntimeError("Failed to get ANTHROPIC_API_KEY from the environment") + key = get_provider_env_value("ANTHROPIC_API_KEY", "anthropic") client = httpx.Client( base_url=url, headers={ diff --git a/src/exchange/providers/azure.py b/src/exchange/providers/azure.py index 7bacb9d..212dcca 100644 --- a/src/exchange/providers/azure.py +++ b/src/exchange/providers/azure.py @@ -1,9 +1,9 @@ -import os from typing import Type import httpx from exchange.providers import OpenAiProvider +from exchange.providers.utils import get_provider_env_value class AzureProvider(OpenAiProvider): @@ -14,25 +14,11 @@ def __init__(self, client: httpx.Client) -> None: @classmethod def from_env(cls: Type["AzureProvider"]) -> "AzureProvider": - try: - url = os.environ["AZURE_CHAT_COMPLETIONS_HOST_NAME"] - except KeyError: - raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_HOST_NAME from the environment.") - - try: - deployment_name = os.environ["AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME"] - except KeyError: - raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME from the environment.") - - try: - api_version = os.environ["AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION"] - except KeyError: - raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION from the environment.") - - try: - key = os.environ["AZURE_CHAT_COMPLETIONS_KEY"] - except KeyError: - raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_KEY from the environment.") + url = cls._get_env_variable("AZURE_CHAT_COMPLETIONS_HOST_NAME") + deployment_name = cls._get_env_variable("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME") + + api_version = cls._get_env_variable("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION") + key = cls._get_env_variable("AZURE_CHAT_COMPLETIONS_KEY") # format the url host/"openai/deployments/" + deployment_name + "/?api-version=" + api_version url = f"{url}/openai/deployments/{deployment_name}/" @@ -43,3 +29,7 @@ def from_env(cls: Type["AzureProvider"]) -> "AzureProvider": timeout=httpx.Timeout(60 * 10), ) return cls(client) + + @classmethod + def _get_env_variable(cls: Type["AzureProvider"], key: str) -> str: + return get_provider_env_value(key, "azure") diff --git a/src/exchange/providers/base.py b/src/exchange/providers/base.py index 7b7ff88..7ec8745 100644 --- a/src/exchange/providers/base.py +++ b/src/exchange/providers/base.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from attrs import define, field -from typing import List, Tuple, Type +from typing import List, Optional, Tuple, Type from exchange.message import Message from exchange.tool import Tool @@ -28,3 +28,14 @@ def complete( ) -> Tuple[Message, Usage]: """Generate the next message using the specified model""" pass + + +class MissingProviderEnvVariableError(Exception): + def __init__(self, env_variable: str, provider: str, instructions_url: Optional[str] = None) -> None: + self.env_variable = env_variable + self.provider = provider + self.instructions_url = instructions_url + self.message = f"Missing environment variable: {env_variable} for provider {provider}." + if instructions_url: + self.message += f"\n Please see {instructions_url} for instructions" + super().__init__(self.message) diff --git a/src/exchange/providers/bedrock.py b/src/exchange/providers/bedrock.py index 2a5f53d..c4ca400 100644 --- a/src/exchange/providers/bedrock.py +++ b/src/exchange/providers/bedrock.py @@ -13,7 +13,7 @@ from exchange.message import Message from exchange.providers import Provider, Usage from tenacity import retry, wait_fixed, stop_after_attempt -from exchange.providers.utils import retry_if_status +from exchange.providers.utils import get_provider_env_value, retry_if_status from exchange.providers.utils import raise_for_status from exchange.tool import Tool @@ -154,12 +154,9 @@ def __init__(self, client: AwsClient) -> None: @classmethod def from_env(cls: Type["BedrockProvider"]) -> "BedrockProvider": aws_region = os.environ.get("AWS_REGION", "us-east-1") - try: - aws_access_key = os.environ["AWS_ACCESS_KEY_ID"] - aws_secret_key = os.environ["AWS_SECRET_ACCESS_KEY"] - aws_session_token = os.environ.get("AWS_SESSION_TOKEN") - except KeyError: - raise RuntimeError("Failed to get AWS_ACCESS_KEY_ID or AWS_SECRET_ACCESS_KEY from the environment") + aws_access_key = cls._get_env_variable("AWS_ACCESS_KEY_ID") + aws_secret_key = cls._get_env_variable("AWS_SECRET_ACCESS_KEY") + aws_session_token = cls._get_env_variable("AWS_SESSION_TOKEN") client = AwsClient( aws_region=aws_region, @@ -326,3 +323,7 @@ def tools_to_bedrock_spec(tools: Tuple[Tool]) -> Optional[dict]: tools_added.add(tool.name) tool_config = {"tools": tool_config_list} return tool_config + + @classmethod + def _get_env_variable(cls: Type["BedrockProvider"], key: str) -> str: + return get_provider_env_value(key, "bedrock") diff --git a/src/exchange/providers/databricks.py b/src/exchange/providers/databricks.py index 84dc751..77d392e 100644 --- a/src/exchange/providers/databricks.py +++ b/src/exchange/providers/databricks.py @@ -1,4 +1,3 @@ -import os from typing import Any, Dict, List, Tuple, Type import httpx @@ -6,7 +5,7 @@ from exchange.message import Message from exchange.providers.base import Provider, Usage from tenacity import retry, wait_fixed, stop_after_attempt -from exchange.providers.utils import raise_for_status, retry_if_status +from exchange.providers.utils import get_provider_env_value, raise_for_status, retry_if_status from exchange.providers.utils import ( messages_to_openai_spec, openai_response_to_message, @@ -37,18 +36,8 @@ def __init__(self, client: httpx.Client) -> None: @classmethod def from_env(cls: Type["DatabricksProvider"]) -> "DatabricksProvider": - try: - url = os.environ["DATABRICKS_HOST"] - except KeyError: - raise RuntimeError( - "Failed to get DATABRICKS_HOST from the environment. See https://docs.databricks.com/en/dev-tools/auth/index.html#general-host-token-and-account-id-environment-variables-and-fields" - ) - try: - key = os.environ["DATABRICKS_TOKEN"] - except KeyError: - raise RuntimeError( - "Failed to get DATABRICKS_TOKEN from the environment. See https://docs.databricks.com/en/dev-tools/auth/index.html#general-host-token-and-account-id-environment-variables-and-fields" - ) + url = cls._get_env_variable("DATABRICKS_HOST") + key = cls._get_env_variable("DATABRICKS_TOKEN") client = httpx.Client( base_url=url, auth=("token", key), @@ -100,3 +89,8 @@ def _post(self, model: str, payload: dict) -> httpx.Response: json=payload, ) return raise_for_status(response).json() + + @classmethod + def _get_env_variable(cls: Type["DatabricksProvider"], key: str) -> str: + instruction = "https://docs.databricks.com/en/dev-tools/auth/index.html#general-host-token-and-account-id-environment-variables-and-fields" + return get_provider_env_value(key, "databricks", instruction) diff --git a/src/exchange/providers/google.py b/src/exchange/providers/google.py index 426aa79..349b803 100644 --- a/src/exchange/providers/google.py +++ b/src/exchange/providers/google.py @@ -7,7 +7,7 @@ from exchange.content import Text, ToolResult, ToolUse from exchange.providers.base import Provider, Usage from tenacity import retry, wait_fixed, stop_after_attempt -from exchange.providers.utils import retry_if_status +from exchange.providers.utils import get_provider_env_value, retry_if_status from exchange.providers.utils import raise_for_status GOOGLE_HOST = "https://generativelanguage.googleapis.com/v1beta" @@ -27,13 +27,8 @@ def __init__(self, client: httpx.Client) -> None: @classmethod def from_env(cls: Type["GoogleProvider"]) -> "GoogleProvider": url = os.environ.get("GOOGLE_HOST", GOOGLE_HOST) - try: - key = os.environ["GOOGLE_API_KEY"] - except KeyError: - raise RuntimeError( - "Failed to get GOOGLE_API_KEY from the environment, see https://ai.google.dev/gemini-api/docs/api-key" - ) - + api_key_instructions_url = "see https://ai.google.dev/gemini-api/docs/api-key" + key = get_provider_env_value("GOOGLE_API_KEY", "google", api_key_instructions_url) client = httpx.Client( base_url=url, headers={ diff --git a/src/exchange/providers/openai.py b/src/exchange/providers/openai.py index dbd293b..d921020 100644 --- a/src/exchange/providers/openai.py +++ b/src/exchange/providers/openai.py @@ -6,6 +6,7 @@ from exchange.message import Message from exchange.providers.base import Provider, Usage from exchange.providers.utils import ( + get_provider_env_value, messages_to_openai_spec, openai_response_to_message, openai_single_message_context_length_exceeded, @@ -36,12 +37,8 @@ def __init__(self, client: httpx.Client) -> None: @classmethod def from_env(cls: Type["OpenAiProvider"]) -> "OpenAiProvider": url = os.environ.get("OPENAI_HOST", OPENAI_HOST) - try: - key = os.environ["OPENAI_API_KEY"] - except KeyError: - raise RuntimeError( - "Failed to get OPENAI_API_KEY from the environment, see https://platform.openai.com/docs/api-reference/api-keys" - ) + api_key_instructions_url = "see https://platform.openai.com/docs/api-reference/api-keys" + key = get_provider_env_value("OPENAI_API_KEY", "openai", api_key_instructions_url) client = httpx.Client( base_url=url + "v1/", auth=("Bearer", key), diff --git a/src/exchange/providers/utils.py b/src/exchange/providers/utils.py index 4be7ac3..0150430 100644 --- a/src/exchange/providers/utils.py +++ b/src/exchange/providers/utils.py @@ -1,11 +1,13 @@ import base64 import json +import os import re from typing import Any, Callable, Dict, List, Optional, Tuple import httpx from exchange.content import Text, ToolResult, ToolUse from exchange.message import Message +from exchange.providers.base import MissingProviderEnvVariableError from exchange.tool import Tool from tenacity import retry_if_exception @@ -179,6 +181,13 @@ def openai_single_message_context_length_exceeded(error_dict: dict) -> None: raise InitialMessageTooLargeError(f"Input message too long. Message: {error_dict.get('message')}") +def get_provider_env_value(env_variable: str, provider: str, instructions_url: Optional[str] = None) -> str: + try: + return os.environ[env_variable] + except KeyError: + raise MissingProviderEnvVariableError(env_variable, provider, instructions_url) + + class InitialMessageTooLargeError(Exception): """Custom error raised when the first input message in an exchange is too large.""" diff --git a/tests/providers/test_anthropic.py b/tests/providers/test_anthropic.py index a6f5bc6..272ebcb 100644 --- a/tests/providers/test_anthropic.py +++ b/tests/providers/test_anthropic.py @@ -6,6 +6,7 @@ from exchange import Message, Text from exchange.content import ToolResult, ToolUse from exchange.providers.anthropic import AnthropicProvider +from exchange.providers.base import MissingProviderEnvVariableError from exchange.tool import Tool @@ -25,6 +26,15 @@ def anthropic_provider(): return AnthropicProvider.from_env() +def test_from_env_throw_error_when_missing_api_key(): + with patch.dict(os.environ, {}, clear=True): + with pytest.raises(MissingProviderEnvVariableError) as context: + AnthropicProvider.from_env() + assert context.value.provider == "anthropic" + assert context.value.env_variable == "ANTHROPIC_API_KEY" + assert context.value.message == "Missing environment variable: ANTHROPIC_API_KEY for provider anthropic." + + def test_anthropic_response_to_text_message() -> None: response = { "content": [{"type": "text", "text": "Hello from Claude!"}], diff --git a/tests/providers/test_azure.py b/tests/providers/test_azure.py index adafabe..b46be30 100644 --- a/tests/providers/test_azure.py +++ b/tests/providers/test_azure.py @@ -1,14 +1,44 @@ import os +from unittest.mock import patch import pytest from exchange import Text, ToolUse from exchange.providers.azure import AzureProvider +from exchange.providers.base import MissingProviderEnvVariableError from .conftest import complete, tools AZURE_MODEL = os.getenv("AZURE_MODEL", "gpt-4o-mini") +@pytest.mark.parametrize( + "env_var_name", + [ + ("AZURE_CHAT_COMPLETIONS_HOST_NAME"), + ("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME"), + ("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION"), + ("AZURE_CHAT_COMPLETIONS_KEY"), + ], +) +def test_from_env_throw_error_when_missing_env_var(env_var_name): + with patch.dict( + os.environ, + { + "AZURE_CHAT_COMPLETIONS_HOST_NAME": "test_host_name", + "AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME": "test_deployment_name", + "AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION": "test_api_version", + "AZURE_CHAT_COMPLETIONS_KEY": "test_api_key", + }, + clear=True, + ): + os.environ.pop(env_var_name) + with pytest.raises(MissingProviderEnvVariableError) as context: + AzureProvider.from_env() + assert context.value.provider == "azure" + assert context.value.env_variable == env_var_name + assert context.value.message == f"Missing environment variable: {env_var_name} for provider azure." + + @pytest.mark.vcr() def test_azure_complete(default_azure_env): reply_message, reply_usage = complete(AzureProvider, AZURE_MODEL) diff --git a/tests/providers/test_bedrock.py b/tests/providers/test_bedrock.py index 2525f65..f8fcaa4 100644 --- a/tests/providers/test_bedrock.py +++ b/tests/providers/test_bedrock.py @@ -5,12 +5,39 @@ import pytest from exchange.content import Text, ToolResult, ToolUse from exchange.message import Message +from exchange.providers.base import MissingProviderEnvVariableError from exchange.providers.bedrock import BedrockProvider from exchange.tool import Tool logger = logging.getLogger(__name__) +@pytest.mark.parametrize( + "env_var_name", + [ + ("AWS_ACCESS_KEY_ID"), + ("AWS_SECRET_ACCESS_KEY"), + ("AWS_SESSION_TOKEN"), + ], +) +def test_from_env_throw_error_when_missing_env_var(env_var_name): + with patch.dict( + os.environ, + { + "AWS_ACCESS_KEY_ID": "test_access_key_id", + "AWS_SECRET_ACCESS_KEY": "test_secret_access_key", + "AWS_SESSION_TOKEN": "test_session_token", + }, + clear=True, + ): + os.environ.pop(env_var_name) + with pytest.raises(MissingProviderEnvVariableError) as context: + BedrockProvider.from_env() + assert context.value.provider == "bedrock" + assert context.value.env_variable == env_var_name + assert context.value.message == f"Missing environment variable: {env_var_name} for provider bedrock." + + @pytest.fixture @patch.dict( os.environ, diff --git a/tests/providers/test_databricks.py b/tests/providers/test_databricks.py index 3c14211..4b6793a 100644 --- a/tests/providers/test_databricks.py +++ b/tests/providers/test_databricks.py @@ -3,9 +3,35 @@ import pytest from exchange import Message, Text +from exchange.providers.base import MissingProviderEnvVariableError from exchange.providers.databricks import DatabricksProvider +@pytest.mark.parametrize( + "env_var_name", + [ + ("DATABRICKS_HOST"), + ("DATABRICKS_TOKEN"), + ], +) +def test_from_env_throw_error_when_missing_env_var(env_var_name): + with patch.dict( + os.environ, + { + "DATABRICKS_HOST": "test_host", + "DATABRICKS_TOKEN": "test_token", + }, + clear=True, + ): + os.environ.pop(env_var_name) + with pytest.raises(MissingProviderEnvVariableError) as context: + DatabricksProvider.from_env() + assert context.value.provider == "databricks" + assert context.value.env_variable == env_var_name + assert f"Missing environment variable: {env_var_name} for provider databricks" in context.value.message + assert "https://docs.databricks.com" in context.value.message + + @pytest.fixture @patch.dict( os.environ, diff --git a/tests/providers/test_google.py b/tests/providers/test_google.py index 47ad46b..76ae4c8 100644 --- a/tests/providers/test_google.py +++ b/tests/providers/test_google.py @@ -5,6 +5,7 @@ import pytest from exchange import Message, Text from exchange.content import ToolResult, ToolUse +from exchange.providers.base import MissingProviderEnvVariableError from exchange.providers.google import GoogleProvider from exchange.tool import Tool @@ -19,6 +20,16 @@ def example_fn(param: str) -> None: pass +def test_from_env_throw_error_when_missing_api_key(): + with patch.dict(os.environ, {}, clear=True): + with pytest.raises(MissingProviderEnvVariableError) as context: + GoogleProvider.from_env() + assert context.value.provider == "google" + assert context.value.env_variable == "GOOGLE_API_KEY" + assert "Missing environment variable: GOOGLE_API_KEY for provider google" in context.value.message + assert "https://ai.google.dev/gemini-api/docs/api-key" in context.value.message + + @pytest.fixture @patch.dict(os.environ, {"GOOGLE_API_KEY": "test_api_key"}) def google_provider(): diff --git a/tests/providers/test_openai.py b/tests/providers/test_openai.py index 45bc620..ea979ab 100644 --- a/tests/providers/test_openai.py +++ b/tests/providers/test_openai.py @@ -1,14 +1,26 @@ import os +from unittest.mock import patch import pytest from exchange import Text, ToolUse +from exchange.providers.base import MissingProviderEnvVariableError from exchange.providers.openai import OpenAiProvider from .conftest import complete, vision, tools OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini") +def test_from_env_throw_error_when_missing_api_key(): + with patch.dict(os.environ, {}, clear=True): + with pytest.raises(MissingProviderEnvVariableError) as context: + OpenAiProvider.from_env() + assert context.value.provider == "openai" + assert context.value.env_variable == "OPENAI_API_KEY" + assert "Missing environment variable: OPENAI_API_KEY for provider openai" in context.value.message + assert "https://platform.openai.com" in context.value.message + + @pytest.mark.vcr() def test_openai_complete(default_openai_env): reply_message, reply_usage = complete(OpenAiProvider, OPENAI_MODEL) diff --git a/tests/providers/test_provider.py b/tests/providers/test_provider.py new file mode 100644 index 0000000..0038d2e --- /dev/null +++ b/tests/providers/test_provider.py @@ -0,0 +1,18 @@ +import pytest +from exchange.load_exchange_attribute_error import LoadExchangeAttributeError +from exchange.providers import get_provider + + +def test_get_provider_valid(): + provider_name = "openai" + provider = get_provider(provider_name) + assert provider.__name__ == "OpenAiProvider" + + +def test_get_provider_throw_error_for_unknown_provider(): + with pytest.raises(LoadExchangeAttributeError) as error: + get_provider("nonexistent") + assert error.value.attribute_name == "provider" + assert error.value.attribute_value == "nonexistent" + assert "openai" in error.value.available_values + assert "openai" in error.value.message diff --git a/tests/test_base.py b/tests/test_base.py new file mode 100644 index 0000000..9b3d6d4 --- /dev/null +++ b/tests/test_base.py @@ -0,0 +1,27 @@ +from exchange.providers.base import MissingProviderEnvVariableError + + +def test_missing_provider_env_variable_error_without_instructions_url(): + env_variable = "API_KEY" + provider = "TestProvider" + error = MissingProviderEnvVariableError(env_variable, provider) + + assert error.env_variable == env_variable + assert error.provider == provider + assert error.instructions_url is None + assert error.message == "Missing environment variable: API_KEY for provider TestProvider." + + +def test_missing_provider_env_variable_error_with_instructions_url(): + env_variable = "API_KEY" + provider = "TestProvider" + instructions_url = "http://example.com/instructions" + error = MissingProviderEnvVariableError(env_variable, provider, instructions_url) + + assert error.env_variable == env_variable + assert error.provider == provider + assert error.instructions_url == instructions_url + assert error.message == ( + "Missing environment variable: API_KEY for provider TestProvider.\n" + " Please see http://example.com/instructions for instructions" + ) diff --git a/tests/test_load_exchange_attribute_error.py b/tests/test_load_exchange_attribute_error.py new file mode 100644 index 0000000..f6f17f1 --- /dev/null +++ b/tests/test_load_exchange_attribute_error.py @@ -0,0 +1,13 @@ +from exchange.load_exchange_attribute_error import LoadExchangeAttributeError + + +def test_load_exchange_attribute_error(): + attribute_name = "moderator" + attribute_value = "not_exist" + available_values = ["truncate", "summarizer"] + error = LoadExchangeAttributeError(attribute_name, attribute_value, available_values) + + assert error.attribute_name == attribute_name + assert error.attribute_value == attribute_value + assert error.attribute_value == attribute_value + assert error.message == "Unknown moderator: not_exist. Available moderators: truncate, summarizer" diff --git a/tests/test_moderators.py b/tests/test_moderators.py new file mode 100644 index 0000000..8d8478f --- /dev/null +++ b/tests/test_moderators.py @@ -0,0 +1,17 @@ +from exchange.load_exchange_attribute_error import LoadExchangeAttributeError +from exchange.moderators import get_moderator +import pytest + + +def test_get_moderator(): + moderator = get_moderator("truncate") + assert moderator.__name__ == "ContextTruncate" + + +def test_get_moderator_raise_error_for_unknown_moderator(): + with pytest.raises(LoadExchangeAttributeError) as error: + get_moderator("nonexistent") + assert error.value.attribute_name == "moderator" + assert error.value.attribute_value == "nonexistent" + assert "truncate" in error.value.available_values + assert "truncate" in error.value.message