@@ -51,20 +51,31 @@ def check_vllm_server(url: str, timeout=5, retries=3) -> bool:
5151 return False
5252
5353
54- def run_simple_prompt (base_url : str , model_name : str ,
55- input_prompt : str ) -> str :
54+ def run_simple_prompt (base_url : str , model_name : str , input_prompt : str ,
55+ use_chat_endpoint : bool ) -> str :
5656 client = openai .OpenAI (api_key = "EMPTY" , base_url = base_url )
57- completion = client .completions .create (model = model_name ,
58- prompt = input_prompt ,
59- max_tokens = MAX_OUTPUT_LEN ,
60- temperature = 0.0 ,
61- seed = 42 )
57+ if use_chat_endpoint :
58+ completion = client .chat .completions .create (
59+ model = model_name ,
60+ messages = [{
61+ "role" : "user" ,
62+ "content" : [{
63+ "type" : "text" ,
64+ "text" : input_prompt
65+ }]
66+ }],
67+ max_completion_tokens = MAX_OUTPUT_LEN ,
68+ temperature = 0.0 ,
69+ seed = 42 )
70+ return completion .choices [0 ].message .content
71+ else :
72+ completion = client .completions .create (model = model_name ,
73+ prompt = input_prompt ,
74+ max_tokens = MAX_OUTPUT_LEN ,
75+ temperature = 0.0 ,
76+ seed = 42 )
6277
63- # print("-" * 50)
64- # print(f"Completion results for {model_name}:")
65- # print(completion)
66- # print("-" * 50)
67- return completion .choices [0 ].text
78+ return completion .choices [0 ].text
6879
6980
7081def main ():
@@ -125,10 +136,12 @@ def main():
125136 f"vllm server: { args .service_url } is not ready yet!" )
126137
127138 output_strs = dict ()
128- for prompt in SAMPLE_PROMPTS :
139+ for i , prompt in enumerate (SAMPLE_PROMPTS ):
140+ use_chat_endpoint = (i % 2 == 1 )
129141 output_str = run_simple_prompt (base_url = service_url ,
130142 model_name = args .model_name ,
131- input_prompt = prompt )
143+ input_prompt = prompt ,
144+ use_chat_endpoint = use_chat_endpoint )
132145 print (f"Prompt: { prompt } , output: { output_str } " )
133146 output_strs [prompt ] = output_str
134147
0 commit comments