Skip to content
This repository has been archived by the owner on Feb 2, 2025. It is now read-only.

Async models #26

Merged
merged 3 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,15 @@ To run the tests:
```bash
pytest
```

This project uses [pytest-recording](https://github.com/kiwicom/pytest-recording) to record Anthropic API responses for the tests.

If you add a new test that calls the API you can capture the API response like this:
```bash
PYTEST_ANTHROPIC_API_KEY="$(llm keys get claude)" pytest --record-mode once
```
You will need to have stored a valid Anthropic API key using this command first:
```bash
llm keys set claude
# Paste key here
```
87 changes: 69 additions & 18 deletions llm_claude_3.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from anthropic import Anthropic
from anthropic import Anthropic, AsyncAnthropic
import llm
from pydantic import Field, field_validator, model_validator
from typing import Optional, List
Expand All @@ -7,19 +7,42 @@
@llm.hookimpl
def register_models(register):
# https://docs.anthropic.com/claude/docs/models-overview
register(ClaudeMessages("claude-3-opus-20240229"))
register(ClaudeMessages("claude-3-opus-latest"), aliases=("claude-3-opus",))
register(ClaudeMessages("claude-3-sonnet-20240229"), aliases=("claude-3-sonnet",))
register(ClaudeMessages("claude-3-haiku-20240307"), aliases=("claude-3-haiku",))
register(
ClaudeMessages("claude-3-opus-20240229"),
AsyncClaudeMessages("claude-3-opus-20240229"),
),
register(
ClaudeMessages("claude-3-opus-latest"),
AsyncClaudeMessages("claude-3-opus-latest"),
aliases=("claude-3-opus",),
)
register(
ClaudeMessages("claude-3-sonnet-20240229"),
AsyncClaudeMessages("claude-3-sonnet-20240229"),
aliases=("claude-3-sonnet",),
)
register(
ClaudeMessages("claude-3-haiku-20240307"),
AsyncClaudeMessages("claude-3-haiku-20240307"),
aliases=("claude-3-haiku",),
)
# 3.5 models
register(ClaudeMessagesLong("claude-3-5-sonnet-20240620"))
register(ClaudeMessagesLong("claude-3-5-sonnet-20241022", supports_pdf=True)),
register(
ClaudeMessagesLong("claude-3-5-sonnet-20240620"),
AsyncClaudeMessagesLong("claude-3-5-sonnet-20240620"),
)
register(
ClaudeMessagesLong("claude-3-5-sonnet-20241022", supports_pdf=True),
AsyncClaudeMessagesLong("claude-3-5-sonnet-20241022", supports_pdf=True),
)
register(
ClaudeMessagesLong("claude-3-5-sonnet-latest", supports_pdf=True),
AsyncClaudeMessagesLong("claude-3-5-sonnet-latest", supports_pdf=True),
aliases=("claude-3.5-sonnet", "claude-3.5-sonnet-latest"),
)
register(
ClaudeMessagesLong("claude-3-5-haiku-latest", supports_images=False),
AsyncClaudeMessagesLong("claude-3-5-haiku-latest", supports_images=False),
aliases=("claude-3.5-haiku",),
)

Expand Down Expand Up @@ -86,7 +109,13 @@ def validate_temperature_top_p(self):
return self


class ClaudeMessages(llm.Model):
long_field = Field(
description="The maximum number of tokens to generate before stopping",
default=4_096 * 2,
)


class _Shared:
needs_key = "claude"
key_env_var = "ANTHROPIC_API_KEY"
can_stream = True
Expand Down Expand Up @@ -178,9 +207,7 @@ def build_messages(self, prompt, conversation) -> List[dict]:
messages.append({"role": "user", "content": prompt.prompt})
return messages

def execute(self, prompt, stream, response, conversation):
client = Anthropic(api_key=self.get_key())

def build_kwargs(self, prompt, conversation):
kwargs = {
"model": self.claude_model_id,
"messages": self.build_messages(prompt, conversation),
Expand All @@ -202,7 +229,17 @@ def execute(self, prompt, stream, response, conversation):

if self.extra_headers:
kwargs["extra_headers"] = self.extra_headers
return kwargs

def __str__(self):
return "Anthropic Messages: {}".format(self.model_id)


class ClaudeMessages(_Shared, llm.Model):

def execute(self, prompt, stream, response, conversation):
client = Anthropic(api_key=self.get_key())
kwargs = self.build_kwargs(prompt, conversation)
if stream:
with client.messages.stream(**kwargs) as stream:
for text in stream.text_stream:
Expand All @@ -214,13 +251,27 @@ def execute(self, prompt, stream, response, conversation):
yield completion.content[0].text
response.response_json = completion.model_dump()

def __str__(self):
return "Anthropic Messages: {}".format(self.model_id)


class ClaudeMessagesLong(ClaudeMessages):
class Options(ClaudeOptions):
max_tokens: Optional[int] = Field(
description="The maximum number of tokens to generate before stopping",
default=4_096 * 2,
)
max_tokens: Optional[int] = long_field


class AsyncClaudeMessages(_Shared, llm.AsyncModel):
async def execute(self, prompt, stream, response, conversation):
client = AsyncAnthropic(api_key=self.get_key())
kwargs = self.build_kwargs(prompt, conversation)
if stream:
async with client.messages.stream(**kwargs) as stream_obj:
async for text in stream_obj.text_stream:
yield text
response.response_json = (await stream_obj.get_final_message()).model_dump()
else:
completion = await client.messages.create(**kwargs)
yield completion.content[0].text
response.response_json = completion.model_dump()


class AsyncClaudeMessagesLong(AsyncClaudeMessages):
class Options(ClaudeOptions):
max_tokens: Optional[int] = long_field
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ classifiers = [
"License :: OSI Approved :: Apache Software License"
]
dependencies = [
"llm>=0.17",
"llm>=0.18a0",
"anthropic>=0.39.0",
]

Expand All @@ -23,4 +23,4 @@ CI = "https://github.com/simonw/llm-claude-3/actions"
claude_3 = "llm_claude_3"

[project.optional-dependencies]
test = ["pytest", "pytest-recording"]
test = ["pytest", "pytest-recording", "pytest-asyncio"]
Loading