Skip to content

Commit

Permalink
Add tests for AnthropicModel (#326)
Browse files Browse the repository at this point in the history
* anthropic model and asset file added

* fixed formatting issue

* added anthropic package

* fixed calling ModelBase

* updated name and info

* Updated cases for exception and test

* Clean up Anthropic error handling by using built-in exceptions

* Remove unused `api_base` in AnthropicModel

* Expand config tests

* updated with doc string

* Update image input tests

---------

Co-authored-by: Fahim Imaduddin Dalvi <faimaduddin@hbku.edu.qa>
  • Loading branch information
firojalam and fdalvi authored Aug 5, 2024
1 parent 0b16a84 commit be93fe2
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 33 deletions.
51 changes: 18 additions & 33 deletions llmebench/models/Anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,6 @@
from llmebench.models.model_base import ModelBase


class AnthropicFailure(Exception):
"""Exception class to map various failure types from the AzureModel server"""

def __init__(self, failure_type, failure_message):
self.type_mapping = {
"processing": "Model Inference failure",
"connection": "Failed to connect to the API endpoint",
}
self.type = failure_type
self.failure_message = failure_message

def __str__(self):
return (
f"{self.type_mapping.get(self.type, self.type)}: \n {self.failure_message}"
)


class AnthropicModel(ModelBase):
"""
Anthropic Model interface.
Expand All @@ -43,7 +26,6 @@ class AnthropicModel(ModelBase):

def __init__(
self,
api_base=None,
api_key=None,
model_name=None,
timeout=20,
Expand All @@ -53,7 +35,6 @@ def __init__(
**kwargs,
):
# API parameters
self.api_base = api_base or os.getenv("ANTHROPIC_API_URL")
self.api_key = api_key or os.getenv("ANTHROPIC_API_KEY")
self.model_name = model_name or os.getenv("ANTHROPIC_MODEL")

Expand Down Expand Up @@ -91,20 +72,18 @@ def __init__(
self.client = anthropic.Anthropic(api_key=self.api_key)

super(AnthropicModel, self).__init__(
retry_exceptions=(TimeoutError, AnthropicFailure), **kwargs
retry_exceptions=(
TimeoutError,
anthropic.APIStatusError,
anthropic.RateLimitError,
anthropic.APITimeoutError,
anthropic.APIConnectionError,
),
**kwargs,
)

def summarize_response(self, response):
"""Returns the first reply from the "assistant", if available"""
if (
"choices" in response
and isinstance(response["choices"], list)
and len(response["choices"]) > 0
and "message" in response["choices"][0]
and "content" in response["choices"][0]["message"]
and response["choices"][0]["message"]["role"] == "assistant"
):
return response["choices"][0]["message"]["content"]
"""Returns the response"""

return response

Expand All @@ -114,13 +93,19 @@ def prompt(self, processed_input):
Arguments
---------
processed_input : dictionary
Must be a dictionary with one key "prompt", the value of which
must be a string.
processed_input : list
Must be list of dictionaries, where each dictionary has two keys;
"role" defines a role in the chat (e.g. "user") and
"content" can be a list or message for that turn. If it is a list, it must contain objects matching one of the following:
- {"type": "text", "text": "....."} for text input/prompt
- {"type": "image", "source": {"type": "base64", "media_type": "image/jpeg", "data": "media_file"}} for image input
- the list can contain mix of the above formats for multimodal input (image + text)
Returns
-------
response : AnthropicModel API response
Response from the anthropic python library
"""

response = self.client.messages.create(
Expand Down
109 changes: 109 additions & 0 deletions tests/models/test_AnthropicModel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import unittest
from unittest.mock import patch

from llmebench import Benchmark
from llmebench.models import AnthropicModel

from llmebench.utils import is_fewshot_asset


class TestAssetsForAnthropicPrompts(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Load the benchmark assets
benchmark = Benchmark(benchmark_dir="assets")
all_assets = benchmark.find_assets()

# Filter out assets not using the Petals model
cls.assets = [
asset
for asset in all_assets
if asset["config"]["model"] in [AnthropicModel]
]

def test_anthropic_prompts(self):
"Test if all assets using this model return data in an appropriate format for prompting"
# self.test_openai_prompts()
n_shots = 3 # Sample for few shot prompts

for asset in self.assets:
with self.subTest(msg=asset["name"]):
config = asset["config"]
dataset_args = config.get("dataset_args", {})
dataset_args["data_dir"] = ""
dataset = config["dataset"](**dataset_args)
data_sample = dataset.get_data_sample()
if is_fewshot_asset(config, asset["module"].prompt):
prompt = asset["module"].prompt(
data_sample["input"],
[data_sample for _ in range(n_shots)],
)
else:
prompt = asset["module"].prompt(data_sample["input"])

self.assertIsInstance(prompt, list)

for message in prompt:
self.assertIsInstance(message, dict)
self.assertIn("role", message)
self.assertIsInstance(message["role"], str)
self.assertIn("content", message)
self.assertIsInstance(message["content"], (str, list))

# Multi-modal input
if isinstance(message["content"], list):
for elem in message["content"]:
self.assertIsInstance(elem, dict)
self.assertIn("type", elem)

if elem["type"] == "text":
self.assertIn("text", elem)
self.assertIsInstance(elem["text"], str)
elif elem["type"] == "image":
self.assertIn("source", elem)
self.assertIsInstance(elem["source"], dict)

# Current support is for base64
self.assertIn("type", elem["source"])
self.assertIsInstance(elem["source"]["type"], str)
self.assertIn("data", elem["source"])
self.assertIsInstance(elem["source"]["data"], str)

self.assertIn("media_type", elem["source"])
self.assertIsInstance(elem["source"]["media_type"], str)


class TestAnthropicConfig(unittest.TestCase):
def test_anthropic_config(self):
"Test if model config parameters passed as arguments are used"
model = AnthropicModel(api_key="secret-key", model_name="private-model")
self.assertEqual(model.api_key, "secret-key")
self.assertEqual(model.model, "private-model")

@patch.dict(
"os.environ",
{
"ANTHROPIC_API_KEY": "secret-env-key",
"ANTHROPIC_MODEL": "private-env-model",
},
)
def test_anthropic_config_env_var(self):
"Test if model config parameters passed as environment variables are used"
model = AnthropicModel()

self.assertEqual(model.api_key, "secret-env-key")
self.assertEqual(model.model, "private-env-model")

@patch.dict(
"os.environ",
{
"ANTHROPIC_API_KEY": "secret-env-key",
"ANTHROPIC_MODEL": "private-env-model",
},
)
def test_anthropic_config_priority(self):
"Test if model config parameters passed as environment variables are used"
model = AnthropicModel(api_key="secret-key", model_name="private-model")

self.assertEqual(model.api_key, "secret-key")
self.assertEqual(model.model, "private-model")

0 comments on commit be93fe2

Please sign in to comment.