diff --git a/llmebench/models/OpenAI.py b/llmebench/models/OpenAI.py index 6e8b2e8d..2db8387f 100644 --- a/llmebench/models/OpenAI.py +++ b/llmebench/models/OpenAI.py @@ -88,9 +88,6 @@ def __init__( openai.api_type = api_type - if api_base: - openai.api_base = api_base - if api_type == "azure" and api_version is None: raise Exception( "API version must be provided as model config or environment variable (`AZURE_API_VERSION`)" @@ -132,7 +129,9 @@ def __init__( base_url=f"{api_base}/openai/deployments/{model_name}/", ) elif api_type == "openai": - self.client = OpenAI(api_key=api_key) + if not api_base: + api_base = "https://api.openai.com/v1" + self.client = OpenAI(base_url=api_base, api_key=api_key) else: raise Exception('API type must be one of "azure" or "openai"') diff --git a/llmebench/models/VLLM.py b/llmebench/models/VLLM.py index bbd820ae..0fcc3d15 100644 --- a/llmebench/models/VLLM.py +++ b/llmebench/models/VLLM.py @@ -2,31 +2,19 @@ import logging import os -import requests +from llmebench.models.OpenAI import OpenAIModel -from llmebench.models.model_base import ModelBase +class VLLMModel(OpenAIModel): + """ + VLLM Model interface. Can be used for models hosted using https://github.com/vllm-project/vllm. -class VLLMFailure(Exception): - """Exception class to map various failure types from the VLLM server""" - - def __init__(self, failure_type, failure_message): - self.type_mapping = { - "processing": "Model Inference failure", - "connection": "Failed to connect to BLOOM Petal server", - } - self.type = failure_type - self.failure_message = failure_message - - def __str__(self): - return ( - f"{self.type_mapping.get(self.type, self.type)}: \n {self.failure_message}" - ) - + Accepts all arguments used by `OpenAIModel`, and overrides the arguments listed + below with VLLM variables. -class VLLMModel(ModelBase): - """ - VLLM Model interface. + See the [https://docs.vllm.ai/en/latest/models/supported_models.html](model_support) + page in VLLM's documentation for supported models and instructions on extending + to custom models. Arguments --------- @@ -45,7 +33,9 @@ class VLLMModel(ModelBase): def __init__( self, - api_url=None, + api_base=None, + api_key=None, + model_name=None, timeout=20, temperature=0, top_p=0.95, @@ -53,12 +43,9 @@ def __init__( **kwargs, ): # API parameters - self.api_url = api_url or os.getenv("VLLM_API_URL") - self.user_session_id = os.getenv("USER_SESSION_ID") - if self.api_url is None: - raise Exception( - "API url must be provided as model config or environment variable (`VLLM_API_URL`)" - ) + self.api_base = api_base or os.getenv("VLLM_API_URL") + self.api_key = api_key or os.getenv("VLLM_API_KEY") + self.model_name = model_name or os.getenv("VLLM_MODEL") self.api_timeout = timeout # Parameters tolerance = 1e-7 @@ -71,59 +58,22 @@ def __init__( self.top_p = top_p self.max_tokens = max_tokens + if self.api_base is None: + raise Exception( + "API url must be provided as model config or environment variable (`VLLM_API_BASE`)" + ) + if self.api_key is None: + raise Exception( + "API key must be provided as model config or environment variable (`VLLM_API_KEY`)" + ) + if self.model_name is None: + raise Exception( + "Model name must be provided as model config or environment variable (`VLLM_MODEL`)" + ) + # checks for valid config settings) super(VLLMModel, self).__init__( - retry_exceptions=(TimeoutError, VLLMFailure), **kwargs + api_base=self.api_base, + api_key=self.api_key, + model_name=self.model_name, + **kwargs, ) - - def summarize_response(self, response): - """Returns the "outputs" key's value, if available""" - if "messages" in response: - return response["messages"] - - return response - - def prompt(self, processed_input): - """ - VLLM API Implementation - - Arguments - --------- - processed_input : dictionary - Must be a dictionary with one key "prompt", the value of which - must be a string. - - Returns - ------- - response : VLLM API response - Response from the VLLM server - - Raises - ------ - VLLMFailure : Exception - This method raises this exception if the server responded with a non-ok - response - """ - headers = {"Content-Type": "application/json"} - params = { - "messages": processed_input, - "max_tokens": self.max_tokens, - "temperature": self.temperature, - "user_session_id": self.user_session_id, - } - try: - response = requests.post( - self.api_url, headers=headers, json=params, timeout=self.api_timeout - ) - if response.status_code != 200: - raise VLLMFailure( - "processing", - "Processing failed with status: {}".format(response.status_code), - ) - - # Parse the final response - response_data = response.json() - logging.info(f"initial_response: {response_data}") - except VLLMFailure as e: - print("Error occurred:", e) - - return response_data diff --git a/tests/models/test_FastChatModel.py b/tests/models/test_FastChatModel.py index bde0f2fe..08e75072 100644 --- a/tests/models/test_FastChatModel.py +++ b/tests/models/test_FastChatModel.py @@ -37,7 +37,9 @@ def test_fastchat_config(self): ) self.assertEqual(openai.api_type, "openai") - self.assertEqual(openai.api_base, "llmebench.qcri.org") + self.assertEqual( + model.client.base_url.raw_path.decode("utf-8"), "llmebench.qcri.org/" + ) self.assertEqual(openai.api_key, "secret-key") self.assertEqual(model.model_params["model"], "private-model") @@ -54,7 +56,9 @@ def test_fastchat_config_env_var(self): model = FastChatModel() self.assertEqual(openai.api_type, "openai") - self.assertEqual(openai.api_base, "llmebench.qcri.org") + self.assertEqual( + model.client.base_url.raw_path.decode("utf-8"), "llmebench.qcri.org/" + ) self.assertEqual(openai.api_key, "secret-key") self.assertEqual(model.model_params["model"], "private-model") @@ -71,6 +75,8 @@ def test_fastchat_config_priority(self): model = FastChatModel(model_name="another-model") self.assertEqual(openai.api_type, "openai") - self.assertEqual(openai.api_base, "llmebench.qcri.org") + self.assertEqual( + model.client.base_url.raw_path.decode("utf-8"), "llmebench.qcri.org/" + ) self.assertEqual(openai.api_key, "secret-key") self.assertEqual(model.model_params["model"], "another-model")