Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/anthropic #326

Merged
merged 13 commits into from
Aug 5, 2024
51 changes: 18 additions & 33 deletions llmebench/models/Anthropic.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@firojalam Can you please update the docstring under def prompt(self, processed_input): to indicate what the required format is (and also maybe talk about multimodal?).

Right now, there are only two ways a user knows what their asset should return, either this doc string or a failing test. The test you have already done, but its nice to have this docstring reflect the correct format as well. Feel free to take a look at the other model docstrings as well.

Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,6 @@
from llmebench.models.model_base import ModelBase


class AnthropicFailure(Exception):
"""Exception class to map various failure types from the AzureModel server"""

def __init__(self, failure_type, failure_message):
self.type_mapping = {
"processing": "Model Inference failure",
"connection": "Failed to connect to the API endpoint",
}
self.type = failure_type
self.failure_message = failure_message

def __str__(self):
return (
f"{self.type_mapping.get(self.type, self.type)}: \n {self.failure_message}"
)


class AnthropicModel(ModelBase):
"""
Anthropic Model interface.
Expand All @@ -43,7 +26,6 @@ class AnthropicModel(ModelBase):

def __init__(
self,
api_base=None,
api_key=None,
model_name=None,
timeout=20,
Expand All @@ -53,7 +35,6 @@ def __init__(
**kwargs,
):
# API parameters
self.api_base = api_base or os.getenv("ANTHROPIC_API_URL")
self.api_key = api_key or os.getenv("ANTHROPIC_API_KEY")
self.model_name = model_name or os.getenv("ANTHROPIC_MODEL")

Expand Down Expand Up @@ -91,20 +72,18 @@ def __init__(
self.client = anthropic.Anthropic(api_key=self.api_key)

super(AnthropicModel, self).__init__(
retry_exceptions=(TimeoutError, AnthropicFailure), **kwargs
retry_exceptions=(
TimeoutError,
anthropic.APIStatusError,
anthropic.RateLimitError,
anthropic.APITimeoutError,
anthropic.APIConnectionError,
),
**kwargs,
)

def summarize_response(self, response):
"""Returns the first reply from the "assistant", if available"""
if (
"choices" in response
and isinstance(response["choices"], list)
and len(response["choices"]) > 0
and "message" in response["choices"][0]
and "content" in response["choices"][0]["message"]
and response["choices"][0]["message"]["role"] == "assistant"
):
return response["choices"][0]["message"]["content"]
"""Returns the response"""

return response

Expand All @@ -114,13 +93,19 @@ def prompt(self, processed_input):

Arguments
---------
processed_input : dictionary
Must be a dictionary with one key "prompt", the value of which
must be a string.
processed_input : list
Must be list of dictionaries, where each dictionary has two keys;
"role" defines a role in the chat (e.g. "user") and
"content" can be a list or message for that turn. If it is a list, it must contain objects matching one of the following:
- {"type": "text", "text": "....."} for text input/prompt
- {"type": "image", "source": {"type": "base64", "media_type": "image/jpeg", "data": "media_file"}} for image input
- the list can contain mix of the above formats for multimodal input (image + text)

Returns
-------
response : AnthropicModel API response
Response from the anthropic python library

"""

response = self.client.messages.create(
Expand Down
109 changes: 109 additions & 0 deletions tests/models/test_AnthropicModel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import unittest
from unittest.mock import patch

from llmebench import Benchmark
from llmebench.models import AnthropicModel

from llmebench.utils import is_fewshot_asset


class TestAssetsForAnthropicPrompts(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Load the benchmark assets
benchmark = Benchmark(benchmark_dir="assets")
all_assets = benchmark.find_assets()

# Filter out assets not using the Petals model
cls.assets = [
asset
for asset in all_assets
if asset["config"]["model"] in [AnthropicModel]
]

def test_anthropic_prompts(self):
"Test if all assets using this model return data in an appropriate format for prompting"
# self.test_openai_prompts()
n_shots = 3 # Sample for few shot prompts

for asset in self.assets:
with self.subTest(msg=asset["name"]):
config = asset["config"]
dataset_args = config.get("dataset_args", {})
dataset_args["data_dir"] = ""
dataset = config["dataset"](**dataset_args)
data_sample = dataset.get_data_sample()
if is_fewshot_asset(config, asset["module"].prompt):
prompt = asset["module"].prompt(
data_sample["input"],
[data_sample for _ in range(n_shots)],
)
else:
prompt = asset["module"].prompt(data_sample["input"])

self.assertIsInstance(prompt, list)

for message in prompt:
self.assertIsInstance(message, dict)
self.assertIn("role", message)
self.assertIsInstance(message["role"], str)
self.assertIn("content", message)
self.assertIsInstance(message["content"], (str, list))

# Multi-modal input
if isinstance(message["content"], list):
for elem in message["content"]:
self.assertIsInstance(elem, dict)
self.assertIn("type", elem)

if elem["type"] == "text":
self.assertIn("text", elem)
self.assertIsInstance(elem["text"], str)
elif elem["type"] == "image":
self.assertIn("source", elem)
self.assertIsInstance(elem["source"], dict)

# Current support is for base64
self.assertIn("type", elem["source"])
self.assertIsInstance(elem["source"]["type"], str)
self.assertIn("data", elem["source"])
self.assertIsInstance(elem["source"]["data"], str)

self.assertIn("media_type", elem["source"])
self.assertIsInstance(elem["source"]["media_type"], str)


class TestAnthropicConfig(unittest.TestCase):
def test_anthropic_config(self):
"Test if model config parameters passed as arguments are used"
model = AnthropicModel(api_key="secret-key", model_name="private-model")
self.assertEqual(model.api_key, "secret-key")
self.assertEqual(model.model, "private-model")

@patch.dict(
"os.environ",
{
"ANTHROPIC_API_KEY": "secret-env-key",
"ANTHROPIC_MODEL": "private-env-model",
},
)
def test_anthropic_config_env_var(self):
"Test if model config parameters passed as environment variables are used"
model = AnthropicModel()

self.assertEqual(model.api_key, "secret-env-key")
self.assertEqual(model.model, "private-env-model")

@patch.dict(
"os.environ",
{
"ANTHROPIC_API_KEY": "secret-env-key",
"ANTHROPIC_MODEL": "private-env-model",
},
)
def test_anthropic_config_priority(self):
"Test if model config parameters passed as environment variables are used"
model = AnthropicModel(api_key="secret-key", model_name="private-model")

self.assertEqual(model.api_key, "secret-key")
self.assertEqual(model.model, "private-model")
Loading