diff --git a/llmebench/models/Anthropic.py b/llmebench/models/Anthropic.py index a5436de8..fc3bd0f4 100644 --- a/llmebench/models/Anthropic.py +++ b/llmebench/models/Anthropic.py @@ -7,23 +7,6 @@ from llmebench.models.model_base import ModelBase -class AnthropicFailure(Exception): - """Exception class to map various failure types from the AzureModel server""" - - def __init__(self, failure_type, failure_message): - self.type_mapping = { - "processing": "Model Inference failure", - "connection": "Failed to connect to the API endpoint", - } - self.type = failure_type - self.failure_message = failure_message - - def __str__(self): - return ( - f"{self.type_mapping.get(self.type, self.type)}: \n {self.failure_message}" - ) - - class AnthropicModel(ModelBase): """ Anthropic Model interface. @@ -43,7 +26,6 @@ class AnthropicModel(ModelBase): def __init__( self, - api_base=None, api_key=None, model_name=None, timeout=20, @@ -53,7 +35,6 @@ def __init__( **kwargs, ): # API parameters - self.api_base = api_base or os.getenv("ANTHROPIC_API_URL") self.api_key = api_key or os.getenv("ANTHROPIC_API_KEY") self.model_name = model_name or os.getenv("ANTHROPIC_MODEL") @@ -91,20 +72,18 @@ def __init__( self.client = anthropic.Anthropic(api_key=self.api_key) super(AnthropicModel, self).__init__( - retry_exceptions=(TimeoutError, AnthropicFailure), **kwargs + retry_exceptions=( + TimeoutError, + anthropic.APIStatusError, + anthropic.RateLimitError, + anthropic.APITimeoutError, + anthropic.APIConnectionError, + ), + **kwargs, ) def summarize_response(self, response): - """Returns the first reply from the "assistant", if available""" - if ( - "choices" in response - and isinstance(response["choices"], list) - and len(response["choices"]) > 0 - and "message" in response["choices"][0] - and "content" in response["choices"][0]["message"] - and response["choices"][0]["message"]["role"] == "assistant" - ): - return response["choices"][0]["message"]["content"] + """Returns the response""" return response @@ -114,13 +93,19 @@ def prompt(self, processed_input): Arguments --------- - processed_input : dictionary - Must be a dictionary with one key "prompt", the value of which - must be a string. + processed_input : list + Must be list of dictionaries, where each dictionary has two keys; + "role" defines a role in the chat (e.g. "user") and + "content" can be a list or message for that turn. If it is a list, it must contain objects matching one of the following: + - {"type": "text", "text": "....."} for text input/prompt + - {"type": "image", "source": {"type": "base64", "media_type": "image/jpeg", "data": "media_file"}} for image input + - the list can contain mix of the above formats for multimodal input (image + text) Returns ------- response : AnthropicModel API response + Response from the anthropic python library + """ response = self.client.messages.create( diff --git a/tests/models/test_AnthropicModel.py b/tests/models/test_AnthropicModel.py new file mode 100644 index 00000000..82afad60 --- /dev/null +++ b/tests/models/test_AnthropicModel.py @@ -0,0 +1,109 @@ +import unittest +from unittest.mock import patch + +from llmebench import Benchmark +from llmebench.models import AnthropicModel + +from llmebench.utils import is_fewshot_asset + + +class TestAssetsForAnthropicPrompts(unittest.TestCase): + @classmethod + def setUpClass(cls): + # Load the benchmark assets + benchmark = Benchmark(benchmark_dir="assets") + all_assets = benchmark.find_assets() + + # Filter out assets not using the Petals model + cls.assets = [ + asset + for asset in all_assets + if asset["config"]["model"] in [AnthropicModel] + ] + + def test_anthropic_prompts(self): + "Test if all assets using this model return data in an appropriate format for prompting" + # self.test_openai_prompts() + n_shots = 3 # Sample for few shot prompts + + for asset in self.assets: + with self.subTest(msg=asset["name"]): + config = asset["config"] + dataset_args = config.get("dataset_args", {}) + dataset_args["data_dir"] = "" + dataset = config["dataset"](**dataset_args) + data_sample = dataset.get_data_sample() + if is_fewshot_asset(config, asset["module"].prompt): + prompt = asset["module"].prompt( + data_sample["input"], + [data_sample for _ in range(n_shots)], + ) + else: + prompt = asset["module"].prompt(data_sample["input"]) + + self.assertIsInstance(prompt, list) + + for message in prompt: + self.assertIsInstance(message, dict) + self.assertIn("role", message) + self.assertIsInstance(message["role"], str) + self.assertIn("content", message) + self.assertIsInstance(message["content"], (str, list)) + + # Multi-modal input + if isinstance(message["content"], list): + for elem in message["content"]: + self.assertIsInstance(elem, dict) + self.assertIn("type", elem) + + if elem["type"] == "text": + self.assertIn("text", elem) + self.assertIsInstance(elem["text"], str) + elif elem["type"] == "image": + self.assertIn("source", elem) + self.assertIsInstance(elem["source"], dict) + + # Current support is for base64 + self.assertIn("type", elem["source"]) + self.assertIsInstance(elem["source"]["type"], str) + self.assertIn("data", elem["source"]) + self.assertIsInstance(elem["source"]["data"], str) + + self.assertIn("media_type", elem["source"]) + self.assertIsInstance(elem["source"]["media_type"], str) + + +class TestAnthropicConfig(unittest.TestCase): + def test_anthropic_config(self): + "Test if model config parameters passed as arguments are used" + model = AnthropicModel(api_key="secret-key", model_name="private-model") + self.assertEqual(model.api_key, "secret-key") + self.assertEqual(model.model, "private-model") + + @patch.dict( + "os.environ", + { + "ANTHROPIC_API_KEY": "secret-env-key", + "ANTHROPIC_MODEL": "private-env-model", + }, + ) + def test_anthropic_config_env_var(self): + "Test if model config parameters passed as environment variables are used" + model = AnthropicModel() + + self.assertEqual(model.api_key, "secret-env-key") + self.assertEqual(model.model, "private-env-model") + + @patch.dict( + "os.environ", + { + "ANTHROPIC_API_KEY": "secret-env-key", + "ANTHROPIC_MODEL": "private-env-model", + }, + ) + def test_anthropic_config_priority(self): + "Test if model config parameters passed as environment variables are used" + model = AnthropicModel(api_key="secret-key", model_name="private-model") + + self.assertEqual(model.api_key, "secret-key") + self.assertEqual(model.model, "private-model")