diff --git a/packages/exchange/src/exchange/providers/anthropic.py b/packages/exchange/src/exchange/providers/anthropic.py index bf052b20b814..84ecd12fb4e2 100644 --- a/packages/exchange/src/exchange/providers/anthropic.py +++ b/packages/exchange/src/exchange/providers/anthropic.py @@ -7,7 +7,7 @@ from exchange.content import Text, ToolResult, ToolUse from exchange.providers.base import Provider, Usage from tenacity import retry, wait_fixed, stop_after_attempt -from exchange.providers.utils import get_provider_env_value, retry_if_status, raise_for_status +from exchange.providers.utils import retry_if_status, raise_for_status ANTHROPIC_HOST = "https://api.anthropic.com/v1/messages" @@ -20,13 +20,19 @@ class AnthropicProvider(Provider): + """Provides chat completions for models hosted directly by Anthropic.""" + + PROVIDER_NAME = "anthropic" + REQUIRED_ENV_VARS = ["ANTHROPIC_API_KEY"] + def __init__(self, client: httpx.Client) -> None: self.client = client @classmethod def from_env(cls: Type["AnthropicProvider"]) -> "AnthropicProvider": + cls.check_env_vars() url = os.environ.get("ANTHROPIC_HOST", ANTHROPIC_HOST) - key = get_provider_env_value("ANTHROPIC_API_KEY", "anthropic") + key = os.environ.get("ANTHROPIC_API_KEY") client = httpx.Client( base_url=url, headers={ diff --git a/packages/exchange/src/exchange/providers/azure.py b/packages/exchange/src/exchange/providers/azure.py index a06a557d1187..4d470f9782de 100644 --- a/packages/exchange/src/exchange/providers/azure.py +++ b/packages/exchange/src/exchange/providers/azure.py @@ -1,26 +1,32 @@ from typing import Type import httpx +import os from exchange.providers import OpenAiProvider -from exchange.providers.utils import get_provider_env_value - -PROVIDER_NAME = "azure" class AzureProvider(OpenAiProvider): - """Provides chat completions for models hosted by the Azure OpenAI Service""" + """Provides chat completions for models hosted by the Azure OpenAI Service.""" + + PROVIDER_NAME = "azure" + REQUIRED_ENV_VARS = [ + "AZURE_CHAT_COMPLETIONS_HOST_NAME", + "AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME", + "AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION", + "AZURE_CHAT_COMPLETIONS_KEY", + ] def __init__(self, client: httpx.Client) -> None: super().__init__(client) @classmethod def from_env(cls: Type["AzureProvider"]) -> "AzureProvider": - url = get_provider_env_value("AZURE_CHAT_COMPLETIONS_HOST_NAME", PROVIDER_NAME) - deployment_name = get_provider_env_value("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME", PROVIDER_NAME) - - api_version = get_provider_env_value("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION", PROVIDER_NAME) - key = get_provider_env_value("AZURE_CHAT_COMPLETIONS_KEY", PROVIDER_NAME) + cls.check_env_vars() + url = os.environ.get("AZURE_CHAT_COMPLETIONS_HOST_NAME") + deployment_name = os.environ.get("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME") + api_version = os.environ.get("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION") + key = os.environ.get("AZURE_CHAT_COMPLETIONS_KEY") # format the url host/"openai/deployments/" + deployment_name + "/?api-version=" + api_version url = f"{url}/openai/deployments/{deployment_name}/" diff --git a/packages/exchange/src/exchange/providers/base.py b/packages/exchange/src/exchange/providers/base.py index 78c267e7bacd..c8d860ecc0ca 100644 --- a/packages/exchange/src/exchange/providers/base.py +++ b/packages/exchange/src/exchange/providers/base.py @@ -1,3 +1,4 @@ +import os from abc import ABC, abstractmethod from attrs import define, field from typing import List, Optional, Tuple, Type @@ -14,10 +15,19 @@ class Usage: class Provider(ABC): + PROVIDER_NAME: str + REQUIRED_ENV_VARS: list[str] = [] + @classmethod def from_env(cls: Type["Provider"]) -> "Provider": return cls() + @classmethod + def check_env_vars(cls: Type["Provider"], instructions_url: Optional[str] = None) -> None: + for env_var in cls.REQUIRED_ENV_VARS: + if env_var not in os.environ: + raise MissingProviderEnvVariableError(env_var, cls.PROVIDER_NAME, instructions_url) + @abstractmethod def complete( self, diff --git a/packages/exchange/src/exchange/providers/bedrock.py b/packages/exchange/src/exchange/providers/bedrock.py index c8c1d6816dc7..6c32d7cb3d7a 100644 --- a/packages/exchange/src/exchange/providers/bedrock.py +++ b/packages/exchange/src/exchange/providers/bedrock.py @@ -13,7 +13,7 @@ from exchange.message import Message from exchange.providers import Provider, Usage from tenacity import retry, wait_fixed, stop_after_attempt -from exchange.providers.utils import get_provider_env_value, raise_for_status, retry_if_status +from exchange.providers.utils import raise_for_status, retry_if_status from exchange.tool import Tool SERVICE = "bedrock-runtime" @@ -146,19 +146,26 @@ def get_signature_key(key: str, date_stamp: str, region_name: str, service_name: return headers -PROVIDER_NAME = "bedrock" +class BedrockProvider(Provider): + """Provides chat completions for models hosted by the Amazon Bedrock Service""" + PROVIDER_NAME = "bedrock" + REQUIRED_ENV_VARS = [ + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "AWS_SESSION_TOKEN", + ] -class BedrockProvider(Provider): def __init__(self, client: AwsClient) -> None: self.client = client @classmethod def from_env(cls: Type["BedrockProvider"]) -> "BedrockProvider": + cls.check_env_vars() aws_region = os.environ.get("AWS_REGION", "us-east-1") - aws_access_key = get_provider_env_value("AWS_ACCESS_KEY_ID", PROVIDER_NAME) - aws_secret_key = get_provider_env_value("AWS_SECRET_ACCESS_KEY", PROVIDER_NAME) - aws_session_token = get_provider_env_value("AWS_SESSION_TOKEN", PROVIDER_NAME) + aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID") + aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY") + aws_session_token = os.environ.get("AWS_SESSION_TOKEN") client = AwsClient( aws_region=aws_region, diff --git a/packages/exchange/src/exchange/providers/databricks.py b/packages/exchange/src/exchange/providers/databricks.py index 77d392e8d6a2..9bd582dc581a 100644 --- a/packages/exchange/src/exchange/providers/databricks.py +++ b/packages/exchange/src/exchange/providers/databricks.py @@ -1,11 +1,12 @@ from typing import Any, Dict, List, Tuple, Type import httpx +import os from exchange.message import Message from exchange.providers.base import Provider, Usage from tenacity import retry, wait_fixed, stop_after_attempt -from exchange.providers.utils import get_provider_env_value, raise_for_status, retry_if_status +from exchange.providers.utils import raise_for_status, retry_if_status from exchange.providers.utils import ( messages_to_openai_spec, openai_response_to_message, @@ -23,21 +24,29 @@ class DatabricksProvider(Provider): - """Provides chat completions for models on Databricks serving endpoints + """Provides chat completions for models on Databricks serving endpoints. Models are expected to follow the llm/v1/chat "task". This includes support for foundation and external model endpoints https://docs.databricks.com/en/machine-learning/model-serving/create-foundation-model-endpoints.html#create-generative-ai-model-serving-endpoints + """ + PROVIDER_NAME = "databricks" + REQUIRED_ENV_VARS = [ + "DATABRICKS_HOST", + "DATABRICKS_TOKEN", + ] + instructions_url = "https://docs.databricks.com/en/dev-tools/auth/index.html#general-host-token-and-account-id-environment-variables-and-fields" + def __init__(self, client: httpx.Client) -> None: - super().__init__() self.client = client @classmethod def from_env(cls: Type["DatabricksProvider"]) -> "DatabricksProvider": - url = cls._get_env_variable("DATABRICKS_HOST") - key = cls._get_env_variable("DATABRICKS_TOKEN") + cls.check_env_vars(cls.instructions_url) + url = os.environ.get("DATABRICKS_HOST") + key = os.environ.get("DATABRICKS_TOKEN") client = httpx.Client( base_url=url, auth=("token", key), @@ -89,8 +98,3 @@ def _post(self, model: str, payload: dict) -> httpx.Response: json=payload, ) return raise_for_status(response).json() - - @classmethod - def _get_env_variable(cls: Type["DatabricksProvider"], key: str) -> str: - instruction = "https://docs.databricks.com/en/dev-tools/auth/index.html#general-host-token-and-account-id-environment-variables-and-fields" - return get_provider_env_value(key, "databricks", instruction) diff --git a/packages/exchange/src/exchange/providers/google.py b/packages/exchange/src/exchange/providers/google.py index 4fc020f348d9..fe83cd605139 100644 --- a/packages/exchange/src/exchange/providers/google.py +++ b/packages/exchange/src/exchange/providers/google.py @@ -7,7 +7,7 @@ from exchange.content import Text, ToolResult, ToolUse from exchange.providers.base import Provider, Usage from tenacity import retry, wait_fixed, stop_after_attempt -from exchange.providers.utils import get_provider_env_value, raise_for_status, retry_if_status +from exchange.providers.utils import raise_for_status, retry_if_status GOOGLE_HOST = "https://generativelanguage.googleapis.com/v1beta" @@ -20,15 +20,20 @@ class GoogleProvider(Provider): + """Provides chat completions for models hosted by Google, including Gemini and other experimental models.""" + + PROVIDER_NAME = "google" + REQUIRED_ENV_VARS = ["GOOGLE_API_KEY"] + instructions_url = "https://ai.google.dev/gemini-api/docs/api-key" + def __init__(self, client: httpx.Client) -> None: self.client = client @classmethod def from_env(cls: Type["GoogleProvider"]) -> "GoogleProvider": + cls.check_env_vars(cls.instructions_url) url = os.environ.get("GOOGLE_HOST", GOOGLE_HOST) - api_key_instructions_url = "https://ai.google.dev/gemini-api/docs/api-key" - key = get_provider_env_value("GOOGLE_API_KEY", "google", api_key_instructions_url) - + key = os.environ.get("GOOGLE_API_KEY") client = httpx.Client( base_url=url, headers={ diff --git a/packages/exchange/src/exchange/providers/ollama.py b/packages/exchange/src/exchange/providers/ollama.py index acad89d9f1e1..db4094e191f0 100644 --- a/packages/exchange/src/exchange/providers/ollama.py +++ b/packages/exchange/src/exchange/providers/ollama.py @@ -10,7 +10,7 @@ class OllamaProvider(OpenAiProvider): - """Provides chat completions for models hosted by Ollama""" + """Provides chat completions for models hosted by Ollama.""" __doc__ += f""" diff --git a/packages/exchange/src/exchange/providers/openai.py b/packages/exchange/src/exchange/providers/openai.py index c30558b8597c..b25c5a70aabc 100644 --- a/packages/exchange/src/exchange/providers/openai.py +++ b/packages/exchange/src/exchange/providers/openai.py @@ -6,7 +6,6 @@ from exchange.message import Message from exchange.providers.base import Provider, Usage from exchange.providers.utils import ( - get_provider_env_value, messages_to_openai_spec, openai_response_to_message, openai_single_message_context_length_exceeded, @@ -28,17 +27,21 @@ class OpenAiProvider(Provider): - """Provides chat completions for models hosted directly by OpenAI""" + """Provides chat completions for models hosted directly by OpenAI.""" + + PROVIDER_NAME = "openai" + REQUIRED_ENV_VARS = ["OPENAI_API_KEY"] + instructions_url = "https://platform.openai.com/docs/api-reference/api-keys" def __init__(self, client: httpx.Client) -> None: - super().__init__() self.client = client @classmethod def from_env(cls: Type["OpenAiProvider"]) -> "OpenAiProvider": + cls.check_env_vars(cls.instructions_url) url = os.environ.get("OPENAI_HOST", OPENAI_HOST) - api_key_instructions_url = "https://platform.openai.com/docs/api-reference/api-keys" - key = get_provider_env_value("OPENAI_API_KEY", "openai", api_key_instructions_url) + key = os.environ.get("OPENAI_API_KEY") + client = httpx.Client( base_url=url + "v1/", auth=("Bearer", key), diff --git a/packages/exchange/src/exchange/providers/utils.py b/packages/exchange/src/exchange/providers/utils.py index 01504305644e..4be7ac31e4a9 100644 --- a/packages/exchange/src/exchange/providers/utils.py +++ b/packages/exchange/src/exchange/providers/utils.py @@ -1,13 +1,11 @@ import base64 import json -import os import re from typing import Any, Callable, Dict, List, Optional, Tuple import httpx from exchange.content import Text, ToolResult, ToolUse from exchange.message import Message -from exchange.providers.base import MissingProviderEnvVariableError from exchange.tool import Tool from tenacity import retry_if_exception @@ -181,13 +179,6 @@ def openai_single_message_context_length_exceeded(error_dict: dict) -> None: raise InitialMessageTooLargeError(f"Input message too long. Message: {error_dict.get('message')}") -def get_provider_env_value(env_variable: str, provider: str, instructions_url: Optional[str] = None) -> str: - try: - return os.environ[env_variable] - except KeyError: - raise MissingProviderEnvVariableError(env_variable, provider, instructions_url) - - class InitialMessageTooLargeError(Exception): """Custom error raised when the first input message in an exchange is too large.""" diff --git a/src/goose/cli/main.py b/src/goose/cli/main.py index e60facc36bb4..866737d2012d 100644 --- a/src/goose/cli/main.py +++ b/src/goose/cli/main.py @@ -97,6 +97,28 @@ def list_toolkits() -> None: print(f" - [bold]{toolkit_name}[/bold]: {first_line_of_doc}") +@goose_cli.group() +def providers() -> None: + """Manage providers""" + pass + + +@providers.command(name="list") +def list_providers() -> None: + providers = load_plugins(group="exchange.provider") + + for provider_name, provider in providers.items(): + lines_doc = provider.__doc__.split("\n") + first_line_of_doc = lines_doc[0] + print(f" - [bold]{provider_name}[/bold]: {first_line_of_doc}") + envs = provider.REQUIRED_ENV_VARS + if envs: + env_required_str = ", ".join(envs) + print(f" [dim]env vars required: {env_required_str}") + + print("\n") + + def autocomplete_session_files(ctx: click.Context, args: str, incomplete: str) -> None: return [ f"{session_name}" diff --git a/src/goose/toolkit/lint.py b/src/goose/toolkit/lint.py index 0f08f222dcfc..a12335c740c7 100644 --- a/src/goose/toolkit/lint.py +++ b/src/goose/toolkit/lint.py @@ -10,3 +10,14 @@ def lint_toolkits() -> None: assert first_line_of_docstring[ 0 ].isupper(), f"`{toolkit_name}` toolkit docstring must start with a capital letter" + + +def lint_providers() -> None: + for provider_name, provider in load_plugins(group="exchange.provider").items(): + assert provider.__doc__ is not None, f"`{provider_name}` provider must have a docstring" + first_line_of_docstring = provider.__doc__.split("\n")[0] + assert len(first_line_of_docstring.split(" ")) > 5, f"`{provider_name}` provider docstring is too short" + assert len(first_line_of_docstring.split(" ")) < 20, f"`{provider_name}` provider docstring is too long" + assert first_line_of_docstring[ + 0 + ].isupper(), f"`{provider_name}` provider docstring must start with a capital letter" diff --git a/tests/test_linting.py b/tests/test_linting.py index f6e246ff629c..cae0d1a73e76 100644 --- a/tests/test_linting.py +++ b/tests/test_linting.py @@ -1,5 +1,11 @@ from goose.toolkit.lint import lint_toolkits +from goose.toolkit.lint import lint_providers + def test_lint_toolkits(): lint_toolkits() + + +def test_lint_providers(): + lint_providers()