Skip to content

Commit

Permalink
Merge pull request #10 from anandhu-eng/vllm_enhancement
Browse files Browse the repository at this point in the history
Verification check for model - VLLM
  • Loading branch information
arjunsuresh authored Jul 17, 2024
2 parents 7607097 + 44ae1d9 commit 93b5d64
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions language/llama2-70b/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,27 @@
import os
import logging
import sys
import requests
import json

sys.path.insert(0, os.getcwd())

logging.basicConfig(level=logging.INFO)
log = logging.getLogger("Llama-70B-MAIN")

# function to check the model name in server matches the user specified one
def verify_model_name(user_specified_name, url):
response = requests.get(url)
if response.status_code == 200:
response_dict = response.json()
server_model_name = response_dict["data"][0]["id"]
if user_specified_name == server_model_name:
return {"matched":True, "error":False}
else:
return {"matched":False, "error":f"User specified {user_specified_name} and server model name {server_model_name} mismatch!"}
else:
return {"matched":False, "error":f"Failed to get a valid response. Status code: {response.status_code}"}

def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--scenario", type=str, choices=["Offline", "Server"], default="Offline", help="Scenario")
Expand Down Expand Up @@ -41,6 +56,13 @@ def get_args():

def main():
args = get_args()

if args.vllm:
resp = verify_model_name(args.api_model_name, args.api_server+"/v1/models")
if resp["error"]:
print(f"\n\n\033[91mError:\033[0m", end=" ")
print(resp["error"])
sys.exit(1)

settings = lg.TestSettings()
settings.scenario = scenario_map[args.scenario.lower()]
Expand Down

0 comments on commit 93b5d64

Please sign in to comment.