From c321c8cf8a8769b3a21305be3fca8c96141b0d2e Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Tue, 1 Oct 2024 13:19:38 +1000 Subject: [PATCH 01/12] exit the goose and show the error message when provider environment is not set --- src/goose/cli/config.py | 5 ----- src/goose/cli/session.py | 17 ++++++++++++++--- tests/cli/test_session.py | 15 +++++++++++++++ 3 files changed, 29 insertions(+), 8 deletions(-) diff --git a/src/goose/cli/config.py b/src/goose/cli/config.py index 2005c46895b7..bef458a69105 100644 --- a/src/goose/cli/config.py +++ b/src/goose/cli/config.py @@ -87,11 +87,6 @@ def default_model_configuration() -> Tuple[str, str, str]: break except Exception: pass - else: - raise ValueError( - "Could not detect an available provider," - + " make sure to plugin a provider or set an env var such as OPENAI_API_KEY" - ) recommended = { "ollama": (OLLAMA_MODEL, OLLAMA_MODEL), diff --git a/src/goose/cli/session.py b/src/goose/cli/session.py index bfa869a3cfd6..e9ab7294fc7a 100644 --- a/src/goose/cli/session.py +++ b/src/goose/cli/session.py @@ -1,8 +1,10 @@ +import sys import traceback from pathlib import Path from typing import Any, Dict, List, Optional -from exchange import Message, ToolResult, ToolUse, Text +from exchange import Message, ToolResult, ToolUse, Text, Exchange +from exchange.providers.base import MissingProviderEnvVariableError from rich import print from rich.console import RenderableType from rich.live import Live @@ -89,8 +91,7 @@ def __init__( self.profile = profile self.status_indicator = Status("", spinner="dots") self.notifier = SessionNotifier(self.status_indicator) - - self.exchange = build_exchange(profile=load_profile(profile), notifier=self.notifier) + self.exchange = self._create_exchange() setup_logging(log_file_directory=LOG_PATH, log_level=log_level) self.exchange.messages.extend(self._get_initial_messages()) @@ -100,6 +101,16 @@ def __init__( self.prompt_session = GoosePromptSession() + def _create_exchange(self) -> Exchange: + try: + return build_exchange(profile=load_profile(self.profile), notifier=self.notifier) + except MissingProviderEnvVariableError as e: + error_message = ( + f"Missing environment variable: {e.message}. Please set the required environment variable to continue." + ) + print(Panel(error_message, style="red")) + sys.exit(1) + def _get_initial_messages(self) -> List[Message]: messages = self.load_session() diff --git a/tests/cli/test_session.py b/tests/cli/test_session.py index 6d9086bcd8b3..ae71b62e4637 100644 --- a/tests/cli/test_session.py +++ b/tests/cli/test_session.py @@ -2,6 +2,7 @@ import pytest from exchange import Exchange, Message, ToolUse, ToolResult +from exchange.providers.base import MissingProviderEnvVariableError from goose.cli.prompt.goose_prompt_session import GoosePromptSession from goose.cli.prompt.user_input import PromptAction, UserInput from goose.cli.session import Session @@ -151,3 +152,17 @@ def test_set_generated_session_name(create_session_with_mock_configs, mock_sessi with patch("goose.cli.session.droid", return_value=generated_session_name): session = create_session_with_mock_configs({"name": None}) assert session.name == generated_session_name + + +def test_create_exchange_exit_when_env_var_does_not_exist(create_session_with_mock_configs, mock_sessions_path): + session = create_session_with_mock_configs() + expected_error = MissingProviderEnvVariableError(env_variable="OPENAI_API_KEY", provider="openai") + with patch("goose.cli.session.build_exchange", side_effect=expected_error), patch( + "goose.cli.session.print" + ) as mock_print, patch("sys.exit") as mock_exit: + session._create_exchange() + mock_print.call_args_list[0][0][0].renderable == ( + "Missing environment variable OPENAI_API_KEY for provider openai. ", + "Please set the required environment variable to continue.", + ) + mock_exit.assert_called_once_with(1) From 6cb7cfcba7c98ec86b38d1064f0369f758281e0e Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Tue, 1 Oct 2024 13:42:02 +1000 Subject: [PATCH 02/12] fixed the message content --- src/goose/cli/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/goose/cli/session.py b/src/goose/cli/session.py index e9ab7294fc7a..c58aa5f8edac 100644 --- a/src/goose/cli/session.py +++ b/src/goose/cli/session.py @@ -106,7 +106,7 @@ def _create_exchange(self) -> Exchange: return build_exchange(profile=load_profile(self.profile), notifier=self.notifier) except MissingProviderEnvVariableError as e: error_message = ( - f"Missing environment variable: {e.message}. Please set the required environment variable to continue." + f"{e.message}. Please set the required environment variable to continue." ) print(Panel(error_message, style="red")) sys.exit(1) From dfb4522a87f6751003c55d67660e517e2e07a961 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 2 Oct 2024 11:47:52 +1000 Subject: [PATCH 03/12] show error message when configuration is incorrect --- src/goose/cli/session.py | 12 +++++++++--- src/goose/toolkit/__init__.py | 7 +++++-- tests/cli/test_session.py | 14 ++++++++++++++ 3 files changed, 28 insertions(+), 5 deletions(-) diff --git a/src/goose/cli/session.py b/src/goose/cli/session.py index c58aa5f8edac..bdec42d10143 100644 --- a/src/goose/cli/session.py +++ b/src/goose/cli/session.py @@ -5,6 +5,7 @@ from exchange import Message, ToolResult, ToolUse, Text, Exchange from exchange.providers.base import MissingProviderEnvVariableError +from exchange.load_exchange_attribute_error import LoadExchangeAttributeError from rich import print from rich.console import RenderableType from rich.live import Live @@ -13,7 +14,7 @@ from rich.status import Status from goose.build import build_exchange -from goose.cli.config import ensure_config, session_path, LOG_PATH +from goose.cli.config import PROFILES_CONFIG_PATH, ensure_config, session_path, LOG_PATH from goose._logger import get_logger, setup_logging from goose.cli.prompt.goose_prompt_session import GoosePromptSession from goose.notifier import Notifier @@ -105,10 +106,15 @@ def _create_exchange(self) -> Exchange: try: return build_exchange(profile=load_profile(self.profile), notifier=self.notifier) except MissingProviderEnvVariableError as e: + error_message = f"{e.message}. Please set the required environment variable to continue." + print(Panel(error_message, style="red")) + sys.exit(1) + except LoadExchangeAttributeError as e: error_message = ( - f"{e.message}. Please set the required environment variable to continue." + f"[bold red]{e.message}[/bold red].\nPlease check your configuration file at {PROFILES_CONFIG_PATH}. " + + "Configuration doc: https://block-open-source.github.io/goose/configuration.html" ) - print(Panel(error_message, style="red")) + print(error_message) sys.exit(1) def _get_initial_messages(self) -> List[Message]: diff --git a/src/goose/toolkit/__init__.py b/src/goose/toolkit/__init__.py index a3a97d41f4dd..6820cc24cd54 100644 --- a/src/goose/toolkit/__init__.py +++ b/src/goose/toolkit/__init__.py @@ -1,9 +1,12 @@ from functools import cache - +from exchange.load_exchange_attribute_error import LoadExchangeAttributeError from goose.toolkit.base import Toolkit from goose.utils import load_plugins @cache def get_toolkit(name: str) -> type[Toolkit]: - return load_plugins(group="goose.toolkit")[name] + toolkits = load_plugins(group="goose.toolkit") + if name not in toolkits: + raise LoadExchangeAttributeError("toolkit", name, toolkits.keys()) + return toolkits[name] diff --git a/tests/cli/test_session.py b/tests/cli/test_session.py index ae71b62e4637..41d9a33d872e 100644 --- a/tests/cli/test_session.py +++ b/tests/cli/test_session.py @@ -3,6 +3,7 @@ import pytest from exchange import Exchange, Message, ToolUse, ToolResult from exchange.providers.base import MissingProviderEnvVariableError +from exchange.load_exchange_attribute_error import LoadExchangeAttributeError from goose.cli.prompt.goose_prompt_session import GoosePromptSession from goose.cli.prompt.user_input import PromptAction, UserInput from goose.cli.session import Session @@ -166,3 +167,16 @@ def test_create_exchange_exit_when_env_var_does_not_exist(create_session_with_mo "Please set the required environment variable to continue.", ) mock_exit.assert_called_once_with(1) + + +def test_create_exchange_exit_when_configuration_is_incorrect(create_session_with_mock_configs, mock_sessions_path): + session = create_session_with_mock_configs() + expected_error = LoadExchangeAttributeError( + attribute_name="provider", attribute_value="wrong_provider", available_values=["openai"] + ) + with patch("goose.cli.session.build_exchange", side_effect=expected_error), patch( + "goose.cli.session.print" + ) as mock_print, patch("sys.exit") as mock_exit: + session._create_exchange() + assert "Unknown provider: wrong_provider. Available providers: openai" in mock_print.call_args_list[0][0][0] + mock_exit.assert_called_once_with(1) From b52f3a373cbb6f3e92463e8292b316f2ccddce5e Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Thu, 3 Oct 2024 10:36:44 +1000 Subject: [PATCH 04/12] created custom errors when env var for providers does not exist or provider/moderator is unknown --- .../exchange/load_exchange_attribute_error.py | 13 +++++++++ .../src/exchange/moderators/__init__.py | 6 +++- .../src/exchange/providers/__init__.py | 6 +++- .../src/exchange/providers/anthropic.py | 8 ++--- .../exchange/src/exchange/providers/azure.py | 29 +++++++------------ .../exchange/src/exchange/providers/base.py | 12 +++++++- .../src/exchange/providers/bedrock.py | 16 +++++----- .../src/exchange/providers/databricks.py | 21 +++++--------- .../exchange/src/exchange/providers/google.py | 11 ++----- .../exchange/src/exchange/providers/openai.py | 9 ++---- .../exchange/src/exchange/providers/utils.py | 8 +++++ .../tests/providers/test_anthropic.py | 8 +++++ .../exchange/tests/providers/test_azure.py | 28 ++++++++++++++++++ .../exchange/tests/providers/test_bedrock.py | 25 ++++++++++++++++ .../tests/providers/test_databricks.py | 24 +++++++++++++++ .../exchange/tests/providers/test_google.py | 9 ++++++ .../exchange/tests/providers/test_openai.py | 10 +++++++ .../exchange/tests/providers/test_provider.py | 18 ++++++++++++ packages/exchange/tests/test_base.py | 27 +++++++++++++++++ .../test_load_exchange_attribute_error.py | 13 +++++++++ packages/exchange/tests/test_moderators.py | 17 +++++++++++ 21 files changed, 255 insertions(+), 63 deletions(-) create mode 100644 packages/exchange/src/exchange/load_exchange_attribute_error.py create mode 100644 packages/exchange/tests/providers/test_provider.py create mode 100644 packages/exchange/tests/test_base.py create mode 100644 packages/exchange/tests/test_load_exchange_attribute_error.py create mode 100644 packages/exchange/tests/test_moderators.py diff --git a/packages/exchange/src/exchange/load_exchange_attribute_error.py b/packages/exchange/src/exchange/load_exchange_attribute_error.py new file mode 100644 index 000000000000..bf1c4b607caa --- /dev/null +++ b/packages/exchange/src/exchange/load_exchange_attribute_error.py @@ -0,0 +1,13 @@ +from typing import List + + +class LoadExchangeAttributeError(Exception): + def __init__(self, attribute_name: str, attribute_value: str, available_values: List[str]) -> None: + self.attribute_name = attribute_name + self.attribute_value = attribute_value + self.available_values = available_values + self.message = ( + f"Unknown {attribute_name}: {attribute_value}." + + f" Available {attribute_name}s: {', '.join(available_values)}" + ) + super().__init__(self.message) diff --git a/packages/exchange/src/exchange/moderators/__init__.py b/packages/exchange/src/exchange/moderators/__init__.py index 56b198a75a1a..8bcdae820136 100644 --- a/packages/exchange/src/exchange/moderators/__init__.py +++ b/packages/exchange/src/exchange/moderators/__init__.py @@ -1,6 +1,7 @@ from functools import cache from typing import Type +from exchange.load_exchange_attribute_error import LoadExchangeAttributeError from exchange.moderators.base import Moderator from exchange.utils import load_plugins from exchange.moderators.passive import PassiveModerator # noqa @@ -10,4 +11,7 @@ @cache def get_moderator(name: str) -> Type[Moderator]: - return load_plugins(group="exchange.moderator")[name] + moderators = load_plugins(group="exchange.moderator") + if name not in moderators: + raise LoadExchangeAttributeError("moderator", name, moderators.keys()) + return moderators[name] diff --git a/packages/exchange/src/exchange/providers/__init__.py b/packages/exchange/src/exchange/providers/__init__.py index ac7ed07a047d..65e83746871a 100644 --- a/packages/exchange/src/exchange/providers/__init__.py +++ b/packages/exchange/src/exchange/providers/__init__.py @@ -1,6 +1,7 @@ from functools import cache from typing import Type +from exchange.load_exchange_attribute_error import LoadExchangeAttributeError from exchange.providers.anthropic import AnthropicProvider # noqa from exchange.providers.base import Provider, Usage # noqa from exchange.providers.databricks import DatabricksProvider # noqa @@ -14,4 +15,7 @@ @cache def get_provider(name: str) -> Type[Provider]: - return load_plugins(group="exchange.provider")[name] + providers = load_plugins(group="exchange.provider") + if name not in providers: + raise LoadExchangeAttributeError("provider", name, providers.keys()) + return providers[name] diff --git a/packages/exchange/src/exchange/providers/anthropic.py b/packages/exchange/src/exchange/providers/anthropic.py index 154ec5f79f50..bf052b20b814 100644 --- a/packages/exchange/src/exchange/providers/anthropic.py +++ b/packages/exchange/src/exchange/providers/anthropic.py @@ -7,8 +7,7 @@ from exchange.content import Text, ToolResult, ToolUse from exchange.providers.base import Provider, Usage from tenacity import retry, wait_fixed, stop_after_attempt -from exchange.providers.utils import retry_if_status -from exchange.providers.utils import raise_for_status +from exchange.providers.utils import get_provider_env_value, retry_if_status, raise_for_status ANTHROPIC_HOST = "https://api.anthropic.com/v1/messages" @@ -27,10 +26,7 @@ def __init__(self, client: httpx.Client) -> None: @classmethod def from_env(cls: Type["AnthropicProvider"]) -> "AnthropicProvider": url = os.environ.get("ANTHROPIC_HOST", ANTHROPIC_HOST) - try: - key = os.environ["ANTHROPIC_API_KEY"] - except KeyError: - raise RuntimeError("Failed to get ANTHROPIC_API_KEY from the environment") + key = get_provider_env_value("ANTHROPIC_API_KEY", "anthropic") client = httpx.Client( base_url=url, headers={ diff --git a/packages/exchange/src/exchange/providers/azure.py b/packages/exchange/src/exchange/providers/azure.py index 7bacb9ddc75d..dc84429e1a63 100644 --- a/packages/exchange/src/exchange/providers/azure.py +++ b/packages/exchange/src/exchange/providers/azure.py @@ -4,6 +4,7 @@ import httpx from exchange.providers import OpenAiProvider +from exchange.providers.utils import get_provider_env_value class AzureProvider(OpenAiProvider): @@ -14,25 +15,11 @@ def __init__(self, client: httpx.Client) -> None: @classmethod def from_env(cls: Type["AzureProvider"]) -> "AzureProvider": - try: - url = os.environ["AZURE_CHAT_COMPLETIONS_HOST_NAME"] - except KeyError: - raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_HOST_NAME from the environment.") - - try: - deployment_name = os.environ["AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME"] - except KeyError: - raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME from the environment.") - - try: - api_version = os.environ["AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION"] - except KeyError: - raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION from the environment.") - - try: - key = os.environ["AZURE_CHAT_COMPLETIONS_KEY"] - except KeyError: - raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_KEY from the environment.") + url = cls._get_env_variable("AZURE_CHAT_COMPLETIONS_HOST_NAME") + deployment_name = cls._get_env_variable("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME") + + api_version = cls._get_env_variable("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION") + key = cls._get_env_variable("AZURE_CHAT_COMPLETIONS_KEY") # format the url host/"openai/deployments/" + deployment_name + "/?api-version=" + api_version url = f"{url}/openai/deployments/{deployment_name}/" @@ -43,3 +30,7 @@ def from_env(cls: Type["AzureProvider"]) -> "AzureProvider": timeout=httpx.Timeout(60 * 10), ) return cls(client) + + @classmethod + def _get_env_variable(cls: Type["AzureProvider"], key: str) -> str: + return get_provider_env_value(key, "azure") diff --git a/packages/exchange/src/exchange/providers/base.py b/packages/exchange/src/exchange/providers/base.py index 7b7ff88bcfff..d577e364ffe7 100644 --- a/packages/exchange/src/exchange/providers/base.py +++ b/packages/exchange/src/exchange/providers/base.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from attrs import define, field -from typing import List, Tuple, Type +from typing import List, Optional, Tuple, Type from exchange.message import Message from exchange.tool import Tool @@ -28,3 +28,13 @@ def complete( ) -> Tuple[Message, Usage]: """Generate the next message using the specified model""" pass + +class MissingProviderEnvVariableError(Exception): + def __init__(self, env_variable: str, provider: str, instructions_url: Optional[str] = None) -> None: + self.env_variable = env_variable + self.provider = provider + self.instructions_url = instructions_url + self.message = f"Missing environment variable: {env_variable} for provider {provider}." + if instructions_url: + self.message += f"\n Please see {instructions_url} for instructions" + super().__init__(self.message) diff --git a/packages/exchange/src/exchange/providers/bedrock.py b/packages/exchange/src/exchange/providers/bedrock.py index 2a5f53dc8621..4b02f2391aae 100644 --- a/packages/exchange/src/exchange/providers/bedrock.py +++ b/packages/exchange/src/exchange/providers/bedrock.py @@ -13,8 +13,7 @@ from exchange.message import Message from exchange.providers import Provider, Usage from tenacity import retry, wait_fixed, stop_after_attempt -from exchange.providers.utils import retry_if_status -from exchange.providers.utils import raise_for_status +from exchange.providers.utils import get_provider_env_value, raise_for_status, retry_if_status from exchange.tool import Tool SERVICE = "bedrock-runtime" @@ -154,12 +153,9 @@ def __init__(self, client: AwsClient) -> None: @classmethod def from_env(cls: Type["BedrockProvider"]) -> "BedrockProvider": aws_region = os.environ.get("AWS_REGION", "us-east-1") - try: - aws_access_key = os.environ["AWS_ACCESS_KEY_ID"] - aws_secret_key = os.environ["AWS_SECRET_ACCESS_KEY"] - aws_session_token = os.environ.get("AWS_SESSION_TOKEN") - except KeyError: - raise RuntimeError("Failed to get AWS_ACCESS_KEY_ID or AWS_SECRET_ACCESS_KEY from the environment") + aws_access_key = cls._get_env_variable("AWS_ACCESS_KEY_ID") + aws_secret_key = cls._get_env_variable("AWS_SECRET_ACCESS_KEY") + aws_session_token = cls._get_env_variable("AWS_SESSION_TOKEN") client = AwsClient( aws_region=aws_region, @@ -326,3 +322,7 @@ def tools_to_bedrock_spec(tools: Tuple[Tool]) -> Optional[dict]: tools_added.add(tool.name) tool_config = {"tools": tool_config_list} return tool_config + + @classmethod + def _get_env_variable(cls: Type["BedrockProvider"], key: str) -> str: + return get_provider_env_value(key, "bedrock") diff --git a/packages/exchange/src/exchange/providers/databricks.py b/packages/exchange/src/exchange/providers/databricks.py index 84dc7515cae2..c036b90806b5 100644 --- a/packages/exchange/src/exchange/providers/databricks.py +++ b/packages/exchange/src/exchange/providers/databricks.py @@ -6,7 +6,7 @@ from exchange.message import Message from exchange.providers.base import Provider, Usage from tenacity import retry, wait_fixed, stop_after_attempt -from exchange.providers.utils import raise_for_status, retry_if_status +from exchange.providers.utils import get_provider_env_value, raise_for_status, retry_if_status from exchange.providers.utils import ( messages_to_openai_spec, openai_response_to_message, @@ -37,18 +37,8 @@ def __init__(self, client: httpx.Client) -> None: @classmethod def from_env(cls: Type["DatabricksProvider"]) -> "DatabricksProvider": - try: - url = os.environ["DATABRICKS_HOST"] - except KeyError: - raise RuntimeError( - "Failed to get DATABRICKS_HOST from the environment. See https://docs.databricks.com/en/dev-tools/auth/index.html#general-host-token-and-account-id-environment-variables-and-fields" - ) - try: - key = os.environ["DATABRICKS_TOKEN"] - except KeyError: - raise RuntimeError( - "Failed to get DATABRICKS_TOKEN from the environment. See https://docs.databricks.com/en/dev-tools/auth/index.html#general-host-token-and-account-id-environment-variables-and-fields" - ) + url = cls._get_env_variable("DATABRICKS_HOST") + key = cls._get_env_variable("DATABRICKS_TOKEN") client = httpx.Client( base_url=url, auth=("token", key), @@ -100,3 +90,8 @@ def _post(self, model: str, payload: dict) -> httpx.Response: json=payload, ) return raise_for_status(response).json() + + @classmethod + def _get_env_variable(cls: Type["DatabricksProvider"], key: str) -> str: + instruction = "https://docs.databricks.com/en/dev-tools/auth/index.html#general-host-token-and-account-id-environment-variables-and-fields" + return get_provider_env_value(key, "databricks", instruction) diff --git a/packages/exchange/src/exchange/providers/google.py b/packages/exchange/src/exchange/providers/google.py index 426aa79d5d80..0c4b43caf32c 100644 --- a/packages/exchange/src/exchange/providers/google.py +++ b/packages/exchange/src/exchange/providers/google.py @@ -7,8 +7,7 @@ from exchange.content import Text, ToolResult, ToolUse from exchange.providers.base import Provider, Usage from tenacity import retry, wait_fixed, stop_after_attempt -from exchange.providers.utils import retry_if_status -from exchange.providers.utils import raise_for_status +from exchange.providers.utils import get_provider_env_value, raise_for_status, retry_if_status GOOGLE_HOST = "https://generativelanguage.googleapis.com/v1beta" @@ -27,12 +26,8 @@ def __init__(self, client: httpx.Client) -> None: @classmethod def from_env(cls: Type["GoogleProvider"]) -> "GoogleProvider": url = os.environ.get("GOOGLE_HOST", GOOGLE_HOST) - try: - key = os.environ["GOOGLE_API_KEY"] - except KeyError: - raise RuntimeError( - "Failed to get GOOGLE_API_KEY from the environment, see https://ai.google.dev/gemini-api/docs/api-key" - ) + api_key_instructions_url = "see https://ai.google.dev/gemini-api/docs/api-key" + key = get_provider_env_value("GOOGLE_API_KEY", "google", api_key_instructions_url) client = httpx.Client( base_url=url, diff --git a/packages/exchange/src/exchange/providers/openai.py b/packages/exchange/src/exchange/providers/openai.py index dbd293b47b75..d9210203d2ea 100644 --- a/packages/exchange/src/exchange/providers/openai.py +++ b/packages/exchange/src/exchange/providers/openai.py @@ -6,6 +6,7 @@ from exchange.message import Message from exchange.providers.base import Provider, Usage from exchange.providers.utils import ( + get_provider_env_value, messages_to_openai_spec, openai_response_to_message, openai_single_message_context_length_exceeded, @@ -36,12 +37,8 @@ def __init__(self, client: httpx.Client) -> None: @classmethod def from_env(cls: Type["OpenAiProvider"]) -> "OpenAiProvider": url = os.environ.get("OPENAI_HOST", OPENAI_HOST) - try: - key = os.environ["OPENAI_API_KEY"] - except KeyError: - raise RuntimeError( - "Failed to get OPENAI_API_KEY from the environment, see https://platform.openai.com/docs/api-reference/api-keys" - ) + api_key_instructions_url = "see https://platform.openai.com/docs/api-reference/api-keys" + key = get_provider_env_value("OPENAI_API_KEY", "openai", api_key_instructions_url) client = httpx.Client( base_url=url + "v1/", auth=("Bearer", key), diff --git a/packages/exchange/src/exchange/providers/utils.py b/packages/exchange/src/exchange/providers/utils.py index 4be7ac31e4a9..d02499cc41f3 100644 --- a/packages/exchange/src/exchange/providers/utils.py +++ b/packages/exchange/src/exchange/providers/utils.py @@ -1,11 +1,13 @@ import base64 import json +import os import re from typing import Any, Callable, Dict, List, Optional, Tuple import httpx from exchange.content import Text, ToolResult, ToolUse from exchange.message import Message +from exchange.providers.base import MissingProviderEnvVariableError from exchange.tool import Tool from tenacity import retry_if_exception @@ -179,6 +181,12 @@ def openai_single_message_context_length_exceeded(error_dict: dict) -> None: raise InitialMessageTooLargeError(f"Input message too long. Message: {error_dict.get('message')}") +def get_provider_env_value(env_variable: str, provider: str, instructions_url: Optional[str] = None) -> str: + try: + return os.environ[env_variable] + except KeyError: + raise MissingProviderEnvVariableError(env_variable, provider, instructions_url) + class InitialMessageTooLargeError(Exception): """Custom error raised when the first input message in an exchange is too large.""" diff --git a/packages/exchange/tests/providers/test_anthropic.py b/packages/exchange/tests/providers/test_anthropic.py index a6f5bc68973f..ef98110fce38 100644 --- a/packages/exchange/tests/providers/test_anthropic.py +++ b/packages/exchange/tests/providers/test_anthropic.py @@ -6,6 +6,7 @@ from exchange import Message, Text from exchange.content import ToolResult, ToolUse from exchange.providers.anthropic import AnthropicProvider +from exchange.providers.base import MissingProviderEnvVariableError from exchange.tool import Tool @@ -24,6 +25,13 @@ def example_fn(param: str) -> None: def anthropic_provider(): return AnthropicProvider.from_env() +def test_from_env_throw_error_when_missing_api_key(): + with patch.dict(os.environ, {}, clear=True): + with pytest.raises(MissingProviderEnvVariableError) as context: + AnthropicProvider.from_env() + assert context.value.provider == "anthropic" + assert context.value.env_variable == "ANTHROPIC_API_KEY" + assert context.value.message == "Missing environment variable: ANTHROPIC_API_KEY for provider anthropic." def test_anthropic_response_to_text_message() -> None: response = { diff --git a/packages/exchange/tests/providers/test_azure.py b/packages/exchange/tests/providers/test_azure.py index adafabedb21e..9d77a3998dad 100644 --- a/packages/exchange/tests/providers/test_azure.py +++ b/packages/exchange/tests/providers/test_azure.py @@ -1,13 +1,41 @@ import os +from unittest.mock import patch import pytest from exchange import Text, ToolUse from exchange.providers.azure import AzureProvider +from exchange.providers.base import MissingProviderEnvVariableError from .conftest import complete, tools AZURE_MODEL = os.getenv("AZURE_MODEL", "gpt-4o-mini") +@pytest.mark.parametrize( + "env_var_name", + [ + ("AZURE_CHAT_COMPLETIONS_HOST_NAME"), + ("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME"), + ("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION"), + ("AZURE_CHAT_COMPLETIONS_KEY"), + ], +) +def test_from_env_throw_error_when_missing_env_var(env_var_name): + with patch.dict( + os.environ, + { + "AZURE_CHAT_COMPLETIONS_HOST_NAME": "test_host_name", + "AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME": "test_deployment_name", + "AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION": "test_api_version", + "AZURE_CHAT_COMPLETIONS_KEY": "test_api_key", + }, + clear=True, + ): + os.environ.pop(env_var_name) + with pytest.raises(MissingProviderEnvVariableError) as context: + AzureProvider.from_env() + assert context.value.provider == "azure" + assert context.value.env_variable == env_var_name + assert context.value.message == f"Missing environment variable: {env_var_name} for provider azure." @pytest.mark.vcr() def test_azure_complete(default_azure_env): diff --git a/packages/exchange/tests/providers/test_bedrock.py b/packages/exchange/tests/providers/test_bedrock.py index 2525f650bfcd..d6ac5ebba6b7 100644 --- a/packages/exchange/tests/providers/test_bedrock.py +++ b/packages/exchange/tests/providers/test_bedrock.py @@ -5,11 +5,36 @@ import pytest from exchange.content import Text, ToolResult, ToolUse from exchange.message import Message +from exchange.providers.base import MissingProviderEnvVariableError from exchange.providers.bedrock import BedrockProvider from exchange.tool import Tool logger = logging.getLogger(__name__) +@pytest.mark.parametrize( + "env_var_name", + [ + ("AWS_ACCESS_KEY_ID"), + ("AWS_SECRET_ACCESS_KEY"), + ("AWS_SESSION_TOKEN"), + ], +) +def test_from_env_throw_error_when_missing_env_var(env_var_name): + with patch.dict( + os.environ, + { + "AWS_ACCESS_KEY_ID": "test_access_key_id", + "AWS_SECRET_ACCESS_KEY": "test_secret_access_key", + "AWS_SESSION_TOKEN": "test_session_token", + }, + clear=True, + ): + os.environ.pop(env_var_name) + with pytest.raises(MissingProviderEnvVariableError) as context: + BedrockProvider.from_env() + assert context.value.provider == "bedrock" + assert context.value.env_variable == env_var_name + assert context.value.message == f"Missing environment variable: {env_var_name} for provider bedrock." @pytest.fixture @patch.dict( diff --git a/packages/exchange/tests/providers/test_databricks.py b/packages/exchange/tests/providers/test_databricks.py index 3c1421146d43..e2e989f05f61 100644 --- a/packages/exchange/tests/providers/test_databricks.py +++ b/packages/exchange/tests/providers/test_databricks.py @@ -3,8 +3,32 @@ import pytest from exchange import Message, Text +from exchange.providers.base import MissingProviderEnvVariableError from exchange.providers.databricks import DatabricksProvider +@pytest.mark.parametrize( + "env_var_name", + [ + ("DATABRICKS_HOST"), + ("DATABRICKS_TOKEN"), + ], +) +def test_from_env_throw_error_when_missing_env_var(env_var_name): + with patch.dict( + os.environ, + { + "DATABRICKS_HOST": "test_host", + "DATABRICKS_TOKEN": "test_token", + }, + clear=True, + ): + os.environ.pop(env_var_name) + with pytest.raises(MissingProviderEnvVariableError) as context: + DatabricksProvider.from_env() + assert context.value.provider == "databricks" + assert context.value.env_variable == env_var_name + assert f"Missing environment variable: {env_var_name} for provider databricks" in context.value.message + assert "https://docs.databricks.com" in context.value.message @pytest.fixture @patch.dict( diff --git a/packages/exchange/tests/providers/test_google.py b/packages/exchange/tests/providers/test_google.py index 47ad46b43fe3..3eb59f1c442b 100644 --- a/packages/exchange/tests/providers/test_google.py +++ b/packages/exchange/tests/providers/test_google.py @@ -5,6 +5,7 @@ import pytest from exchange import Message, Text from exchange.content import ToolResult, ToolUse +from exchange.providers.base import MissingProviderEnvVariableError from exchange.providers.google import GoogleProvider from exchange.tool import Tool @@ -18,6 +19,14 @@ def example_fn(param: str) -> None: """ pass +def test_from_env_throw_error_when_missing_api_key(): + with patch.dict(os.environ, {}, clear=True): + with pytest.raises(MissingProviderEnvVariableError) as context: + GoogleProvider.from_env() + assert context.value.provider == "google" + assert context.value.env_variable == "GOOGLE_API_KEY" + assert "Missing environment variable: GOOGLE_API_KEY for provider google" in context.value.message + assert "https://ai.google.dev/gemini-api/docs/api-key" in context.value.message @pytest.fixture @patch.dict(os.environ, {"GOOGLE_API_KEY": "test_api_key"}) diff --git a/packages/exchange/tests/providers/test_openai.py b/packages/exchange/tests/providers/test_openai.py index 45bc620500f6..ddf92ab74e90 100644 --- a/packages/exchange/tests/providers/test_openai.py +++ b/packages/exchange/tests/providers/test_openai.py @@ -1,13 +1,23 @@ import os +from unittest.mock import patch import pytest from exchange import Text, ToolUse +from exchange.providers.base import MissingProviderEnvVariableError from exchange.providers.openai import OpenAiProvider from .conftest import complete, vision, tools OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini") +def test_from_env_throw_error_when_missing_api_key(): + with patch.dict(os.environ, {}, clear=True): + with pytest.raises(MissingProviderEnvVariableError) as context: + OpenAiProvider.from_env() + assert context.value.provider == "openai" + assert context.value.env_variable == "OPENAI_API_KEY" + assert "Missing environment variable: OPENAI_API_KEY for provider openai" in context.value.message + assert "https://platform.openai.com" in context.value.message @pytest.mark.vcr() def test_openai_complete(default_openai_env): diff --git a/packages/exchange/tests/providers/test_provider.py b/packages/exchange/tests/providers/test_provider.py new file mode 100644 index 000000000000..0038d2ea2bb1 --- /dev/null +++ b/packages/exchange/tests/providers/test_provider.py @@ -0,0 +1,18 @@ +import pytest +from exchange.load_exchange_attribute_error import LoadExchangeAttributeError +from exchange.providers import get_provider + + +def test_get_provider_valid(): + provider_name = "openai" + provider = get_provider(provider_name) + assert provider.__name__ == "OpenAiProvider" + + +def test_get_provider_throw_error_for_unknown_provider(): + with pytest.raises(LoadExchangeAttributeError) as error: + get_provider("nonexistent") + assert error.value.attribute_name == "provider" + assert error.value.attribute_value == "nonexistent" + assert "openai" in error.value.available_values + assert "openai" in error.value.message diff --git a/packages/exchange/tests/test_base.py b/packages/exchange/tests/test_base.py new file mode 100644 index 000000000000..9b3d6d49b85b --- /dev/null +++ b/packages/exchange/tests/test_base.py @@ -0,0 +1,27 @@ +from exchange.providers.base import MissingProviderEnvVariableError + + +def test_missing_provider_env_variable_error_without_instructions_url(): + env_variable = "API_KEY" + provider = "TestProvider" + error = MissingProviderEnvVariableError(env_variable, provider) + + assert error.env_variable == env_variable + assert error.provider == provider + assert error.instructions_url is None + assert error.message == "Missing environment variable: API_KEY for provider TestProvider." + + +def test_missing_provider_env_variable_error_with_instructions_url(): + env_variable = "API_KEY" + provider = "TestProvider" + instructions_url = "http://example.com/instructions" + error = MissingProviderEnvVariableError(env_variable, provider, instructions_url) + + assert error.env_variable == env_variable + assert error.provider == provider + assert error.instructions_url == instructions_url + assert error.message == ( + "Missing environment variable: API_KEY for provider TestProvider.\n" + " Please see http://example.com/instructions for instructions" + ) diff --git a/packages/exchange/tests/test_load_exchange_attribute_error.py b/packages/exchange/tests/test_load_exchange_attribute_error.py new file mode 100644 index 000000000000..f6f17f14459f --- /dev/null +++ b/packages/exchange/tests/test_load_exchange_attribute_error.py @@ -0,0 +1,13 @@ +from exchange.load_exchange_attribute_error import LoadExchangeAttributeError + + +def test_load_exchange_attribute_error(): + attribute_name = "moderator" + attribute_value = "not_exist" + available_values = ["truncate", "summarizer"] + error = LoadExchangeAttributeError(attribute_name, attribute_value, available_values) + + assert error.attribute_name == attribute_name + assert error.attribute_value == attribute_value + assert error.attribute_value == attribute_value + assert error.message == "Unknown moderator: not_exist. Available moderators: truncate, summarizer" diff --git a/packages/exchange/tests/test_moderators.py b/packages/exchange/tests/test_moderators.py new file mode 100644 index 000000000000..8d8478f84300 --- /dev/null +++ b/packages/exchange/tests/test_moderators.py @@ -0,0 +1,17 @@ +from exchange.load_exchange_attribute_error import LoadExchangeAttributeError +from exchange.moderators import get_moderator +import pytest + + +def test_get_moderator(): + moderator = get_moderator("truncate") + assert moderator.__name__ == "ContextTruncate" + + +def test_get_moderator_raise_error_for_unknown_moderator(): + with pytest.raises(LoadExchangeAttributeError) as error: + get_moderator("nonexistent") + assert error.value.attribute_name == "moderator" + assert error.value.attribute_value == "nonexistent" + assert "truncate" in error.value.available_values + assert "truncate" in error.value.message From 40fab9bedb76071b7dc29c72d2653209dbabf987 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Thu, 3 Oct 2024 10:37:58 +1000 Subject: [PATCH 05/12] fixed format issues --- packages/exchange/src/exchange/providers/azure.py | 1 - packages/exchange/src/exchange/providers/base.py | 1 + packages/exchange/src/exchange/providers/databricks.py | 1 - packages/exchange/src/exchange/providers/utils.py | 1 + packages/exchange/tests/providers/test_anthropic.py | 2 ++ packages/exchange/tests/providers/test_azure.py | 2 ++ packages/exchange/tests/providers/test_bedrock.py | 2 ++ packages/exchange/tests/providers/test_databricks.py | 2 ++ packages/exchange/tests/providers/test_google.py | 2 ++ packages/exchange/tests/providers/test_openai.py | 2 ++ 10 files changed, 14 insertions(+), 2 deletions(-) diff --git a/packages/exchange/src/exchange/providers/azure.py b/packages/exchange/src/exchange/providers/azure.py index dc84429e1a63..212dcca5e7c5 100644 --- a/packages/exchange/src/exchange/providers/azure.py +++ b/packages/exchange/src/exchange/providers/azure.py @@ -1,4 +1,3 @@ -import os from typing import Type import httpx diff --git a/packages/exchange/src/exchange/providers/base.py b/packages/exchange/src/exchange/providers/base.py index d577e364ffe7..7ec8745bc410 100644 --- a/packages/exchange/src/exchange/providers/base.py +++ b/packages/exchange/src/exchange/providers/base.py @@ -29,6 +29,7 @@ def complete( """Generate the next message using the specified model""" pass + class MissingProviderEnvVariableError(Exception): def __init__(self, env_variable: str, provider: str, instructions_url: Optional[str] = None) -> None: self.env_variable = env_variable diff --git a/packages/exchange/src/exchange/providers/databricks.py b/packages/exchange/src/exchange/providers/databricks.py index c036b90806b5..77d392e8d6a2 100644 --- a/packages/exchange/src/exchange/providers/databricks.py +++ b/packages/exchange/src/exchange/providers/databricks.py @@ -1,4 +1,3 @@ -import os from typing import Any, Dict, List, Tuple, Type import httpx diff --git a/packages/exchange/src/exchange/providers/utils.py b/packages/exchange/src/exchange/providers/utils.py index d02499cc41f3..01504305644e 100644 --- a/packages/exchange/src/exchange/providers/utils.py +++ b/packages/exchange/src/exchange/providers/utils.py @@ -187,6 +187,7 @@ def get_provider_env_value(env_variable: str, provider: str, instructions_url: O except KeyError: raise MissingProviderEnvVariableError(env_variable, provider, instructions_url) + class InitialMessageTooLargeError(Exception): """Custom error raised when the first input message in an exchange is too large.""" diff --git a/packages/exchange/tests/providers/test_anthropic.py b/packages/exchange/tests/providers/test_anthropic.py index ef98110fce38..272ebcb0f0d2 100644 --- a/packages/exchange/tests/providers/test_anthropic.py +++ b/packages/exchange/tests/providers/test_anthropic.py @@ -25,6 +25,7 @@ def example_fn(param: str) -> None: def anthropic_provider(): return AnthropicProvider.from_env() + def test_from_env_throw_error_when_missing_api_key(): with patch.dict(os.environ, {}, clear=True): with pytest.raises(MissingProviderEnvVariableError) as context: @@ -33,6 +34,7 @@ def test_from_env_throw_error_when_missing_api_key(): assert context.value.env_variable == "ANTHROPIC_API_KEY" assert context.value.message == "Missing environment variable: ANTHROPIC_API_KEY for provider anthropic." + def test_anthropic_response_to_text_message() -> None: response = { "content": [{"type": "text", "text": "Hello from Claude!"}], diff --git a/packages/exchange/tests/providers/test_azure.py b/packages/exchange/tests/providers/test_azure.py index 9d77a3998dad..b46be30b99d8 100644 --- a/packages/exchange/tests/providers/test_azure.py +++ b/packages/exchange/tests/providers/test_azure.py @@ -10,6 +10,7 @@ AZURE_MODEL = os.getenv("AZURE_MODEL", "gpt-4o-mini") + @pytest.mark.parametrize( "env_var_name", [ @@ -37,6 +38,7 @@ def test_from_env_throw_error_when_missing_env_var(env_var_name): assert context.value.env_variable == env_var_name assert context.value.message == f"Missing environment variable: {env_var_name} for provider azure." + @pytest.mark.vcr() def test_azure_complete(default_azure_env): reply_message, reply_usage = complete(AzureProvider, AZURE_MODEL) diff --git a/packages/exchange/tests/providers/test_bedrock.py b/packages/exchange/tests/providers/test_bedrock.py index d6ac5ebba6b7..f8fcaa4b8753 100644 --- a/packages/exchange/tests/providers/test_bedrock.py +++ b/packages/exchange/tests/providers/test_bedrock.py @@ -11,6 +11,7 @@ logger = logging.getLogger(__name__) + @pytest.mark.parametrize( "env_var_name", [ @@ -36,6 +37,7 @@ def test_from_env_throw_error_when_missing_env_var(env_var_name): assert context.value.env_variable == env_var_name assert context.value.message == f"Missing environment variable: {env_var_name} for provider bedrock." + @pytest.fixture @patch.dict( os.environ, diff --git a/packages/exchange/tests/providers/test_databricks.py b/packages/exchange/tests/providers/test_databricks.py index e2e989f05f61..4b6793abc67a 100644 --- a/packages/exchange/tests/providers/test_databricks.py +++ b/packages/exchange/tests/providers/test_databricks.py @@ -6,6 +6,7 @@ from exchange.providers.base import MissingProviderEnvVariableError from exchange.providers.databricks import DatabricksProvider + @pytest.mark.parametrize( "env_var_name", [ @@ -30,6 +31,7 @@ def test_from_env_throw_error_when_missing_env_var(env_var_name): assert f"Missing environment variable: {env_var_name} for provider databricks" in context.value.message assert "https://docs.databricks.com" in context.value.message + @pytest.fixture @patch.dict( os.environ, diff --git a/packages/exchange/tests/providers/test_google.py b/packages/exchange/tests/providers/test_google.py index 3eb59f1c442b..76ae4c8d7216 100644 --- a/packages/exchange/tests/providers/test_google.py +++ b/packages/exchange/tests/providers/test_google.py @@ -19,6 +19,7 @@ def example_fn(param: str) -> None: """ pass + def test_from_env_throw_error_when_missing_api_key(): with patch.dict(os.environ, {}, clear=True): with pytest.raises(MissingProviderEnvVariableError) as context: @@ -28,6 +29,7 @@ def test_from_env_throw_error_when_missing_api_key(): assert "Missing environment variable: GOOGLE_API_KEY for provider google" in context.value.message assert "https://ai.google.dev/gemini-api/docs/api-key" in context.value.message + @pytest.fixture @patch.dict(os.environ, {"GOOGLE_API_KEY": "test_api_key"}) def google_provider(): diff --git a/packages/exchange/tests/providers/test_openai.py b/packages/exchange/tests/providers/test_openai.py index ddf92ab74e90..ea979abeb417 100644 --- a/packages/exchange/tests/providers/test_openai.py +++ b/packages/exchange/tests/providers/test_openai.py @@ -10,6 +10,7 @@ OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini") + def test_from_env_throw_error_when_missing_api_key(): with patch.dict(os.environ, {}, clear=True): with pytest.raises(MissingProviderEnvVariableError) as context: @@ -19,6 +20,7 @@ def test_from_env_throw_error_when_missing_api_key(): assert "Missing environment variable: OPENAI_API_KEY for provider openai" in context.value.message assert "https://platform.openai.com" in context.value.message + @pytest.mark.vcr() def test_openai_complete(default_openai_env): reply_message, reply_usage = complete(OpenAiProvider, OPENAI_MODEL) From 1b725f755f6ebdbdce0bbd3c3c1877590ef291e7 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Thu, 3 Oct 2024 11:31:47 +1000 Subject: [PATCH 06/12] fixed the errror message --- packages/exchange/src/exchange/providers/base.py | 2 +- packages/exchange/src/exchange/providers/google.py | 2 +- packages/exchange/src/exchange/providers/openai.py | 2 +- packages/exchange/tests/test_base.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/packages/exchange/src/exchange/providers/base.py b/packages/exchange/src/exchange/providers/base.py index 7ec8745bc410..78c267e7bacd 100644 --- a/packages/exchange/src/exchange/providers/base.py +++ b/packages/exchange/src/exchange/providers/base.py @@ -37,5 +37,5 @@ def __init__(self, env_variable: str, provider: str, instructions_url: Optional[ self.instructions_url = instructions_url self.message = f"Missing environment variable: {env_variable} for provider {provider}." if instructions_url: - self.message += f"\n Please see {instructions_url} for instructions" + self.message += f"\nPlease see {instructions_url} for instructions" super().__init__(self.message) diff --git a/packages/exchange/src/exchange/providers/google.py b/packages/exchange/src/exchange/providers/google.py index 0c4b43caf32c..4fc020f348d9 100644 --- a/packages/exchange/src/exchange/providers/google.py +++ b/packages/exchange/src/exchange/providers/google.py @@ -26,7 +26,7 @@ def __init__(self, client: httpx.Client) -> None: @classmethod def from_env(cls: Type["GoogleProvider"]) -> "GoogleProvider": url = os.environ.get("GOOGLE_HOST", GOOGLE_HOST) - api_key_instructions_url = "see https://ai.google.dev/gemini-api/docs/api-key" + api_key_instructions_url = "https://ai.google.dev/gemini-api/docs/api-key" key = get_provider_env_value("GOOGLE_API_KEY", "google", api_key_instructions_url) client = httpx.Client( diff --git a/packages/exchange/src/exchange/providers/openai.py b/packages/exchange/src/exchange/providers/openai.py index d9210203d2ea..c30558b8597c 100644 --- a/packages/exchange/src/exchange/providers/openai.py +++ b/packages/exchange/src/exchange/providers/openai.py @@ -37,7 +37,7 @@ def __init__(self, client: httpx.Client) -> None: @classmethod def from_env(cls: Type["OpenAiProvider"]) -> "OpenAiProvider": url = os.environ.get("OPENAI_HOST", OPENAI_HOST) - api_key_instructions_url = "see https://platform.openai.com/docs/api-reference/api-keys" + api_key_instructions_url = "https://platform.openai.com/docs/api-reference/api-keys" key = get_provider_env_value("OPENAI_API_KEY", "openai", api_key_instructions_url) client = httpx.Client( base_url=url + "v1/", diff --git a/packages/exchange/tests/test_base.py b/packages/exchange/tests/test_base.py index 9b3d6d49b85b..4aae8bde5b39 100644 --- a/packages/exchange/tests/test_base.py +++ b/packages/exchange/tests/test_base.py @@ -23,5 +23,5 @@ def test_missing_provider_env_variable_error_with_instructions_url(): assert error.instructions_url == instructions_url assert error.message == ( "Missing environment variable: API_KEY for provider TestProvider.\n" - " Please see http://example.com/instructions for instructions" + "Please see http://example.com/instructions for instructions" ) From 451cda4b1312a5a274e444993ae1e1ad6df32e9c Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Thu, 3 Oct 2024 11:41:54 +1000 Subject: [PATCH 07/12] add a new line in message --- .gitignore | 5 ++++- src/goose/cli/session.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index c9142a70d055..f799b7221698 100644 --- a/.gitignore +++ b/.gitignore @@ -121,4 +121,7 @@ celerybeat.pid .vscode # Autogenerated docs files -docs/docs/reference \ No newline at end of file +docs/docs/reference + +# uv lock file +uv.lock diff --git a/src/goose/cli/session.py b/src/goose/cli/session.py index bdec42d10143..882bc800b150 100644 --- a/src/goose/cli/session.py +++ b/src/goose/cli/session.py @@ -111,7 +111,7 @@ def _create_exchange(self) -> Exchange: sys.exit(1) except LoadExchangeAttributeError as e: error_message = ( - f"[bold red]{e.message}[/bold red].\nPlease check your configuration file at {PROFILES_CONFIG_PATH}. " + f"[bold red]{e.message}[/bold red].\nPlease check your configuration file at {PROFILES_CONFIG_PATH}.\n" + "Configuration doc: https://block-open-source.github.io/goose/configuration.html" ) print(error_message) From 78292623f456e4b3a630cf48e0dafaa177ddefbe Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Fri, 4 Oct 2024 08:26:25 +1000 Subject: [PATCH 08/12] format the code --- tests/cli/test_session.py | 41 ++++++++++++++++++++++++--------------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/tests/cli/test_session.py b/tests/cli/test_session.py index 41d9a33d872e..ddaeaa2fd9e1 100644 --- a/tests/cli/test_session.py +++ b/tests/cli/test_session.py @@ -21,10 +21,11 @@ def mock_specified_session_name(): @pytest.fixture def create_session_with_mock_configs(mock_sessions_path, exchange_factory, profile_factory): - with patch("goose.cli.session.build_exchange") as mock_exchange, patch( - "goose.cli.session.load_profile", return_value=profile_factory() - ), patch("goose.cli.session.SessionNotifier") as mock_session_notifier, patch( - "goose.cli.session.load_provider", return_value="provider" + with ( + patch("goose.cli.session.build_exchange") as mock_exchange, + patch("goose.cli.session.load_profile", return_value=profile_factory()), + patch("goose.cli.session.SessionNotifier") as mock_session_notifier, + patch("goose.cli.session.load_provider", return_value="provider"), ): mock_session_notifier.return_value = MagicMock() mock_exchange.return_value = exchange_factory() @@ -115,9 +116,11 @@ def test_log_log_cost(create_session_with_mock_configs): session = create_session_with_mock_configs() mock_logger = MagicMock() cost_message = "You have used 100 tokens" - with patch("exchange.Exchange.get_token_usage", return_value={}), patch( - "goose.cli.session.get_total_cost_message", return_value=cost_message - ), patch("goose.cli.session.get_logger", return_value=mock_logger): + with ( + patch("exchange.Exchange.get_token_usage", return_value={}), + patch("goose.cli.session.get_total_cost_message", return_value=cost_message), + patch("goose.cli.session.get_logger", return_value=mock_logger), + ): session._log_cost() mock_logger.info.assert_called_once_with(cost_message) @@ -135,9 +138,11 @@ def custom_exchange_generate(self, *args, **kwargs): ] session = create_session_with_mock_configs({"name": SESSION_NAME}) - with patch.object(GoosePromptSession, "get_user_input", side_effect=user_inputs), patch.object( - Exchange, "generate" - ) as mock_generate, patch("goose.cli.session.save_latest_session") as mock_save_latest_session: + with ( + patch.object(GoosePromptSession, "get_user_input", side_effect=user_inputs), + patch.object(Exchange, "generate") as mock_generate, + patch("goose.cli.session.save_latest_session") as mock_save_latest_session, + ): mock_generate.side_effect = lambda *args, **kwargs: custom_exchange_generate(session.exchange, *args, **kwargs) session.run() @@ -158,9 +163,11 @@ def test_set_generated_session_name(create_session_with_mock_configs, mock_sessi def test_create_exchange_exit_when_env_var_does_not_exist(create_session_with_mock_configs, mock_sessions_path): session = create_session_with_mock_configs() expected_error = MissingProviderEnvVariableError(env_variable="OPENAI_API_KEY", provider="openai") - with patch("goose.cli.session.build_exchange", side_effect=expected_error), patch( - "goose.cli.session.print" - ) as mock_print, patch("sys.exit") as mock_exit: + with ( + patch("goose.cli.session.build_exchange", side_effect=expected_error), + patch("goose.cli.session.print") as mock_print, + patch("sys.exit") as mock_exit, + ): session._create_exchange() mock_print.call_args_list[0][0][0].renderable == ( "Missing environment variable OPENAI_API_KEY for provider openai. ", @@ -174,9 +181,11 @@ def test_create_exchange_exit_when_configuration_is_incorrect(create_session_wit expected_error = LoadExchangeAttributeError( attribute_name="provider", attribute_value="wrong_provider", available_values=["openai"] ) - with patch("goose.cli.session.build_exchange", side_effect=expected_error), patch( - "goose.cli.session.print" - ) as mock_print, patch("sys.exit") as mock_exit: + with ( + patch("goose.cli.session.build_exchange", side_effect=expected_error), + patch("goose.cli.session.print") as mock_print, + patch("sys.exit") as mock_exit, + ): session._create_exchange() assert "Unknown provider: wrong_provider. Available providers: openai" in mock_print.call_args_list[0][0][0] mock_exit.assert_called_once_with(1) From 03af0c6534ce1ce1837939c4e1143944fa402cfb Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Fri, 4 Oct 2024 10:31:34 +1000 Subject: [PATCH 09/12] renamed to invalidChoiceError --- ..._exchange_attribute_error.py => invalid_choice_error.py} | 2 +- packages/exchange/src/exchange/moderators/__init__.py | 4 ++-- packages/exchange/src/exchange/providers/__init__.py | 4 ++-- packages/exchange/tests/providers/test_provider.py | 4 ++-- ...ange_attribute_error.py => test_invalid_choice_error.py} | 6 +++--- packages/exchange/tests/test_moderators.py | 4 ++-- src/goose/cli/session.py | 4 ++-- src/goose/toolkit/__init__.py | 4 ++-- tests/cli/test_session.py | 4 ++-- 9 files changed, 18 insertions(+), 18 deletions(-) rename packages/exchange/src/exchange/{load_exchange_attribute_error.py => invalid_choice_error.py} (91%) rename packages/exchange/tests/{test_load_exchange_attribute_error.py => test_invalid_choice_error.py} (64%) diff --git a/packages/exchange/src/exchange/load_exchange_attribute_error.py b/packages/exchange/src/exchange/invalid_choice_error.py similarity index 91% rename from packages/exchange/src/exchange/load_exchange_attribute_error.py rename to packages/exchange/src/exchange/invalid_choice_error.py index bf1c4b607caa..ffbb9899f2a1 100644 --- a/packages/exchange/src/exchange/load_exchange_attribute_error.py +++ b/packages/exchange/src/exchange/invalid_choice_error.py @@ -1,7 +1,7 @@ from typing import List -class LoadExchangeAttributeError(Exception): +class InvalidChoiceError(Exception): def __init__(self, attribute_name: str, attribute_value: str, available_values: List[str]) -> None: self.attribute_name = attribute_name self.attribute_value = attribute_value diff --git a/packages/exchange/src/exchange/moderators/__init__.py b/packages/exchange/src/exchange/moderators/__init__.py index 8bcdae820136..82d032e426ad 100644 --- a/packages/exchange/src/exchange/moderators/__init__.py +++ b/packages/exchange/src/exchange/moderators/__init__.py @@ -1,7 +1,7 @@ from functools import cache from typing import Type -from exchange.load_exchange_attribute_error import LoadExchangeAttributeError +from exchange.invalid_choice_error import InvalidChoiceError from exchange.moderators.base import Moderator from exchange.utils import load_plugins from exchange.moderators.passive import PassiveModerator # noqa @@ -13,5 +13,5 @@ def get_moderator(name: str) -> Type[Moderator]: moderators = load_plugins(group="exchange.moderator") if name not in moderators: - raise LoadExchangeAttributeError("moderator", name, moderators.keys()) + raise InvalidChoiceError("moderator", name, moderators.keys()) return moderators[name] diff --git a/packages/exchange/src/exchange/providers/__init__.py b/packages/exchange/src/exchange/providers/__init__.py index 65e83746871a..f92d4f769d44 100644 --- a/packages/exchange/src/exchange/providers/__init__.py +++ b/packages/exchange/src/exchange/providers/__init__.py @@ -1,7 +1,7 @@ from functools import cache from typing import Type -from exchange.load_exchange_attribute_error import LoadExchangeAttributeError +from exchange.invalid_choice_error import InvalidChoiceError from exchange.providers.anthropic import AnthropicProvider # noqa from exchange.providers.base import Provider, Usage # noqa from exchange.providers.databricks import DatabricksProvider # noqa @@ -17,5 +17,5 @@ def get_provider(name: str) -> Type[Provider]: providers = load_plugins(group="exchange.provider") if name not in providers: - raise LoadExchangeAttributeError("provider", name, providers.keys()) + raise InvalidChoiceError("provider", name, providers.keys()) return providers[name] diff --git a/packages/exchange/tests/providers/test_provider.py b/packages/exchange/tests/providers/test_provider.py index 0038d2ea2bb1..fb7d15ce0dbb 100644 --- a/packages/exchange/tests/providers/test_provider.py +++ b/packages/exchange/tests/providers/test_provider.py @@ -1,5 +1,5 @@ import pytest -from exchange.load_exchange_attribute_error import LoadExchangeAttributeError +from exchange.invalid_choice_error import InvalidChoiceError from exchange.providers import get_provider @@ -10,7 +10,7 @@ def test_get_provider_valid(): def test_get_provider_throw_error_for_unknown_provider(): - with pytest.raises(LoadExchangeAttributeError) as error: + with pytest.raises(InvalidChoiceError) as error: get_provider("nonexistent") assert error.value.attribute_name == "provider" assert error.value.attribute_value == "nonexistent" diff --git a/packages/exchange/tests/test_load_exchange_attribute_error.py b/packages/exchange/tests/test_invalid_choice_error.py similarity index 64% rename from packages/exchange/tests/test_load_exchange_attribute_error.py rename to packages/exchange/tests/test_invalid_choice_error.py index f6f17f14459f..9fad8b12f7f2 100644 --- a/packages/exchange/tests/test_load_exchange_attribute_error.py +++ b/packages/exchange/tests/test_invalid_choice_error.py @@ -1,11 +1,11 @@ -from exchange.load_exchange_attribute_error import LoadExchangeAttributeError +from exchange.invalid_choice_error import InvalidChoiceError -def test_load_exchange_attribute_error(): +def test_load_invalid_choice_error(): attribute_name = "moderator" attribute_value = "not_exist" available_values = ["truncate", "summarizer"] - error = LoadExchangeAttributeError(attribute_name, attribute_value, available_values) + error = InvalidChoiceError(attribute_name, attribute_value, available_values) assert error.attribute_name == attribute_name assert error.attribute_value == attribute_value diff --git a/packages/exchange/tests/test_moderators.py b/packages/exchange/tests/test_moderators.py index 8d8478f84300..16bcaa13bed4 100644 --- a/packages/exchange/tests/test_moderators.py +++ b/packages/exchange/tests/test_moderators.py @@ -1,4 +1,4 @@ -from exchange.load_exchange_attribute_error import LoadExchangeAttributeError +from exchange.invalid_choice_error import InvalidChoiceError from exchange.moderators import get_moderator import pytest @@ -9,7 +9,7 @@ def test_get_moderator(): def test_get_moderator_raise_error_for_unknown_moderator(): - with pytest.raises(LoadExchangeAttributeError) as error: + with pytest.raises(InvalidChoiceError) as error: get_moderator("nonexistent") assert error.value.attribute_name == "moderator" assert error.value.attribute_value == "nonexistent" diff --git a/src/goose/cli/session.py b/src/goose/cli/session.py index 882bc800b150..f9fe13b9a38e 100644 --- a/src/goose/cli/session.py +++ b/src/goose/cli/session.py @@ -5,7 +5,7 @@ from exchange import Message, ToolResult, ToolUse, Text, Exchange from exchange.providers.base import MissingProviderEnvVariableError -from exchange.load_exchange_attribute_error import LoadExchangeAttributeError +from exchange.invalid_choice_error import InvalidChoiceError from rich import print from rich.console import RenderableType from rich.live import Live @@ -109,7 +109,7 @@ def _create_exchange(self) -> Exchange: error_message = f"{e.message}. Please set the required environment variable to continue." print(Panel(error_message, style="red")) sys.exit(1) - except LoadExchangeAttributeError as e: + except InvalidChoiceError as e: error_message = ( f"[bold red]{e.message}[/bold red].\nPlease check your configuration file at {PROFILES_CONFIG_PATH}.\n" + "Configuration doc: https://block-open-source.github.io/goose/configuration.html" diff --git a/src/goose/toolkit/__init__.py b/src/goose/toolkit/__init__.py index 6820cc24cd54..fc561ee67f49 100644 --- a/src/goose/toolkit/__init__.py +++ b/src/goose/toolkit/__init__.py @@ -1,5 +1,5 @@ from functools import cache -from exchange.load_exchange_attribute_error import LoadExchangeAttributeError +from exchange.invalid_choice_error import InvalidChoiceError from goose.toolkit.base import Toolkit from goose.utils import load_plugins @@ -8,5 +8,5 @@ def get_toolkit(name: str) -> type[Toolkit]: toolkits = load_plugins(group="goose.toolkit") if name not in toolkits: - raise LoadExchangeAttributeError("toolkit", name, toolkits.keys()) + raise InvalidChoiceError("toolkit", name, toolkits.keys()) return toolkits[name] diff --git a/tests/cli/test_session.py b/tests/cli/test_session.py index ddaeaa2fd9e1..81519512d99e 100644 --- a/tests/cli/test_session.py +++ b/tests/cli/test_session.py @@ -3,7 +3,7 @@ import pytest from exchange import Exchange, Message, ToolUse, ToolResult from exchange.providers.base import MissingProviderEnvVariableError -from exchange.load_exchange_attribute_error import LoadExchangeAttributeError +from exchange.invalid_choice_error import InvalidChoiceError from goose.cli.prompt.goose_prompt_session import GoosePromptSession from goose.cli.prompt.user_input import PromptAction, UserInput from goose.cli.session import Session @@ -178,7 +178,7 @@ def test_create_exchange_exit_when_env_var_does_not_exist(create_session_with_mo def test_create_exchange_exit_when_configuration_is_incorrect(create_session_with_mock_configs, mock_sessions_path): session = create_session_with_mock_configs() - expected_error = LoadExchangeAttributeError( + expected_error = InvalidChoiceError( attribute_name="provider", attribute_value="wrong_provider", available_values=["openai"] ) with ( From 4e97f0a449b6733f6829440ce7ed1b6679e00efa Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Fri, 4 Oct 2024 10:37:41 +1000 Subject: [PATCH 10/12] inline methods --- packages/exchange/src/exchange/providers/azure.py | 14 ++++++-------- .../exchange/src/exchange/providers/bedrock.py | 13 ++++++------- 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/packages/exchange/src/exchange/providers/azure.py b/packages/exchange/src/exchange/providers/azure.py index 212dcca5e7c5..a06a557d1187 100644 --- a/packages/exchange/src/exchange/providers/azure.py +++ b/packages/exchange/src/exchange/providers/azure.py @@ -5,6 +5,8 @@ from exchange.providers import OpenAiProvider from exchange.providers.utils import get_provider_env_value +PROVIDER_NAME = "azure" + class AzureProvider(OpenAiProvider): """Provides chat completions for models hosted by the Azure OpenAI Service""" @@ -14,11 +16,11 @@ def __init__(self, client: httpx.Client) -> None: @classmethod def from_env(cls: Type["AzureProvider"]) -> "AzureProvider": - url = cls._get_env_variable("AZURE_CHAT_COMPLETIONS_HOST_NAME") - deployment_name = cls._get_env_variable("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME") + url = get_provider_env_value("AZURE_CHAT_COMPLETIONS_HOST_NAME", PROVIDER_NAME) + deployment_name = get_provider_env_value("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME", PROVIDER_NAME) - api_version = cls._get_env_variable("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION") - key = cls._get_env_variable("AZURE_CHAT_COMPLETIONS_KEY") + api_version = get_provider_env_value("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION", PROVIDER_NAME) + key = get_provider_env_value("AZURE_CHAT_COMPLETIONS_KEY", PROVIDER_NAME) # format the url host/"openai/deployments/" + deployment_name + "/?api-version=" + api_version url = f"{url}/openai/deployments/{deployment_name}/" @@ -29,7 +31,3 @@ def from_env(cls: Type["AzureProvider"]) -> "AzureProvider": timeout=httpx.Timeout(60 * 10), ) return cls(client) - - @classmethod - def _get_env_variable(cls: Type["AzureProvider"], key: str) -> str: - return get_provider_env_value(key, "azure") diff --git a/packages/exchange/src/exchange/providers/bedrock.py b/packages/exchange/src/exchange/providers/bedrock.py index 4b02f2391aae..c8c1d6816dc7 100644 --- a/packages/exchange/src/exchange/providers/bedrock.py +++ b/packages/exchange/src/exchange/providers/bedrock.py @@ -146,6 +146,9 @@ def get_signature_key(key: str, date_stamp: str, region_name: str, service_name: return headers +PROVIDER_NAME = "bedrock" + + class BedrockProvider(Provider): def __init__(self, client: AwsClient) -> None: self.client = client @@ -153,9 +156,9 @@ def __init__(self, client: AwsClient) -> None: @classmethod def from_env(cls: Type["BedrockProvider"]) -> "BedrockProvider": aws_region = os.environ.get("AWS_REGION", "us-east-1") - aws_access_key = cls._get_env_variable("AWS_ACCESS_KEY_ID") - aws_secret_key = cls._get_env_variable("AWS_SECRET_ACCESS_KEY") - aws_session_token = cls._get_env_variable("AWS_SESSION_TOKEN") + aws_access_key = get_provider_env_value("AWS_ACCESS_KEY_ID", PROVIDER_NAME) + aws_secret_key = get_provider_env_value("AWS_SECRET_ACCESS_KEY", PROVIDER_NAME) + aws_session_token = get_provider_env_value("AWS_SESSION_TOKEN", PROVIDER_NAME) client = AwsClient( aws_region=aws_region, @@ -322,7 +325,3 @@ def tools_to_bedrock_spec(tools: Tuple[Tool]) -> Optional[dict]: tools_added.add(tool.name) tool_config = {"tools": tool_config_list} return tool_config - - @classmethod - def _get_env_variable(cls: Type["BedrockProvider"], key: str) -> str: - return get_provider_env_value(key, "bedrock") From 36b5d2c58f8bc9577b73d6328492d1b3b6ecd7b3 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Fri, 4 Oct 2024 11:13:36 +1000 Subject: [PATCH 11/12] set default provider --- src/goose/cli/config.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/goose/cli/config.py b/src/goose/cli/config.py index bef458a69105..403ad2cd34f4 100644 --- a/src/goose/cli/config.py +++ b/src/goose/cli/config.py @@ -16,6 +16,7 @@ SESSIONS_PATH = GOOSE_GLOBAL_PATH.joinpath("sessions") SESSION_FILE_SUFFIX = ".jsonl" LOG_PATH = GOOSE_GLOBAL_PATH.joinpath("logs") +DEFAULT_PROVIDER = "openai" @cache @@ -87,7 +88,8 @@ def default_model_configuration() -> Tuple[str, str, str]: break except Exception: pass - + else: + provider = DEFAULT_PROVIDER recommended = { "ollama": (OLLAMA_MODEL, OLLAMA_MODEL), "openai": ("gpt-4o", "gpt-4o-mini"), From 208b2650cd60e31617f1f28c8e1bd5cd360fd58c Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Fri, 4 Oct 2024 11:15:25 +1000 Subject: [PATCH 12/12] rename variable --- src/goose/cli/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/goose/cli/config.py b/src/goose/cli/config.py index 403ad2cd34f4..7bede0be56f1 100644 --- a/src/goose/cli/config.py +++ b/src/goose/cli/config.py @@ -16,7 +16,7 @@ SESSIONS_PATH = GOOSE_GLOBAL_PATH.joinpath("sessions") SESSION_FILE_SUFFIX = ".jsonl" LOG_PATH = GOOSE_GLOBAL_PATH.joinpath("logs") -DEFAULT_PROVIDER = "openai" +RECOMMENDED_DEFAULT_PROVIDER = "openai" @cache @@ -89,7 +89,7 @@ def default_model_configuration() -> Tuple[str, str, str]: except Exception: pass else: - provider = DEFAULT_PROVIDER + provider = RECOMMENDED_DEFAULT_PROVIDER recommended = { "ollama": (OLLAMA_MODEL, OLLAMA_MODEL), "openai": ("gpt-4o", "gpt-4o-mini"),