Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor plugin to use Messages API, add nextgen claude models #8

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 71 additions & 33 deletions llm_claude/__init__.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,103 @@
from typing import Optional

import click
import llm
from anthropic import AI_PROMPT, HUMAN_PROMPT, Anthropic
from anthropic import Anthropic
from pydantic import Field, field_validator


@llm.hookimpl
def register_models(register):
# https://docs.anthropic.com/claude/reference/selecting-a-model
# Family Latest major version Latest full version
# Claude Instant claude-instant-1 claude-instant-1.1
# Claude claude-2 claude-2.0
register(Claude("claude-instant-1"), aliases=("claude-instant",))
register(Claude("claude-2"), aliases=("claude",))
# https://docs.anthropic.com/claude/docs/models-overview
register(Claude("claude-instant-1.2"), aliases=("claude-instant",))
register(Claude("claude-2.0"))
register(Claude("claude-2.1"), aliases=("claude-2",))
register(Claude("claude-3-opus-20240229"), aliases=("claude", "claude-3", "opus", "claude-opus"))
register(Claude("claude-3-sonnet-20240229"), aliases=("sonnet", "claude-sonnet"))
# TODO haiku when it's released


class _ClaudeOptions(llm.Options):
max_tokens: Optional[int] = Field(
description="The maximum number of tokens to generate before stopping",
default=4_096,
)

temperature: Optional[float] = Field(
description="Amount of randomness injected into the response. Defaults to 1.0. Ranges from 0.0 to 1.0. Use temperature closer to 0.0 for analytical / multiple choice, and closer to 1.0 for creative and generative tasks. Note that even with temperature of 0.0, the results will not be fully deterministic.",
default=1.0,
)

top_p: Optional[float] = Field(
description="Use nucleus sampling. In nucleus sampling, we compute the cumulative distribution over all the options for each subsequent token in decreasing probability order and cut it off once it reaches a particular probability specified by top_p. You should either alter temperature or top_p, but not both. Recommended for advanced use cases only. You usually only need to use temperature.",
default=None,
)

top_k: Optional[int] = Field(
description="Only sample from the top K options for each subsequent token. Used to remove 'long tail' low probability responses. Recommended for advanced use cases only. You usually only need to use temperature.",
default=None,
)

@field_validator("max_tokens")
def validate_max_tokens(cls, max_tokens):
if not (0 < max_tokens <= 4_096):
raise ValueError("max_tokens must be in range 1-4,096")
return max_tokens

@field_validator("temperature")
def validate_temperature(cls, temperature):
if not (0.0 <= temperature <= 1.0):
raise ValueError("temperature must be in range 0.0-1.0")
return temperature

@field_validator("top_p")
def validate_top_p(cls, top_p):
if top_p is not None and not (0.0 <= top_p <= 1.0):
raise ValueError("top_p must be in range 0.0-1.0")
return top_p

@field_validator("top_k")
def validate_top_k(cls, top_k):
if top_k is not None and top_k <= 0:
raise ValueError("top_k must be a positive integer")
return top_k


class Claude(llm.Model):
needs_key = "claude"
key_env_var = "ANTHROPIC_API_KEY"
can_stream = True

class Options(llm.Options):
max_tokens_to_sample: Optional[int] = Field(
description="The maximum number of tokens to generate before stopping",
default=10_000,
)

@field_validator("max_tokens_to_sample")
def validate_length(cls, max_tokens_to_sample):
if not (0 < max_tokens_to_sample <= 1_000_000):
raise ValueError("max_tokens_to_sample must be in range 1-1,000,000")
return max_tokens_to_sample
class Options(_ClaudeOptions):
...

def __init__(self, model_id):
self.model_id = model_id

def generate_prompt_messages(self, prompt, conversation):
messages = []
if conversation:
for response in conversation.responses:
yield self.build_prompt(response.prompt.prompt, response.text())

yield self.build_prompt(prompt)

def build_prompt(self, human, ai=""):
return f"{HUMAN_PROMPT} {human}{AI_PROMPT}{ai}"
messages.append({"role": "user", "content": response.prompt.prompt})
messages.append({"role": "assistant", "content": response.text()})
messages.append({"role": "user", "content": prompt})
return messages

def execute(self, prompt, stream, response, conversation):
anthropic = Anthropic(api_key=self.get_key())

prompt_str = "".join(self.generate_prompt_messages(prompt.prompt, conversation))

completion = anthropic.completions.create(
messages = self.generate_prompt_messages(prompt.prompt, conversation)
completion = anthropic.messages.create(
model=self.model_id,
max_tokens_to_sample=prompt.options.max_tokens_to_sample,
prompt=prompt_str,
max_tokens=prompt.options.max_tokens,
messages=messages,
stream=stream,
)
if stream:
for comp in completion:
yield comp.completion
if hasattr(comp, "content_block"):
response = comp.content_block.text
yield response
elif hasattr(comp, "delta"):
if hasattr(comp.delta, "text"):
yield comp.delta.text
Comment on lines +95 to +100
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this block of code is pretty fast and loose, based on their documented SSE format here. I mostly just inspected the python objects coming out of the completion stream and used this to get the raw text deltas. there's probably a more principled way to do it, happy to rewrite if that's a concern

else:
yield completion.completion

Expand Down
61 changes: 33 additions & 28 deletions tests/test_llm_claude.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from unittest.mock import MagicMock, Mock, patch

import pytest
from anthropic import AI_PROMPT, HUMAN_PROMPT
from click.testing import CliRunner
from llm import Prompt, Response, get_model
from llm.cli import cli
Expand All @@ -15,52 +14,56 @@
def test_claude_response(mock_anthropic):
mock_response = MagicMock()
mock_response.completion = "hello"
mock_anthropic.return_value.completions.create.return_value.__iter__.return_value = [
mock_anthropic.return_value.messages.create.return_value.__iter__.return_value = [
mock_response
]
prompt = Prompt("hello", "", options=Claude.Options())
model = Claude("claude-2")
model.key = "key"
items = list(model.response(prompt))
model_response = model.response(prompt)
# breakpoint()
items = list(model_response)

mock_anthropic.return_value.completions.create.assert_called_with(
mock_anthropic.return_value.messages.create.assert_called_with(
model="claude-2",
max_tokens_to_sample=10_000,
prompt="\n\nHuman: hello\n\nAssistant:",
max_tokens=10_000,
messages=[{"role": "user", "content": "hello"}],
stream=True,
)
# breakpoint()

assert items == ["hello"]


@pytest.mark.parametrize("max_tokens_to_sample", (1, 500_000, 1_000_000))
@pytest.mark.parametrize("max_tokens", (1, 500_000, 1_000_000))
@patch("llm_claude.Anthropic")
def test_with_max_tokens_to_sample(mock_anthropic, max_tokens_to_sample):
def test_with_max_tokens(mock_anthropic, max_tokens):
mock_response = MagicMock()
mock_response.completion = "hello"
mock_anthropic.return_value.completions.create.return_value.__iter__.return_value = [
mock_anthropic.return_value.messages.create.return_value.__iter__.return_value = [
mock_response
]
prompt = Prompt(
"hello", "", options=Claude.Options(max_tokens_to_sample=max_tokens_to_sample)
"hello", "", options=Claude.Options(max_tokens=max_tokens)
)
model = Claude("claude-2")
model.key = "key"
items = list(model.response(prompt))
model_response = model.response(prompt)
items = list(model_response)

mock_anthropic.return_value.completions.create.assert_called_with(
mock_anthropic.return_value.messages.create.assert_called_with(
model="claude-2",
max_tokens_to_sample=max_tokens_to_sample,
prompt="\n\nHuman: hello\n\nAssistant:",
max_tokens=max_tokens,
messages=[{"role": "user", "content": "hello"}],
stream=True,
)

assert items == ["hello"]


@pytest.mark.parametrize("max_tokens_to_sample", (0, 1_000_001))
@pytest.mark.parametrize("max_tokens", (0, 1_000_001))
@patch("llm_claude.Anthropic")
def test_invalid_max_tokens_to_sample(mock_anthropic, max_tokens_to_sample):
def test_invalid_max_tokens(mock_anthropic, max_tokens):
runner = CliRunner()
result = runner.invoke(
cli,
Expand All @@ -69,14 +72,14 @@ def test_invalid_max_tokens_to_sample(mock_anthropic, max_tokens_to_sample):
"-m",
"claude",
"-o",
"max_tokens_to_sample",
max_tokens_to_sample,
"max_tokens",
max_tokens,
],
)
assert result.exit_code == 1
assert (
result.output
== "Error: max_tokens_to_sample\n Value error, max_tokens_to_sample must be in range 1-1,000,000\n"
== "Error: max_tokens\n Value error, max_tokens must be in range 1-1,000,000\n"
)


Expand All @@ -85,7 +88,7 @@ def test_invalid_max_tokens_to_sample(mock_anthropic, max_tokens_to_sample):
def test_claude_prompt(mock_anthropic):
mock_response = Mock()
mock_response.completion = "🐶🐶"
mock_anthropic.return_value.completions.create.return_value = mock_response
mock_anthropic.return_value.messages.create.return_value = mock_response
runner = CliRunner()
result = runner.invoke(cli, ["two dog emoji", "-m", "claude", "--no-stream"])
assert result.exit_code == 0, result.output
Expand All @@ -95,28 +98,31 @@ def test_claude_prompt(mock_anthropic):
@pytest.mark.parametrize(
"prompt, conversation_messages, expected",
(
("hello", [], [f"{HUMAN_PROMPT} hello{AI_PROMPT}"]),
("hello", [], [{"role": "user", "content": "hello"}]),
(
"hello 2",
[("user 1", "response 1")],
[
f"{HUMAN_PROMPT} user 1{AI_PROMPT}response 1",
f"{HUMAN_PROMPT} hello 2{AI_PROMPT}",
{"role": "user", "content": "user 1"},
{"role": "assistant", "content": "response 1"},
{"role": "user", "content": "hello 2"},
],
),
(
"hello 3",
[("user 1", "response 1"), ("user 2", "response 2")],
[
"\n\nHuman: user 1\n\nAssistant:response 1",
"\n\nHuman: user 2\n\nAssistant:response 2",
"\n\nHuman: hello 3\n\nAssistant:",
{"role": "user", "content": "user 1"},
{"role": "assistant", "content": "response 1"},
{"role": "user", "content": "user 2"},
{"role": "assistant", "content": "response 2"},
{"role": "user", "content": "hello 3"},
],
),
),
)
def test_generate_prompt_messages(
prompt: str, conversation_messages: List[Tuple[str, str]], expected: List[str]
prompt: str, conversation_messages: List[Tuple[str, str]], expected: List[dict]
):
model = get_model("claude")
conversation = None
Expand All @@ -131,6 +137,5 @@ def test_generate_prompt_messages(
response=prev_response,
)
)

messages = model.generate_prompt_messages(prompt, conversation)
assert list(messages) == expected
Loading