Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -121,4 +121,7 @@ celerybeat.pid
.vscode

# Autogenerated docs files
docs/docs/reference
docs/docs/reference

# uv lock file
uv.lock
13 changes: 13 additions & 0 deletions packages/exchange/src/exchange/invalid_choice_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import List


class InvalidChoiceError(Exception):
def __init__(self, attribute_name: str, attribute_value: str, available_values: List[str]) -> None:
self.attribute_name = attribute_name
self.attribute_value = attribute_value
self.available_values = available_values
self.message = (
f"Unknown {attribute_name}: {attribute_value}."
+ f" Available {attribute_name}s: {', '.join(available_values)}"
)
super().__init__(self.message)
6 changes: 5 additions & 1 deletion packages/exchange/src/exchange/moderators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import cache
from typing import Type

from exchange.invalid_choice_error import InvalidChoiceError
from exchange.moderators.base import Moderator
from exchange.utils import load_plugins
from exchange.moderators.passive import PassiveModerator # noqa
Expand All @@ -10,4 +11,7 @@

@cache
def get_moderator(name: str) -> Type[Moderator]:
return load_plugins(group="exchange.moderator")[name]
moderators = load_plugins(group="exchange.moderator")
if name not in moderators:
raise InvalidChoiceError("moderator", name, moderators.keys())
return moderators[name]
6 changes: 5 additions & 1 deletion packages/exchange/src/exchange/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import cache
from typing import Type

from exchange.invalid_choice_error import InvalidChoiceError
from exchange.providers.anthropic import AnthropicProvider # noqa
from exchange.providers.base import Provider, Usage # noqa
from exchange.providers.databricks import DatabricksProvider # noqa
Expand All @@ -14,4 +15,7 @@

@cache
def get_provider(name: str) -> Type[Provider]:
return load_plugins(group="exchange.provider")[name]
providers = load_plugins(group="exchange.provider")
if name not in providers:
raise InvalidChoiceError("provider", name, providers.keys())
return providers[name]
8 changes: 2 additions & 6 deletions packages/exchange/src/exchange/providers/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
from exchange.content import Text, ToolResult, ToolUse
from exchange.providers.base import Provider, Usage
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import retry_if_status
from exchange.providers.utils import raise_for_status
from exchange.providers.utils import get_provider_env_value, retry_if_status, raise_for_status

ANTHROPIC_HOST = "https://api.anthropic.com/v1/messages"

Expand All @@ -27,10 +26,7 @@ def __init__(self, client: httpx.Client) -> None:
@classmethod
def from_env(cls: Type["AnthropicProvider"]) -> "AnthropicProvider":
url = os.environ.get("ANTHROPIC_HOST", ANTHROPIC_HOST)
try:
key = os.environ["ANTHROPIC_API_KEY"]
except KeyError:
raise RuntimeError("Failed to get ANTHROPIC_API_KEY from the environment")
key = get_provider_env_value("ANTHROPIC_API_KEY", "anthropic")
client = httpx.Client(
base_url=url,
headers={
Expand Down
28 changes: 8 additions & 20 deletions packages/exchange/src/exchange/providers/azure.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
from typing import Type

import httpx

from exchange.providers import OpenAiProvider
from exchange.providers.utils import get_provider_env_value

PROVIDER_NAME = "azure"


class AzureProvider(OpenAiProvider):
Expand All @@ -14,25 +16,11 @@ def __init__(self, client: httpx.Client) -> None:

@classmethod
def from_env(cls: Type["AzureProvider"]) -> "AzureProvider":
try:
url = os.environ["AZURE_CHAT_COMPLETIONS_HOST_NAME"]
except KeyError:
raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_HOST_NAME from the environment.")

try:
deployment_name = os.environ["AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME"]
except KeyError:
raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME from the environment.")

try:
api_version = os.environ["AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION"]
except KeyError:
raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION from the environment.")

try:
key = os.environ["AZURE_CHAT_COMPLETIONS_KEY"]
except KeyError:
raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_KEY from the environment.")
url = get_provider_env_value("AZURE_CHAT_COMPLETIONS_HOST_NAME", PROVIDER_NAME)
deployment_name = get_provider_env_value("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME", PROVIDER_NAME)

api_version = get_provider_env_value("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION", PROVIDER_NAME)
key = get_provider_env_value("AZURE_CHAT_COMPLETIONS_KEY", PROVIDER_NAME)

# format the url host/"openai/deployments/" + deployment_name + "/?api-version=" + api_version
url = f"{url}/openai/deployments/{deployment_name}/"
Expand Down
13 changes: 12 additions & 1 deletion packages/exchange/src/exchange/providers/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from attrs import define, field
from typing import List, Tuple, Type
from typing import List, Optional, Tuple, Type

from exchange.message import Message
from exchange.tool import Tool
Expand Down Expand Up @@ -28,3 +28,14 @@ def complete(
) -> Tuple[Message, Usage]:
"""Generate the next message using the specified model"""
pass


class MissingProviderEnvVariableError(Exception):
def __init__(self, env_variable: str, provider: str, instructions_url: Optional[str] = None) -> None:
self.env_variable = env_variable
self.provider = provider
self.instructions_url = instructions_url
self.message = f"Missing environment variable: {env_variable} for provider {provider}."
if instructions_url:
self.message += f"\nPlease see {instructions_url} for instructions"
super().__init__(self.message)
15 changes: 7 additions & 8 deletions packages/exchange/src/exchange/providers/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
from exchange.message import Message
from exchange.providers import Provider, Usage
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import retry_if_status
from exchange.providers.utils import raise_for_status
from exchange.providers.utils import get_provider_env_value, raise_for_status, retry_if_status
from exchange.tool import Tool

SERVICE = "bedrock-runtime"
Expand Down Expand Up @@ -147,19 +146,19 @@ def get_signature_key(key: str, date_stamp: str, region_name: str, service_name:
return headers


PROVIDER_NAME = "bedrock"


class BedrockProvider(Provider):
def __init__(self, client: AwsClient) -> None:
self.client = client

@classmethod
def from_env(cls: Type["BedrockProvider"]) -> "BedrockProvider":
aws_region = os.environ.get("AWS_REGION", "us-east-1")
try:
aws_access_key = os.environ["AWS_ACCESS_KEY_ID"]
aws_secret_key = os.environ["AWS_SECRET_ACCESS_KEY"]
aws_session_token = os.environ.get("AWS_SESSION_TOKEN")
except KeyError:
raise RuntimeError("Failed to get AWS_ACCESS_KEY_ID or AWS_SECRET_ACCESS_KEY from the environment")
aws_access_key = get_provider_env_value("AWS_ACCESS_KEY_ID", PROVIDER_NAME)
aws_secret_key = get_provider_env_value("AWS_SECRET_ACCESS_KEY", PROVIDER_NAME)
aws_session_token = get_provider_env_value("AWS_SESSION_TOKEN", PROVIDER_NAME)

client = AwsClient(
aws_region=aws_region,
Expand Down
22 changes: 8 additions & 14 deletions packages/exchange/src/exchange/providers/databricks.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import os
from typing import Any, Dict, List, Tuple, Type

import httpx

from exchange.message import Message
from exchange.providers.base import Provider, Usage
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import raise_for_status, retry_if_status
from exchange.providers.utils import get_provider_env_value, raise_for_status, retry_if_status
from exchange.providers.utils import (
messages_to_openai_spec,
openai_response_to_message,
Expand Down Expand Up @@ -37,18 +36,8 @@ def __init__(self, client: httpx.Client) -> None:

@classmethod
def from_env(cls: Type["DatabricksProvider"]) -> "DatabricksProvider":
try:
url = os.environ["DATABRICKS_HOST"]
except KeyError:
raise RuntimeError(
"Failed to get DATABRICKS_HOST from the environment. See https://docs.databricks.com/en/dev-tools/auth/index.html#general-host-token-and-account-id-environment-variables-and-fields"
)
try:
key = os.environ["DATABRICKS_TOKEN"]
except KeyError:
raise RuntimeError(
"Failed to get DATABRICKS_TOKEN from the environment. See https://docs.databricks.com/en/dev-tools/auth/index.html#general-host-token-and-account-id-environment-variables-and-fields"
)
url = cls._get_env_variable("DATABRICKS_HOST")
key = cls._get_env_variable("DATABRICKS_TOKEN")
client = httpx.Client(
base_url=url,
auth=("token", key),
Expand Down Expand Up @@ -100,3 +89,8 @@ def _post(self, model: str, payload: dict) -> httpx.Response:
json=payload,
)
return raise_for_status(response).json()

@classmethod
def _get_env_variable(cls: Type["DatabricksProvider"], key: str) -> str:
instruction = "https://docs.databricks.com/en/dev-tools/auth/index.html#general-host-token-and-account-id-environment-variables-and-fields"
return get_provider_env_value(key, "databricks", instruction)
11 changes: 3 additions & 8 deletions packages/exchange/src/exchange/providers/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
from exchange.content import Text, ToolResult, ToolUse
from exchange.providers.base import Provider, Usage
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import retry_if_status
from exchange.providers.utils import raise_for_status
from exchange.providers.utils import get_provider_env_value, raise_for_status, retry_if_status

GOOGLE_HOST = "https://generativelanguage.googleapis.com/v1beta"

Expand All @@ -27,12 +26,8 @@ def __init__(self, client: httpx.Client) -> None:
@classmethod
def from_env(cls: Type["GoogleProvider"]) -> "GoogleProvider":
url = os.environ.get("GOOGLE_HOST", GOOGLE_HOST)
try:
key = os.environ["GOOGLE_API_KEY"]
except KeyError:
raise RuntimeError(
"Failed to get GOOGLE_API_KEY from the environment, see https://ai.google.dev/gemini-api/docs/api-key"
)
api_key_instructions_url = "https://ai.google.dev/gemini-api/docs/api-key"
key = get_provider_env_value("GOOGLE_API_KEY", "google", api_key_instructions_url)

client = httpx.Client(
base_url=url,
Expand Down
9 changes: 3 additions & 6 deletions packages/exchange/src/exchange/providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from exchange.message import Message
from exchange.providers.base import Provider, Usage
from exchange.providers.utils import (
get_provider_env_value,
messages_to_openai_spec,
openai_response_to_message,
openai_single_message_context_length_exceeded,
Expand Down Expand Up @@ -36,12 +37,8 @@ def __init__(self, client: httpx.Client) -> None:
@classmethod
def from_env(cls: Type["OpenAiProvider"]) -> "OpenAiProvider":
url = os.environ.get("OPENAI_HOST", OPENAI_HOST)
try:
key = os.environ["OPENAI_API_KEY"]
except KeyError:
raise RuntimeError(
"Failed to get OPENAI_API_KEY from the environment, see https://platform.openai.com/docs/api-reference/api-keys"
)
api_key_instructions_url = "https://platform.openai.com/docs/api-reference/api-keys"
key = get_provider_env_value("OPENAI_API_KEY", "openai", api_key_instructions_url)
client = httpx.Client(
base_url=url + "v1/",
auth=("Bearer", key),
Expand Down
9 changes: 9 additions & 0 deletions packages/exchange/src/exchange/providers/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import base64
import json
import os
import re
from typing import Any, Callable, Dict, List, Optional, Tuple

import httpx
from exchange.content import Text, ToolResult, ToolUse
from exchange.message import Message
from exchange.providers.base import MissingProviderEnvVariableError
from exchange.tool import Tool
from tenacity import retry_if_exception

Expand Down Expand Up @@ -179,6 +181,13 @@ def openai_single_message_context_length_exceeded(error_dict: dict) -> None:
raise InitialMessageTooLargeError(f"Input message too long. Message: {error_dict.get('message')}")


def get_provider_env_value(env_variable: str, provider: str, instructions_url: Optional[str] = None) -> str:
try:
return os.environ[env_variable]
except KeyError:
raise MissingProviderEnvVariableError(env_variable, provider, instructions_url)


class InitialMessageTooLargeError(Exception):
"""Custom error raised when the first input message in an exchange is too large."""

Expand Down
10 changes: 10 additions & 0 deletions packages/exchange/tests/providers/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from exchange import Message, Text
from exchange.content import ToolResult, ToolUse
from exchange.providers.anthropic import AnthropicProvider
from exchange.providers.base import MissingProviderEnvVariableError
from exchange.tool import Tool


Expand All @@ -25,6 +26,15 @@ def anthropic_provider():
return AnthropicProvider.from_env()


def test_from_env_throw_error_when_missing_api_key():
with patch.dict(os.environ, {}, clear=True):
with pytest.raises(MissingProviderEnvVariableError) as context:
AnthropicProvider.from_env()
assert context.value.provider == "anthropic"
assert context.value.env_variable == "ANTHROPIC_API_KEY"
assert context.value.message == "Missing environment variable: ANTHROPIC_API_KEY for provider anthropic."


def test_anthropic_response_to_text_message() -> None:
response = {
"content": [{"type": "text", "text": "Hello from Claude!"}],
Expand Down
30 changes: 30 additions & 0 deletions packages/exchange/tests/providers/test_azure.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,44 @@
import os
from unittest.mock import patch

import pytest

from exchange import Text, ToolUse
from exchange.providers.azure import AzureProvider
from exchange.providers.base import MissingProviderEnvVariableError
from .conftest import complete, tools

AZURE_MODEL = os.getenv("AZURE_MODEL", "gpt-4o-mini")


@pytest.mark.parametrize(
"env_var_name",
[
("AZURE_CHAT_COMPLETIONS_HOST_NAME"),
("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME"),
("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION"),
("AZURE_CHAT_COMPLETIONS_KEY"),
],
)
def test_from_env_throw_error_when_missing_env_var(env_var_name):
with patch.dict(
os.environ,
{
"AZURE_CHAT_COMPLETIONS_HOST_NAME": "test_host_name",
"AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME": "test_deployment_name",
"AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION": "test_api_version",
"AZURE_CHAT_COMPLETIONS_KEY": "test_api_key",
},
clear=True,
):
os.environ.pop(env_var_name)
with pytest.raises(MissingProviderEnvVariableError) as context:
AzureProvider.from_env()
assert context.value.provider == "azure"
assert context.value.env_variable == env_var_name
assert context.value.message == f"Missing environment variable: {env_var_name} for provider azure."


@pytest.mark.vcr()
def test_azure_complete(default_azure_env):
reply_message, reply_usage = complete(AzureProvider, AZURE_MODEL)
Expand Down
Loading