From dba98eb164a251453412c4be04e7c532d1ee5ee3 Mon Sep 17 00:00:00 2001 From: jvmncs Date: Tue, 5 Mar 2024 15:03:23 -0500 Subject: [PATCH] refactor plugin to use Messages API, add nextgen claude models --- llm_claude/__init__.py | 104 ++++++++++++++++++++++++++------------- tests/test_llm_claude.py | 61 ++++++++++++----------- 2 files changed, 104 insertions(+), 61 deletions(-) diff --git a/llm_claude/__init__.py b/llm_claude/__init__.py index 00a74c6..e560814 100644 --- a/llm_claude/__init__.py +++ b/llm_claude/__init__.py @@ -1,19 +1,64 @@ from typing import Optional - -import click import llm -from anthropic import AI_PROMPT, HUMAN_PROMPT, Anthropic +from anthropic import Anthropic from pydantic import Field, field_validator @llm.hookimpl def register_models(register): - # https://docs.anthropic.com/claude/reference/selecting-a-model - # Family Latest major version Latest full version - # Claude Instant claude-instant-1 claude-instant-1.1 - # Claude claude-2 claude-2.0 - register(Claude("claude-instant-1"), aliases=("claude-instant",)) - register(Claude("claude-2"), aliases=("claude",)) + # https://docs.anthropic.com/claude/docs/models-overview + register(Claude("claude-instant-1.2"), aliases=("claude-instant",)) + register(Claude("claude-2.0")) + register(Claude("claude-2.1"), aliases=("claude-2",)) + register(Claude("claude-3-opus-20240229"), aliases=("claude", "claude-3", "opus", "claude-opus")) + register(Claude("claude-3-sonnet-20240229"), aliases=("sonnet", "claude-sonnet")) + # TODO haiku when it's released + + +class _ClaudeOptions(llm.Options): + max_tokens: Optional[int] = Field( + description="The maximum number of tokens to generate before stopping", + default=4_096, + ) + + temperature: Optional[float] = Field( + description="Amount of randomness injected into the response. Defaults to 1.0. Ranges from 0.0 to 1.0. Use temperature closer to 0.0 for analytical / multiple choice, and closer to 1.0 for creative and generative tasks. Note that even with temperature of 0.0, the results will not be fully deterministic.", + default=1.0, + ) + + top_p: Optional[float] = Field( + description="Use nucleus sampling. In nucleus sampling, we compute the cumulative distribution over all the options for each subsequent token in decreasing probability order and cut it off once it reaches a particular probability specified by top_p. You should either alter temperature or top_p, but not both. Recommended for advanced use cases only. You usually only need to use temperature.", + default=None, + ) + + top_k: Optional[int] = Field( + description="Only sample from the top K options for each subsequent token. Used to remove 'long tail' low probability responses. Recommended for advanced use cases only. You usually only need to use temperature.", + default=None, + ) + + @field_validator("max_tokens") + def validate_max_tokens(cls, max_tokens): + if not (0 < max_tokens <= 4_096): + raise ValueError("max_tokens must be in range 1-4,096") + return max_tokens + + @field_validator("temperature") + def validate_temperature(cls, temperature): + if not (0.0 <= temperature <= 1.0): + raise ValueError("temperature must be in range 0.0-1.0") + return temperature + + @field_validator("top_p") + def validate_top_p(cls, top_p): + if top_p is not None and not (0.0 <= top_p <= 1.0): + raise ValueError("top_p must be in range 0.0-1.0") + return top_p + + @field_validator("top_k") + def validate_top_k(cls, top_k): + if top_k is not None and top_k <= 0: + raise ValueError("top_k must be a positive integer") + return top_k class Claude(llm.Model): @@ -21,45 +66,38 @@ class Claude(llm.Model): key_env_var = "ANTHROPIC_API_KEY" can_stream = True - class Options(llm.Options): - max_tokens_to_sample: Optional[int] = Field( - description="The maximum number of tokens to generate before stopping", - default=10_000, - ) - - @field_validator("max_tokens_to_sample") - def validate_length(cls, max_tokens_to_sample): - if not (0 < max_tokens_to_sample <= 1_000_000): - raise ValueError("max_tokens_to_sample must be in range 1-1,000,000") - return max_tokens_to_sample + class Options(_ClaudeOptions): + ... def __init__(self, model_id): self.model_id = model_id def generate_prompt_messages(self, prompt, conversation): + messages = [] if conversation: for response in conversation.responses: - yield self.build_prompt(response.prompt.prompt, response.text()) - - yield self.build_prompt(prompt) - - def build_prompt(self, human, ai=""): - return f"{HUMAN_PROMPT} {human}{AI_PROMPT}{ai}" + messages.append({"role": "user", "content": response.prompt.prompt}) + messages.append({"role": "assistant", "content": response.text()}) + messages.append({"role": "user", "content": prompt}) + return messages def execute(self, prompt, stream, response, conversation): anthropic = Anthropic(api_key=self.get_key()) - - prompt_str = "".join(self.generate_prompt_messages(prompt.prompt, conversation)) - - completion = anthropic.completions.create( + messages = self.generate_prompt_messages(prompt.prompt, conversation) + completion = anthropic.messages.create( model=self.model_id, - max_tokens_to_sample=prompt.options.max_tokens_to_sample, - prompt=prompt_str, + max_tokens=prompt.options.max_tokens, + messages=messages, stream=stream, ) if stream: for comp in completion: - yield comp.completion + if hasattr(comp, "content_block"): + response = comp.content_block.text + yield response + elif hasattr(comp, "delta"): + if hasattr(comp.delta, "text"): + yield comp.delta.text else: yield completion.completion diff --git a/tests/test_llm_claude.py b/tests/test_llm_claude.py index 52b1fa9..2df443b 100644 --- a/tests/test_llm_claude.py +++ b/tests/test_llm_claude.py @@ -3,7 +3,6 @@ from unittest.mock import MagicMock, Mock, patch import pytest -from anthropic import AI_PROMPT, HUMAN_PROMPT from click.testing import CliRunner from llm import Prompt, Response, get_model from llm.cli import cli @@ -15,52 +14,56 @@ def test_claude_response(mock_anthropic): mock_response = MagicMock() mock_response.completion = "hello" - mock_anthropic.return_value.completions.create.return_value.__iter__.return_value = [ + mock_anthropic.return_value.messages.create.return_value.__iter__.return_value = [ mock_response ] prompt = Prompt("hello", "", options=Claude.Options()) model = Claude("claude-2") model.key = "key" - items = list(model.response(prompt)) + model_response = model.response(prompt) + # breakpoint() + items = list(model_response) - mock_anthropic.return_value.completions.create.assert_called_with( + mock_anthropic.return_value.messages.create.assert_called_with( model="claude-2", - max_tokens_to_sample=10_000, - prompt="\n\nHuman: hello\n\nAssistant:", + max_tokens=10_000, + messages=[{"role": "user", "content": "hello"}], stream=True, ) + # breakpoint() assert items == ["hello"] -@pytest.mark.parametrize("max_tokens_to_sample", (1, 500_000, 1_000_000)) +@pytest.mark.parametrize("max_tokens", (1, 500_000, 1_000_000)) @patch("llm_claude.Anthropic") -def test_with_max_tokens_to_sample(mock_anthropic, max_tokens_to_sample): +def test_with_max_tokens(mock_anthropic, max_tokens): mock_response = MagicMock() mock_response.completion = "hello" - mock_anthropic.return_value.completions.create.return_value.__iter__.return_value = [ + mock_anthropic.return_value.messages.create.return_value.__iter__.return_value = [ mock_response ] prompt = Prompt( - "hello", "", options=Claude.Options(max_tokens_to_sample=max_tokens_to_sample) + "hello", "", options=Claude.Options(max_tokens=max_tokens) ) model = Claude("claude-2") model.key = "key" - items = list(model.response(prompt)) + model_response = model.response(prompt) + items = list(model_response) - mock_anthropic.return_value.completions.create.assert_called_with( + mock_anthropic.return_value.messages.create.assert_called_with( model="claude-2", - max_tokens_to_sample=max_tokens_to_sample, - prompt="\n\nHuman: hello\n\nAssistant:", + max_tokens=max_tokens, + messages=[{"role": "user", "content": "hello"}], stream=True, ) assert items == ["hello"] -@pytest.mark.parametrize("max_tokens_to_sample", (0, 1_000_001)) +@pytest.mark.parametrize("max_tokens", (0, 1_000_001)) @patch("llm_claude.Anthropic") -def test_invalid_max_tokens_to_sample(mock_anthropic, max_tokens_to_sample): +def test_invalid_max_tokens(mock_anthropic, max_tokens): runner = CliRunner() result = runner.invoke( cli, @@ -69,14 +72,14 @@ def test_invalid_max_tokens_to_sample(mock_anthropic, max_tokens_to_sample): "-m", "claude", "-o", - "max_tokens_to_sample", - max_tokens_to_sample, + "max_tokens", + max_tokens, ], ) assert result.exit_code == 1 assert ( result.output - == "Error: max_tokens_to_sample\n Value error, max_tokens_to_sample must be in range 1-1,000,000\n" + == "Error: max_tokens\n Value error, max_tokens must be in range 1-1,000,000\n" ) @@ -85,7 +88,7 @@ def test_invalid_max_tokens_to_sample(mock_anthropic, max_tokens_to_sample): def test_claude_prompt(mock_anthropic): mock_response = Mock() mock_response.completion = "🐶🐶" - mock_anthropic.return_value.completions.create.return_value = mock_response + mock_anthropic.return_value.messages.create.return_value = mock_response runner = CliRunner() result = runner.invoke(cli, ["two dog emoji", "-m", "claude", "--no-stream"]) assert result.exit_code == 0, result.output @@ -95,28 +98,31 @@ def test_claude_prompt(mock_anthropic): @pytest.mark.parametrize( "prompt, conversation_messages, expected", ( - ("hello", [], [f"{HUMAN_PROMPT} hello{AI_PROMPT}"]), + ("hello", [], [{"role": "user", "content": "hello"}]), ( "hello 2", [("user 1", "response 1")], [ - f"{HUMAN_PROMPT} user 1{AI_PROMPT}response 1", - f"{HUMAN_PROMPT} hello 2{AI_PROMPT}", + {"role": "user", "content": "user 1"}, + {"role": "assistant", "content": "response 1"}, + {"role": "user", "content": "hello 2"}, ], ), ( "hello 3", [("user 1", "response 1"), ("user 2", "response 2")], [ - "\n\nHuman: user 1\n\nAssistant:response 1", - "\n\nHuman: user 2\n\nAssistant:response 2", - "\n\nHuman: hello 3\n\nAssistant:", + {"role": "user", "content": "user 1"}, + {"role": "assistant", "content": "response 1"}, + {"role": "user", "content": "user 2"}, + {"role": "assistant", "content": "response 2"}, + {"role": "user", "content": "hello 3"}, ], ), ), ) def test_generate_prompt_messages( - prompt: str, conversation_messages: List[Tuple[str, str]], expected: List[str] + prompt: str, conversation_messages: List[Tuple[str, str]], expected: List[dict] ): model = get_model("claude") conversation = None @@ -131,6 +137,5 @@ def test_generate_prompt_messages( response=prev_response, ) ) - messages = model.generate_prompt_messages(prompt, conversation) assert list(messages) == expected