Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/exchange/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ anthropic = "exchange.providers.anthropic:AnthropicProvider"
bedrock = "exchange.providers.bedrock:BedrockProvider"
ollama = "exchange.providers.ollama:OllamaProvider"
google = "exchange.providers.google:GoogleProvider"
groq = "exchange.providers.groq:GroqProvider"

[project.entry-points."exchange.moderator"]
passive = "exchange.moderators.passive:PassiveModerator"
Expand Down
1 change: 1 addition & 0 deletions packages/exchange/src/exchange/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from exchange.providers.databricks import DatabricksProvider # noqa
from exchange.providers.openai import OpenAiProvider # noqa
from exchange.providers.ollama import OllamaProvider # noqa
from exchange.providers.groq import GroqProvider # noqa
from exchange.providers.azure import AzureProvider # noqa
from exchange.providers.google import GoogleProvider # noqa

Expand Down
97 changes: 97 additions & 0 deletions packages/exchange/src/exchange/providers/groq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import os
from typing import Any, Dict, List, Tuple, Type
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:nit: i think our patterns here don't realize as of python 3.9+ we generally don't need Dict, List, and Tuple anymore we can use the primitives (dict, list, and tuple respectively)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#149 to do this everywhere


import httpx

from exchange.message import Message
from exchange.providers.base import Provider, Usage
from exchange.providers.utils import (
messages_to_openai_spec,
openai_response_to_message,
openai_single_message_context_length_exceeded,
raise_for_status,
tools_to_openai_spec,
)
from exchange.tool import Tool
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import retry_if_status

GROQ_HOST = "https://api.groq.com/openai/"

retry_procedure = retry(
wait=wait_fixed(5),
stop=stop_after_attempt(5),
retry=retry_if_status(codes=[429], above=500),
reraise=True,
)


class GroqProvider(Provider):
"""Provides chat completions for models hosted directly by OpenAI."""

PROVIDER_NAME = "groq"
REQUIRED_ENV_VARS = ["GROQ_API_KEY"]
instructions_url = "https://console.groq.com/docs/quickstart"

def __init__(self, client: httpx.Client) -> None:
self.client = client

@classmethod
def from_env(cls: Type["GroqProvider"]) -> "GroqProvider":
cls.check_env_vars(cls.instructions_url)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@elenazherdeva whenever we revisit this, we could check this implicitly so the user's don't need to pass it in (e.g. we can do the check for if instructions_url: ...).

what do you think? simplifies the implementation of adding a new provider

url = os.environ.get("GROQ_HOST", GROQ_HOST)
key = os.environ.get("GROQ_API_KEY")

client = httpx.Client(
base_url=url + "v1/",
headers={"Authorization": "Bearer " + key},
timeout=httpx.Timeout(60 * 10),
)
return cls(client)

@staticmethod
def get_usage(data: dict) -> Usage:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lifeizhou-ap this might be a useful thing to promote to a @classmethod (not sure i understand when we'd need to use it as a static method Provider.get_usage({...})). let's consider making an abstraction to make adding providers super easy?

usage = data.pop("usage")
input_tokens = usage.get("prompt_tokens")
output_tokens = usage.get("completion_tokens")
total_tokens = usage.get("total_tokens")

if total_tokens is None and input_tokens is not None and output_tokens is not None:
total_tokens = input_tokens + output_tokens

return Usage(
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=total_tokens,
)

def complete(
self,
model: str,
system: str,
messages: List[Message],
tools: Tuple[Tool],
**kwargs: Dict[str, Any],
) -> Tuple[Message, Usage]:
system_message = [{"role": "system", "content": system}]
payload = dict(
messages=system_message + messages_to_openai_spec(messages),
model=model,
tools=tools_to_openai_spec(tools) if tools else [],
**kwargs,
)
payload = {k: v for k, v in payload.items() if v}
response = self._post(payload)

# Check for context_length_exceeded error for single, long input message
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copy pasta? maybe delete?

if "error" in response and len(messages) == 1:
openai_single_message_context_length_exceeded(response["error"])

message = openai_response_to_message(response)
usage = self.get_usage(response)
return message, usage

@retry_procedure
def _post(self, payload: dict) -> dict:
response = self.client.post("chat/completions", json=payload)
return raise_for_status(response).json()