From 1a78e92d0c4946fdeccebe7271d2f57e5dfafecc Mon Sep 17 00:00:00 2001 From: Fahim Imaduddin Dalvi Date: Thu, 11 Jul 2024 09:30:39 +0300 Subject: [PATCH 1/4] VLLM updates from Firoj --- llmebench/models/VLLM.py | 109 ++++++++++++++++++++++++++++++++++- llmebench/models/__init__.py | 2 +- 2 files changed, 108 insertions(+), 3 deletions(-) diff --git a/llmebench/models/VLLM.py b/llmebench/models/VLLM.py index bbd820ae..ca8363ce 100644 --- a/llmebench/models/VLLM.py +++ b/llmebench/models/VLLM.py @@ -43,6 +43,112 @@ class VLLMModel(ModelBase): Maximum number of tokens to pass to the model. Defaults to 1512 """ + def __init__( + self, + api_url=None, + timeout=20, + temperature=0, + top_p=0.95, + max_tokens=1512, + **kwargs, + ): + # API parameters + self.api_url = api_url or os.getenv("VLLM_API_URL") + self.user_session_id = os.getenv("USER_SESSION_ID") + self.model = os.getenv("VLLM_MODEL") + if self.api_url is None: + raise Exception( + "API url must be provided as model config or environment variable (`VLLM_API_URL`)" + ) + self.api_timeout = timeout + # Parameters + tolerance = 1e-7 + self.temperature = temperature + if self.temperature < tolerance: + # Currently, the model inference fails if temperature + # is exactly 0, so we nudge it slightly to work around + # the issue + self.temperature += tolerance + self.top_p = top_p + self.max_tokens = max_tokens + + super(VLLMModel, self).__init__( + retry_exceptions=(TimeoutError, VLLMFailure), **kwargs + ) + + def summarize_response(self, response): + """Returns the "outputs" key's value, if available""" + if "messages" in response: + return response["messages"] + + return response + + def prompt(self, processed_input): + """ + VLLM API Implementation + + Arguments + --------- + processed_input : dictionary + Must be a dictionary with one key "prompt", the value of which + must be a string. + + Returns + ------- + response : VLLM API response + Response from the VLLM server + + Raises + ------ + VLLMFailure : Exception + This method raises this exception if the server responded with a non-ok + response + """ + headers = {"Content-Type": "application/json"} + data = { + "model": self.model, + "messages": processed_input, + "max_tokens": 1000, + "temperature": 0, + } + try: + response = requests.post( + self.api_url, headers=headers, json=data, timeout=self.api_timeout + ) + if response.status_code != 200: + raise VLLMFailure( + "processing", + "Processing failed with status: {}".format(response.status_code), + ) + + # Parse the final response + response_data = response.json() + except VLLMFailure as e: + print("Error occurred:", e) + return None + + return response_data + + +class VLLMFanarModel(ModelBase): + """ + VLLM Model interface. + + Arguments + --------- + api_url : str + URL where the VLLM server is hosted. If not provided, the implementation will + look at environment variable `VLLM_API_URL` + timeout : int + Number of seconds before the request to the server is timed out + temperature : float + Temperature value to use for the model. Defaults to zero for reproducibility. + top_p : float + Top P value to use for the model. Defaults to 0.95 + max_tokens : int + Maximum number of tokens to pass to the model. Defaults to 1512 + """ + def __init__( self, api_url=None, @@ -119,11 +225,10 @@ def prompt(self, processed_input): "processing", "Processing failed with status: {}".format(response.status_code), ) - # Parse the final response response_data = response.json() - logging.info(f"initial_response: {response_data}") except VLLMFailure as e: print("Error occurred:", e) + return None return response_data diff --git a/llmebench/models/__init__.py b/llmebench/models/__init__.py index 92e44a30..990979e0 100644 --- a/llmebench/models/__init__.py +++ b/llmebench/models/__init__.py @@ -4,4 +4,4 @@ from .OpenAI import LegacyOpenAIModel, OpenAIModel from .Petals import PetalsModel from .Random import RandomModel -from .VLLM import VLLMModel +from .VLLM import VLLMFanarModel, VLLMModel From 11700a2922a76519b2b7b81183866905c1841f4d Mon Sep 17 00:00:00 2001 From: Fahim Imaduddin Dalvi Date: Sun, 21 Jul 2024 13:38:39 +0300 Subject: [PATCH 2/4] Changes from pr-321 --- llmebench/models/OpenAI.py | 7 +- llmebench/models/VLLM.py | 213 +++++-------------------------------- 2 files changed, 32 insertions(+), 188 deletions(-) diff --git a/llmebench/models/OpenAI.py b/llmebench/models/OpenAI.py index 6e8b2e8d..2db8387f 100644 --- a/llmebench/models/OpenAI.py +++ b/llmebench/models/OpenAI.py @@ -88,9 +88,6 @@ def __init__( openai.api_type = api_type - if api_base: - openai.api_base = api_base - if api_type == "azure" and api_version is None: raise Exception( "API version must be provided as model config or environment variable (`AZURE_API_VERSION`)" @@ -132,7 +129,9 @@ def __init__( base_url=f"{api_base}/openai/deployments/{model_name}/", ) elif api_type == "openai": - self.client = OpenAI(api_key=api_key) + if not api_base: + api_base = "https://api.openai.com/v1" + self.client = OpenAI(base_url=api_base, api_key=api_key) else: raise Exception('API type must be one of "azure" or "openai"') diff --git a/llmebench/models/VLLM.py b/llmebench/models/VLLM.py index ca8363ce..0fcc3d15 100644 --- a/llmebench/models/VLLM.py +++ b/llmebench/models/VLLM.py @@ -2,31 +2,19 @@ import logging import os -import requests +from llmebench.models.OpenAI import OpenAIModel -from llmebench.models.model_base import ModelBase +class VLLMModel(OpenAIModel): + """ + VLLM Model interface. Can be used for models hosted using https://github.com/vllm-project/vllm. -class VLLMFailure(Exception): - """Exception class to map various failure types from the VLLM server""" - - def __init__(self, failure_type, failure_message): - self.type_mapping = { - "processing": "Model Inference failure", - "connection": "Failed to connect to BLOOM Petal server", - } - self.type = failure_type - self.failure_message = failure_message + Accepts all arguments used by `OpenAIModel`, and overrides the arguments listed + below with VLLM variables. - def __str__(self): - return ( - f"{self.type_mapping.get(self.type, self.type)}: \n {self.failure_message}" - ) - - -class VLLMModel(ModelBase): - """ - VLLM Model interface. + See the [https://docs.vllm.ai/en/latest/models/supported_models.html](model_support) + page in VLLM's documentation for supported models and instructions on extending + to custom models. Arguments --------- @@ -45,7 +33,9 @@ class VLLMModel(ModelBase): def __init__( self, - api_url=None, + api_base=None, + api_key=None, + model_name=None, timeout=20, temperature=0, top_p=0.95, @@ -53,13 +43,9 @@ def __init__( **kwargs, ): # API parameters - self.api_url = api_url or os.getenv("VLLM_API_URL") - self.user_session_id = os.getenv("USER_SESSION_ID") - self.model = os.getenv("VLLM_MODEL") - if self.api_url is None: - raise Exception( - "API url must be provided as model config or environment variable (`VLLM_API_URL`)" - ) + self.api_base = api_base or os.getenv("VLLM_API_URL") + self.api_key = api_key or os.getenv("VLLM_API_KEY") + self.model_name = model_name or os.getenv("VLLM_MODEL") self.api_timeout = timeout # Parameters tolerance = 1e-7 @@ -72,163 +58,22 @@ def __init__( self.top_p = top_p self.max_tokens = max_tokens - super(VLLMModel, self).__init__( - retry_exceptions=(TimeoutError, VLLMFailure), **kwargs - ) - - def summarize_response(self, response): - """Returns the "outputs" key's value, if available""" - if "messages" in response: - return response["messages"] - - return response - - def prompt(self, processed_input): - """ - VLLM API Implementation - - Arguments - --------- - processed_input : dictionary - Must be a dictionary with one key "prompt", the value of which - must be a string. - - Returns - ------- - response : VLLM API response - Response from the VLLM server - - Raises - ------ - VLLMFailure : Exception - This method raises this exception if the server responded with a non-ok - response - """ - headers = {"Content-Type": "application/json"} - data = { - "model": self.model, - "messages": processed_input, - "max_tokens": 1000, - "temperature": 0, - } - try: - response = requests.post( - self.api_url, headers=headers, json=data, timeout=self.api_timeout + if self.api_base is None: + raise Exception( + "API url must be provided as model config or environment variable (`VLLM_API_BASE`)" ) - if response.status_code != 200: - raise VLLMFailure( - "processing", - "Processing failed with status: {}".format(response.status_code), - ) - - # Parse the final response - response_data = response.json() - except VLLMFailure as e: - print("Error occurred:", e) - return None - - return response_data - - -class VLLMFanarModel(ModelBase): - """ - VLLM Model interface. - - Arguments - --------- - api_url : str - URL where the VLLM server is hosted. If not provided, the implementation will - look at environment variable `VLLM_API_URL` - timeout : int - Number of seconds before the request to the server is timed out - temperature : float - Temperature value to use for the model. Defaults to zero for reproducibility. - top_p : float - Top P value to use for the model. Defaults to 0.95 - max_tokens : int - Maximum number of tokens to pass to the model. Defaults to 1512 - """ - - def __init__( - self, - api_url=None, - timeout=20, - temperature=0, - top_p=0.95, - max_tokens=1512, - **kwargs, - ): - # API parameters - self.api_url = api_url or os.getenv("VLLM_API_URL") - self.user_session_id = os.getenv("USER_SESSION_ID") - if self.api_url is None: + if self.api_key is None: raise Exception( - "API url must be provided as model config or environment variable (`VLLM_API_URL`)" + "API key must be provided as model config or environment variable (`VLLM_API_KEY`)" ) - self.api_timeout = timeout - # Parameters - tolerance = 1e-7 - self.temperature = temperature - if self.temperature < tolerance: - # Currently, the model inference fails if temperature - # is exactly 0, so we nudge it slightly to work around - # the issue - self.temperature += tolerance - self.top_p = top_p - self.max_tokens = max_tokens - + if self.model_name is None: + raise Exception( + "Model name must be provided as model config or environment variable (`VLLM_MODEL`)" + ) + # checks for valid config settings) super(VLLMModel, self).__init__( - retry_exceptions=(TimeoutError, VLLMFailure), **kwargs + api_base=self.api_base, + api_key=self.api_key, + model_name=self.model_name, + **kwargs, ) - - def summarize_response(self, response): - """Returns the "outputs" key's value, if available""" - if "messages" in response: - return response["messages"] - - return response - - def prompt(self, processed_input): - """ - VLLM API Implementation - - Arguments - --------- - processed_input : dictionary - Must be a dictionary with one key "prompt", the value of which - must be a string. - - Returns - ------- - response : VLLM API response - Response from the VLLM server - - Raises - ------ - VLLMFailure : Exception - This method raises this exception if the server responded with a non-ok - response - """ - headers = {"Content-Type": "application/json"} - params = { - "messages": processed_input, - "max_tokens": self.max_tokens, - "temperature": self.temperature, - "user_session_id": self.user_session_id, - } - try: - response = requests.post( - self.api_url, headers=headers, json=params, timeout=self.api_timeout - ) - if response.status_code != 200: - raise VLLMFailure( - "processing", - "Processing failed with status: {}".format(response.status_code), - ) - # Parse the final response - response_data = response.json() - except VLLMFailure as e: - print("Error occurred:", e) - return None - - return response_data From 4e1f2813491f10dff08c3958b01615911422c408 Mon Sep 17 00:00:00 2001 From: Fahim Imaduddin Dalvi Date: Sun, 21 Jul 2024 13:30:58 +0300 Subject: [PATCH 3/4] Update tests with new `base_url` replacement for `api_base` in openai package --- tests/models/test_FastChatModel.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/models/test_FastChatModel.py b/tests/models/test_FastChatModel.py index bde0f2fe..08e75072 100644 --- a/tests/models/test_FastChatModel.py +++ b/tests/models/test_FastChatModel.py @@ -37,7 +37,9 @@ def test_fastchat_config(self): ) self.assertEqual(openai.api_type, "openai") - self.assertEqual(openai.api_base, "llmebench.qcri.org") + self.assertEqual( + model.client.base_url.raw_path.decode("utf-8"), "llmebench.qcri.org/" + ) self.assertEqual(openai.api_key, "secret-key") self.assertEqual(model.model_params["model"], "private-model") @@ -54,7 +56,9 @@ def test_fastchat_config_env_var(self): model = FastChatModel() self.assertEqual(openai.api_type, "openai") - self.assertEqual(openai.api_base, "llmebench.qcri.org") + self.assertEqual( + model.client.base_url.raw_path.decode("utf-8"), "llmebench.qcri.org/" + ) self.assertEqual(openai.api_key, "secret-key") self.assertEqual(model.model_params["model"], "private-model") @@ -71,6 +75,8 @@ def test_fastchat_config_priority(self): model = FastChatModel(model_name="another-model") self.assertEqual(openai.api_type, "openai") - self.assertEqual(openai.api_base, "llmebench.qcri.org") + self.assertEqual( + model.client.base_url.raw_path.decode("utf-8"), "llmebench.qcri.org/" + ) self.assertEqual(openai.api_key, "secret-key") self.assertEqual(model.model_params["model"], "another-model") From 11712a989629690d23829257c0b7a3005ea5a96a Mon Sep 17 00:00:00 2001 From: Fahim Imaduddin Dalvi Date: Sun, 21 Jul 2024 13:44:07 +0300 Subject: [PATCH 4/4] Remove missing import --- llmebench/models/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmebench/models/__init__.py b/llmebench/models/__init__.py index 990979e0..92e44a30 100644 --- a/llmebench/models/__init__.py +++ b/llmebench/models/__init__.py @@ -4,4 +4,4 @@ from .OpenAI import LegacyOpenAIModel, OpenAIModel from .Petals import PetalsModel from .Random import RandomModel -from .VLLM import VLLMFanarModel, VLLMModel +from .VLLM import VLLMModel