Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VLLM/FastChat updates #321

Merged
merged 4 commits into from
Jul 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions llmebench/models/OpenAI.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,6 @@ def __init__(

openai.api_type = api_type

if api_base:
openai.api_base = api_base

if api_type == "azure" and api_version is None:
raise Exception(
"API version must be provided as model config or environment variable (`AZURE_API_VERSION`)"
Expand Down Expand Up @@ -132,7 +129,9 @@ def __init__(
base_url=f"{api_base}/openai/deployments/{model_name}/",
)
elif api_type == "openai":
self.client = OpenAI(api_key=api_key)
if not api_base:
api_base = "https://api.openai.com/v1"
self.client = OpenAI(base_url=api_base, api_key=api_key)
else:
raise Exception('API type must be one of "azure" or "openai"')

Expand Down
114 changes: 32 additions & 82 deletions llmebench/models/VLLM.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,19 @@
import logging
import os

import requests
from llmebench.models.OpenAI import OpenAIModel

from llmebench.models.model_base import ModelBase

class VLLMModel(OpenAIModel):
"""
VLLM Model interface. Can be used for models hosted using https://github.com/vllm-project/vllm.

class VLLMFailure(Exception):
"""Exception class to map various failure types from the VLLM server"""

def __init__(self, failure_type, failure_message):
self.type_mapping = {
"processing": "Model Inference failure",
"connection": "Failed to connect to BLOOM Petal server",
}
self.type = failure_type
self.failure_message = failure_message

def __str__(self):
return (
f"{self.type_mapping.get(self.type, self.type)}: \n {self.failure_message}"
)

Accepts all arguments used by `OpenAIModel`, and overrides the arguments listed
below with VLLM variables.

class VLLMModel(ModelBase):
"""
VLLM Model interface.
See the [https://docs.vllm.ai/en/latest/models/supported_models.html](model_support)
page in VLLM's documentation for supported models and instructions on extending
to custom models.

Arguments
---------
Expand All @@ -45,20 +33,19 @@ class VLLMModel(ModelBase):

def __init__(
self,
api_url=None,
api_base=None,
api_key=None,
model_name=None,
timeout=20,
temperature=0,
top_p=0.95,
max_tokens=1512,
**kwargs,
):
# API parameters
self.api_url = api_url or os.getenv("VLLM_API_URL")
self.user_session_id = os.getenv("USER_SESSION_ID")
if self.api_url is None:
raise Exception(
"API url must be provided as model config or environment variable (`VLLM_API_URL`)"
)
self.api_base = api_base or os.getenv("VLLM_API_URL")
self.api_key = api_key or os.getenv("VLLM_API_KEY")
self.model_name = model_name or os.getenv("VLLM_MODEL")
self.api_timeout = timeout
# Parameters
tolerance = 1e-7
Expand All @@ -71,59 +58,22 @@ def __init__(
self.top_p = top_p
self.max_tokens = max_tokens

if self.api_base is None:
raise Exception(
"API url must be provided as model config or environment variable (`VLLM_API_BASE`)"
)
if self.api_key is None:
raise Exception(
"API key must be provided as model config or environment variable (`VLLM_API_KEY`)"
)
if self.model_name is None:
raise Exception(
"Model name must be provided as model config or environment variable (`VLLM_MODEL`)"
)
# checks for valid config settings)
super(VLLMModel, self).__init__(
retry_exceptions=(TimeoutError, VLLMFailure), **kwargs
api_base=self.api_base,
api_key=self.api_key,
model_name=self.model_name,
**kwargs,
)

def summarize_response(self, response):
"""Returns the "outputs" key's value, if available"""
if "messages" in response:
return response["messages"]

return response

def prompt(self, processed_input):
"""
VLLM API Implementation

Arguments
---------
processed_input : dictionary
Must be a dictionary with one key "prompt", the value of which
must be a string.

Returns
-------
response : VLLM API response
Response from the VLLM server

Raises
------
VLLMFailure : Exception
This method raises this exception if the server responded with a non-ok
response
"""
headers = {"Content-Type": "application/json"}
params = {
"messages": processed_input,
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"user_session_id": self.user_session_id,
}
try:
response = requests.post(
self.api_url, headers=headers, json=params, timeout=self.api_timeout
)
if response.status_code != 200:
raise VLLMFailure(
"processing",
"Processing failed with status: {}".format(response.status_code),
)

# Parse the final response
response_data = response.json()
logging.info(f"initial_response: {response_data}")
except VLLMFailure as e:
print("Error occurred:", e)

return response_data
12 changes: 9 additions & 3 deletions tests/models/test_FastChatModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def test_fastchat_config(self):
)

self.assertEqual(openai.api_type, "openai")
self.assertEqual(openai.api_base, "llmebench.qcri.org")
self.assertEqual(
model.client.base_url.raw_path.decode("utf-8"), "llmebench.qcri.org/"
)
self.assertEqual(openai.api_key, "secret-key")
self.assertEqual(model.model_params["model"], "private-model")

Expand All @@ -54,7 +56,9 @@ def test_fastchat_config_env_var(self):
model = FastChatModel()

self.assertEqual(openai.api_type, "openai")
self.assertEqual(openai.api_base, "llmebench.qcri.org")
self.assertEqual(
model.client.base_url.raw_path.decode("utf-8"), "llmebench.qcri.org/"
)
self.assertEqual(openai.api_key, "secret-key")
self.assertEqual(model.model_params["model"], "private-model")

Expand All @@ -71,6 +75,8 @@ def test_fastchat_config_priority(self):
model = FastChatModel(model_name="another-model")

self.assertEqual(openai.api_type, "openai")
self.assertEqual(openai.api_base, "llmebench.qcri.org")
self.assertEqual(
model.client.base_url.raw_path.decode("utf-8"), "llmebench.qcri.org/"
)
self.assertEqual(openai.api_key, "secret-key")
self.assertEqual(model.model_params["model"], "another-model")
Loading