Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions packages/exchange/src/exchange/providers/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from exchange.content import Text, ToolResult, ToolUse
from exchange.providers.base import Provider, Usage
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import get_provider_env_value, retry_if_status, raise_for_status
from exchange.providers.utils import retry_if_status, raise_for_status

ANTHROPIC_HOST = "https://api.anthropic.com/v1/messages"

Expand All @@ -20,13 +20,19 @@


class AnthropicProvider(Provider):
"""Provides chat completions for models hosted directly by Anthropic."""

PROVIDER_NAME = "anthropic"
REQUIRED_ENV_VARS = ["ANTHROPIC_API_KEY"]

def __init__(self, client: httpx.Client) -> None:
self.client = client

@classmethod
def from_env(cls: Type["AnthropicProvider"]) -> "AnthropicProvider":
cls.check_env_vars()
url = os.environ.get("ANTHROPIC_HOST", ANTHROPIC_HOST)
key = get_provider_env_value("ANTHROPIC_API_KEY", "anthropic")
key = os.environ.get("ANTHROPIC_API_KEY")
client = httpx.Client(
base_url=url,
headers={
Expand Down
24 changes: 15 additions & 9 deletions packages/exchange/src/exchange/providers/azure.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,32 @@
from typing import Type

import httpx
import os

from exchange.providers import OpenAiProvider
from exchange.providers.utils import get_provider_env_value

PROVIDER_NAME = "azure"


class AzureProvider(OpenAiProvider):
"""Provides chat completions for models hosted by the Azure OpenAI Service"""
"""Provides chat completions for models hosted by the Azure OpenAI Service."""

PROVIDER_NAME = "azure"
REQUIRED_ENV_VARS = [
"AZURE_CHAT_COMPLETIONS_HOST_NAME",
"AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME",
"AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION",
"AZURE_CHAT_COMPLETIONS_KEY",
]

def __init__(self, client: httpx.Client) -> None:
super().__init__(client)

@classmethod
def from_env(cls: Type["AzureProvider"]) -> "AzureProvider":
url = get_provider_env_value("AZURE_CHAT_COMPLETIONS_HOST_NAME", PROVIDER_NAME)
deployment_name = get_provider_env_value("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME", PROVIDER_NAME)

api_version = get_provider_env_value("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION", PROVIDER_NAME)
key = get_provider_env_value("AZURE_CHAT_COMPLETIONS_KEY", PROVIDER_NAME)
cls.check_env_vars()
url = os.environ.get("AZURE_CHAT_COMPLETIONS_HOST_NAME")
deployment_name = os.environ.get("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME")
api_version = os.environ.get("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION")
key = os.environ.get("AZURE_CHAT_COMPLETIONS_KEY")

# format the url host/"openai/deployments/" + deployment_name + "/?api-version=" + api_version
url = f"{url}/openai/deployments/{deployment_name}/"
Expand Down
10 changes: 10 additions & 0 deletions packages/exchange/src/exchange/providers/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from abc import ABC, abstractmethod
from attrs import define, field
from typing import List, Optional, Tuple, Type
Expand All @@ -14,10 +15,19 @@ class Usage:


class Provider(ABC):
PROVIDER_NAME: str
REQUIRED_ENV_VARS: list[str] = []

@classmethod
def from_env(cls: Type["Provider"]) -> "Provider":
return cls()

@classmethod
def check_env_vars(cls: Type["Provider"], instructions_url: Optional[str] = None) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

for env_var in cls.REQUIRED_ENV_VARS:
if env_var not in os.environ:
raise MissingProviderEnvVariableError(env_var, cls.PROVIDER_NAME, instructions_url)

@abstractmethod
def complete(
self,
Expand Down
19 changes: 13 additions & 6 deletions packages/exchange/src/exchange/providers/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from exchange.message import Message
from exchange.providers import Provider, Usage
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import get_provider_env_value, raise_for_status, retry_if_status
from exchange.providers.utils import raise_for_status, retry_if_status
from exchange.tool import Tool

SERVICE = "bedrock-runtime"
Expand Down Expand Up @@ -146,19 +146,26 @@ def get_signature_key(key: str, date_stamp: str, region_name: str, service_name:
return headers


PROVIDER_NAME = "bedrock"
class BedrockProvider(Provider):
"""Provides chat completions for models hosted by the Amazon Bedrock Service"""

PROVIDER_NAME = "bedrock"
REQUIRED_ENV_VARS = [
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
"AWS_SESSION_TOKEN",
]

class BedrockProvider(Provider):
def __init__(self, client: AwsClient) -> None:
self.client = client

@classmethod
def from_env(cls: Type["BedrockProvider"]) -> "BedrockProvider":
cls.check_env_vars()
aws_region = os.environ.get("AWS_REGION", "us-east-1")
aws_access_key = get_provider_env_value("AWS_ACCESS_KEY_ID", PROVIDER_NAME)
aws_secret_key = get_provider_env_value("AWS_SECRET_ACCESS_KEY", PROVIDER_NAME)
aws_session_token = get_provider_env_value("AWS_SESSION_TOKEN", PROVIDER_NAME)
aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID")
aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY")
aws_session_token = os.environ.get("AWS_SESSION_TOKEN")

client = AwsClient(
aws_region=aws_region,
Expand Down
24 changes: 14 additions & 10 deletions packages/exchange/src/exchange/providers/databricks.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Any, Dict, List, Tuple, Type

import httpx
import os

from exchange.message import Message
from exchange.providers.base import Provider, Usage
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import get_provider_env_value, raise_for_status, retry_if_status
from exchange.providers.utils import raise_for_status, retry_if_status
from exchange.providers.utils import (
messages_to_openai_spec,
openai_response_to_message,
Expand All @@ -23,21 +24,29 @@


class DatabricksProvider(Provider):
"""Provides chat completions for models on Databricks serving endpoints
"""Provides chat completions for models on Databricks serving endpoints.

Models are expected to follow the llm/v1/chat "task". This includes support
for foundation and external model endpoints
https://docs.databricks.com/en/machine-learning/model-serving/create-foundation-model-endpoints.html#create-generative-ai-model-serving-endpoints

"""

PROVIDER_NAME = "databricks"
REQUIRED_ENV_VARS = [
"DATABRICKS_HOST",
"DATABRICKS_TOKEN",
]
instructions_url = "https://docs.databricks.com/en/dev-tools/auth/index.html#general-host-token-and-account-id-environment-variables-and-fields"

def __init__(self, client: httpx.Client) -> None:
super().__init__()
self.client = client

@classmethod
def from_env(cls: Type["DatabricksProvider"]) -> "DatabricksProvider":
url = cls._get_env_variable("DATABRICKS_HOST")
key = cls._get_env_variable("DATABRICKS_TOKEN")
cls.check_env_vars(cls.instructions_url)
url = os.environ.get("DATABRICKS_HOST")
key = os.environ.get("DATABRICKS_TOKEN")
client = httpx.Client(
base_url=url,
auth=("token", key),
Expand Down Expand Up @@ -89,8 +98,3 @@ def _post(self, model: str, payload: dict) -> httpx.Response:
json=payload,
)
return raise_for_status(response).json()

@classmethod
def _get_env_variable(cls: Type["DatabricksProvider"], key: str) -> str:
instruction = "https://docs.databricks.com/en/dev-tools/auth/index.html#general-host-token-and-account-id-environment-variables-and-fields"
return get_provider_env_value(key, "databricks", instruction)
13 changes: 9 additions & 4 deletions packages/exchange/src/exchange/providers/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from exchange.content import Text, ToolResult, ToolUse
from exchange.providers.base import Provider, Usage
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import get_provider_env_value, raise_for_status, retry_if_status
from exchange.providers.utils import raise_for_status, retry_if_status

GOOGLE_HOST = "https://generativelanguage.googleapis.com/v1beta"

Expand All @@ -20,15 +20,20 @@


class GoogleProvider(Provider):
"""Provides chat completions for models hosted by Google, including Gemini and other experimental models."""

PROVIDER_NAME = "google"
REQUIRED_ENV_VARS = ["GOOGLE_API_KEY"]
instructions_url = "https://ai.google.dev/gemini-api/docs/api-key"

def __init__(self, client: httpx.Client) -> None:
self.client = client

@classmethod
def from_env(cls: Type["GoogleProvider"]) -> "GoogleProvider":
cls.check_env_vars(cls.instructions_url)
url = os.environ.get("GOOGLE_HOST", GOOGLE_HOST)
api_key_instructions_url = "https://ai.google.dev/gemini-api/docs/api-key"
key = get_provider_env_value("GOOGLE_API_KEY", "google", api_key_instructions_url)

key = os.environ.get("GOOGLE_API_KEY")
client = httpx.Client(
base_url=url,
headers={
Expand Down
2 changes: 1 addition & 1 deletion packages/exchange/src/exchange/providers/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


class OllamaProvider(OpenAiProvider):
"""Provides chat completions for models hosted by Ollama"""
"""Provides chat completions for models hosted by Ollama."""

__doc__ += f"""

Expand Down
13 changes: 8 additions & 5 deletions packages/exchange/src/exchange/providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from exchange.message import Message
from exchange.providers.base import Provider, Usage
from exchange.providers.utils import (
get_provider_env_value,
messages_to_openai_spec,
openai_response_to_message,
openai_single_message_context_length_exceeded,
Expand All @@ -28,17 +27,21 @@


class OpenAiProvider(Provider):
"""Provides chat completions for models hosted directly by OpenAI"""
"""Provides chat completions for models hosted directly by OpenAI."""

PROVIDER_NAME = "openai"
REQUIRED_ENV_VARS = ["OPENAI_API_KEY"]
instructions_url = "https://platform.openai.com/docs/api-reference/api-keys"

def __init__(self, client: httpx.Client) -> None:
super().__init__()
self.client = client

@classmethod
def from_env(cls: Type["OpenAiProvider"]) -> "OpenAiProvider":
cls.check_env_vars(cls.instructions_url)
url = os.environ.get("OPENAI_HOST", OPENAI_HOST)
api_key_instructions_url = "https://platform.openai.com/docs/api-reference/api-keys"
key = get_provider_env_value("OPENAI_API_KEY", "openai", api_key_instructions_url)
key = os.environ.get("OPENAI_API_KEY")

client = httpx.Client(
base_url=url + "v1/",
auth=("Bearer", key),
Expand Down
9 changes: 0 additions & 9 deletions packages/exchange/src/exchange/providers/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import base64
import json
import os
import re
from typing import Any, Callable, Dict, List, Optional, Tuple

import httpx
from exchange.content import Text, ToolResult, ToolUse
from exchange.message import Message
from exchange.providers.base import MissingProviderEnvVariableError
from exchange.tool import Tool
from tenacity import retry_if_exception

Expand Down Expand Up @@ -181,13 +179,6 @@ def openai_single_message_context_length_exceeded(error_dict: dict) -> None:
raise InitialMessageTooLargeError(f"Input message too long. Message: {error_dict.get('message')}")


def get_provider_env_value(env_variable: str, provider: str, instructions_url: Optional[str] = None) -> str:
try:
return os.environ[env_variable]
except KeyError:
raise MissingProviderEnvVariableError(env_variable, provider, instructions_url)


class InitialMessageTooLargeError(Exception):
"""Custom error raised when the first input message in an exchange is too large."""

Expand Down
22 changes: 22 additions & 0 deletions src/goose/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,28 @@ def list_toolkits() -> None:
print(f" - [bold]{toolkit_name}[/bold]: {first_line_of_doc}")


@goose_cli.group()
def providers() -> None:
"""Manage providers"""
pass


@providers.command(name="list")
def list_providers() -> None:
Copy link
Collaborator

@lifeizhou-ap lifeizhou-ap Oct 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe write a test for this function?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kinda hacky but if the exchange/providers project is visible via the goose distro but might be possible to render this with an os.listdirs()

providers = load_plugins(group="exchange.provider")

for provider_name, provider in providers.items():
lines_doc = provider.__doc__.split("\n")
first_line_of_doc = lines_doc[0]
print(f" - [bold]{provider_name}[/bold]: {first_line_of_doc}")
envs = provider.REQUIRED_ENV_VARS
if envs:
env_required_str = ", ".join(envs)
print(f" [dim]env vars required: {env_required_str}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:nit: since these are required i think they should be emphasized instead of dim


print("\n")


def autocomplete_session_files(ctx: click.Context, args: str, incomplete: str) -> None:
return [
f"{session_name}"
Expand Down
11 changes: 11 additions & 0 deletions src/goose/toolkit/lint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,14 @@ def lint_toolkits() -> None:
assert first_line_of_docstring[
0
].isupper(), f"`{toolkit_name}` toolkit docstring must start with a capital letter"


def lint_providers() -> None:
for provider_name, provider in load_plugins(group="exchange.provider").items():
assert provider.__doc__ is not None, f"`{provider_name}` provider must have a docstring"
first_line_of_docstring = provider.__doc__.split("\n")[0]
assert len(first_line_of_docstring.split(" ")) > 5, f"`{provider_name}` provider docstring is too short"
assert len(first_line_of_docstring.split(" ")) < 20, f"`{provider_name}` provider docstring is too long"
assert first_line_of_docstring[
0
].isupper(), f"`{provider_name}` provider docstring must start with a capital letter"
6 changes: 6 additions & 0 deletions tests/test_linting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from goose.toolkit.lint import lint_toolkits

from goose.toolkit.lint import lint_providers


def test_lint_toolkits():
lint_toolkits()


def test_lint_providers():
lint_providers()
Copy link
Collaborator

@lifeizhou-ap lifeizhou-ap Oct 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you followed the existing implementation pattern.

I am not sure whether lint_providers function is used anywhere else. If it does not, we could move the implementation of lint_providers into this test_linting.py as the implementation is just to check the doc string is correct. I feel it will be easier for people to read and understand

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lukealvoeiro can describe it better, but the reason behind this is that lint_providers could be used in tests for the goose-plugins-block internal library. Currently, we only have one internal provider, but if there are more in the future, they should follow the same structure. Do you think we should reconsider this logic?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, we want this function to be exported so we can lint this in any open source or internal plugin repos as well.