Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ankit/add claude #22

Merged
merged 5 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backend/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
-e .
anthropic==0.37.1
Flask==3.0.3
Flask-SocketIO==5.3.6
Flask-Cors==4.0.1
Expand Down
2 changes: 1 addition & 1 deletion backend/spielberg/agents/prompt_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def _text_prompter(self, transcript_text, prompt):
self.llm.chat_completions,
[
ContextMessage(
content=prompt, role=RoleTypes.system
content=prompt, role=RoleTypes.user
).to_llm_msg()
],
response_format={"type": "json_object"},
Expand Down
2 changes: 2 additions & 0 deletions backend/spielberg/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ class LLMType(str, Enum):
"""Enum for LLM types"""

OPENAI = "openai"
ANTHROPIC = "anthropic"


class EnvPrefix(str, Enum):
"""Enum for environment prefixes"""

OPENAI_ = "OPENAI_"
ANTHROPIC_ = "ANTHROPIC_"
185 changes: 185 additions & 0 deletions backend/spielberg/llm/anthropic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
from enum import Enum

from pydantic import Field, field_validator, FieldValidationInfo
from pydantic_settings import SettingsConfigDict

from spielberg.core.session import RoleTypes
from spielberg.llm.base import BaseLLM, BaseLLMConfig, LLMResponse, LLMResponseStatus
from spielberg.constants import (
LLMType,
EnvPrefix,
)


class AnthropicChatModel(str, Enum):
"""Enum for Anthropic Chat models"""

CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
CLAUDE_3_OPUS = "claude-3-opus-20240229"
CLAUDE_3_5_SONNET = "claude-3-5-sonnet-20240620"
CLAUDE_3_5_SONNET_LATEST = "claude-3-5-sonnet-20241022"


class AnthropicAIConfig(BaseLLMConfig):
"""AnthropicAI Config"""

model_config = SettingsConfigDict(
env_prefix=EnvPrefix.ANTHROPIC_,
extra="ignore",
)

llm_type: str = LLMType.ANTHROPIC
api_key: str = ""
api_base: str = ""
chat_model: str = Field(default=AnthropicChatModel.CLAUDE_3_5_SONNET)

@field_validator("api_key")
@classmethod
def validate_non_empty(cls, v, info: FieldValidationInfo):
if not v:
raise ValueError(
f"{info.field_name} must not be empty. please set {EnvPrefix.OPENAI_.value}{info.field_name.upper()} environment variable."
)
return v


class AnthropicAI(BaseLLM):
def __init__(self, config: AnthropicAIConfig = None):
"""
:param config: AnthropicAI Config
"""
if config is None:
config = AnthropicAIConfig()
super().__init__(config=config)
try:
import anthropic
except ImportError:
raise ImportError("Please install Anthropic python library.")

self.client = anthropic.Anthropic(api_key=self.api_key)

def _format_messages(self, messages: list):
system = ""
formatted_messages = []
if messages[0]["role"] == RoleTypes.system:
system = messages[0]["content"]
messages = messages[1:]

for message in messages:
if message["role"] == RoleTypes.assistant and message.get("tool_calls"):
tool = message["tool_calls"][0]["tool"]
formatted_messages.append(
{
"role": message["role"],
"content": [
{
"type": "text",
"text": message["content"],
},
{
"id": message["tool_calls"][0]["id"],
"type": message["tool_calls"][0]["type"],
"name": tool["name"],
"input": tool["arguments"],
},
],
}
)

elif message["role"] == RoleTypes.tool:
formatted_messages.append(
{
"role": RoleTypes.user,
"content": [
{
"type": "tool_result",
"tool_use_id": message["tool_call_id"],
"content": message["content"],
}
],
}
)
else:
formatted_messages.append(message)

return system, formatted_messages

def _format_tools(self, tools: list):
"""Format the tools to the format that Anthropic expects.

**Example**::

[
{
"name": "get_weather",
"description": "Get the current weather in a given location",
"input_schema": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
}
},
"required": ["location"],
},
}
]
"""
formatted_tools = []
for tool in tools:
formatted_tools.append(
{
"name": tool["name"],
"description": tool["description"],
"input_schema": tool["parameters"],
}
)
return formatted_tools

def chat_completions(
self, messages: list, tools: list = [], stop=None, response_format=None
):
"""Get completions for chat.

tools docs: https://docs.anthropic.com/en/docs/build-with-claude/tool-use
"""
system, messages = self._format_messages(messages)
params = {
"model": self.chat_model,
"messages": messages,
"system": system,
"max_tokens": self.max_tokens,
}
if tools:
params["tools"] = self._format_tools(tools)

try:
response = self.client.messages.create(**params)
except Exception as e:
raise e
return LLMResponse(content=f"Error: {e}")

return LLMResponse(
content=response.content[0].text,
tool_calls=[
{
"id": response.content[1].id,
"tool": {
"name": response.content[1].name,
"arguments": response.content[1].input,
},
"type": response.content[1].type,
}
]
if next(
(block for block in response.content if block.type == "tool_use"), None
)
is not None
else [],
finish_reason=response.stop_reason,
send_tokens=response.usage.input_tokens,
recv_tokens=response.usage.output_tokens,
total_tokens=(response.usage.input_tokens + response.usage.output_tokens),
status=LLMResponseStatus.SUCCESS,
)
8 changes: 0 additions & 8 deletions backend/spielberg/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ class BaseLLMConfig(BaseSettings):
:param str api_key: API key for the LLM.
:param str api_base: Base URL for the LLM API.
:param str chat_model: Model name for chat completions.
:param str text_model: Model name for text completions.
:param str temperature: Sampling temperature for completions.
:param float top_p: Top p sampling for completions.
:param int max_tokens: Maximum tokens to generate.
Expand All @@ -40,7 +39,6 @@ class BaseLLMConfig(BaseSettings):
api_key: str = ""
api_base: str = ""
chat_model: str = ""
text_model: str = ""
temperature: float = 0.9
top_p: float = 1
max_tokens: int = 4096
Expand All @@ -60,7 +58,6 @@ def __init__(self, config: BaseLLMConfig):
self.api_key = config.api_key
self.api_base = config.api_base
self.chat_model = config.chat_model
self.text_model = config.text_model
self.temperature = config.temperature
self.top_p = config.top_p
self.max_tokens = config.max_tokens
Expand All @@ -71,8 +68,3 @@ def __init__(self, config: BaseLLMConfig):
def chat_completions(self, messages: List[Dict], tools: List[Dict]) -> LLMResponse:
"""Abstract method for chat completions"""
pass

@abstractmethod
def text_completions(self, prompt: str) -> LLMResponse:
"""Abstract method for text completions"""
pass
32 changes: 4 additions & 28 deletions backend/spielberg/llm/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,6 @@ class OpenAIChatModel(str, Enum):
GPT4o_MINI = "gpt-4o-mini"


class OpenAITextModel(str, Enum):
"""Enum for OpenAI Text models"""

GPT4 = "gpt-4"
GPT4_32K = "gpt-4-32k"
GPT4_TURBO = "gpt-4-turbo"
GPT4o = "gpt-4o-2024-08-06"
GPT4o_MINI = "gpt-4o-mini"


class OpenaiConfig(BaseLLMConfig):
"""OpenAI Config"""

Expand All @@ -44,9 +34,7 @@ class OpenaiConfig(BaseLLMConfig):
api_key: str = ""
api_base: str = "https://api.openai.com/v1"
chat_model: str = Field(default=OpenAIChatModel.GPT4o)
text_model: str = Field(default=OpenAITextModel.GPT4_TURBO)
max_tokens: int = 4096
enable_langfuse: bool = False

@field_validator("api_key")
@classmethod
Expand All @@ -66,18 +54,10 @@ def __init__(self, config: OpenaiConfig = None):
if config is None:
config = OpenaiConfig()
super().__init__(config=config)

if self.enable_langfuse:
try:
from langfuse.openai import openai
except ImportError:
raise ImportError("Please install Langfuse and OpenAI python library.")
self.init_langfuse()
else:
try:
import openai
except ImportError:
raise ImportError("Please install OpenAI python library.")
try:
import openai
except ImportError:
raise ImportError("Please install OpenAI python library.")

self.client = openai.OpenAI(api_key=self.api_key, base_url=self.api_base)

Expand Down Expand Up @@ -206,7 +186,3 @@ def chat_completions(
total_tokens=response.usage.total_tokens,
status=LLMResponseStatus.SUCCESS,
)

def text_completions(self):
"""Get completions for text."""
raise NotImplementedError("Not implemented yet")
15 changes: 15 additions & 0 deletions docs/llm/anthropic.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
## AnthropicAI

AnthropicAI extends the base LLM and implements the Anthropic API.

### AnthropicAI Config

AnthropicAI Config is the configuration object for AnthropicAI. It is used to configure AnthropicAI and is passed to AnthropicAI when it is created.

::: spielberg.llm.anthropic.AnthropicAIConfig

### AnthropicAI Interface

AnthropicAI is the LLM used by the agents and tools. It is used to generate responses to messages.

::: spielberg.llm.anthropic.AnthropicAI
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ nav:
- 'Interface': 'llm/interface.md'
- Integrations:
- 'OpenAI': 'llm/openai.md'
- 'AnthropicAI': 'llm/anthropic.md'
- 'Database':
- 'Interface': 'database/interface.md'
- Integrations:
Expand Down