Skip to content

Commit 86d1f14

Browse files
committed
Move chat and complete to vllm.cmd.openai
Signed-off-by: Russell Bryant <rbryant@redhat.com>
1 parent e32f896 commit 86d1f14

File tree

4 files changed

+183
-117
lines changed

4 files changed

+183
-117
lines changed

vllm/cmd/main.py

Lines changed: 12 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,23 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
# The CLI entrypoint to vLLM.
4-
import argparse
54
import os
65
import signal
76
import sys
8-
from typing import List, Optional
9-
10-
from openai import OpenAI
11-
from openai.types.chat import ChatCompletionMessageParam
127

8+
import vllm.cmd.openai
139
import vllm.cmd.serve
1410
import vllm.version
1511
from vllm.logger import init_logger
1612
from vllm.utils import FlexibleArgumentParser
1713

1814
logger = init_logger(__name__)
1915

16+
CMD_MODULES = [
17+
vllm.cmd.openai,
18+
vllm.cmd.serve,
19+
]
20+
2021

2122
def register_signal_handlers():
2223

@@ -27,83 +28,6 @@ def signal_handler(sig, frame):
2728
signal.signal(signal.SIGTSTP, signal_handler)
2829

2930

30-
def interactive_cli(args: argparse.Namespace) -> None:
31-
register_signal_handlers()
32-
33-
base_url = args.url
34-
api_key = args.api_key or os.environ.get("OPENAI_API_KEY", "EMPTY")
35-
openai_client = OpenAI(api_key=api_key, base_url=base_url)
36-
37-
if args.model_name:
38-
model_name = args.model_name
39-
else:
40-
available_models = openai_client.models.list()
41-
model_name = available_models.data[0].id
42-
43-
print(f"Using model: {model_name}")
44-
45-
if args.command == "complete":
46-
complete(model_name, openai_client)
47-
elif args.command == "chat":
48-
chat(args.system_prompt, model_name, openai_client)
49-
50-
51-
def complete(model_name: str, client: OpenAI) -> None:
52-
print("Please enter prompt to complete:")
53-
while True:
54-
input_prompt = input("> ")
55-
56-
completion = client.completions.create(model=model_name,
57-
prompt=input_prompt)
58-
output = completion.choices[0].text
59-
print(output)
60-
61-
62-
def chat(system_prompt: Optional[str], model_name: str,
63-
client: OpenAI) -> None:
64-
conversation: List[ChatCompletionMessageParam] = []
65-
if system_prompt is not None:
66-
conversation.append({"role": "system", "content": system_prompt})
67-
68-
print("Please enter a message for the chat model:")
69-
while True:
70-
input_message = input("> ")
71-
conversation.append({"role": "user", "content": input_message})
72-
73-
chat_completion = client.chat.completions.create(model=model_name,
74-
messages=conversation)
75-
76-
response_message = chat_completion.choices[0].message
77-
output = response_message.content
78-
79-
conversation.append(response_message) # type: ignore
80-
print(output)
81-
82-
83-
def _add_query_options(
84-
parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
85-
parser.add_argument(
86-
"--url",
87-
type=str,
88-
default="http://localhost:8000/v1",
89-
help="url of the running OpenAI-Compatible RESTful API server")
90-
parser.add_argument(
91-
"--model-name",
92-
type=str,
93-
default=None,
94-
help=("The model name used in prompt completion, default to "
95-
"the first model in list models API call."))
96-
parser.add_argument(
97-
"--api-key",
98-
type=str,
99-
default=None,
100-
help=(
101-
"API key for OpenAI services. If provided, this api key "
102-
"will overwrite the api key obtained through environment variables."
103-
))
104-
return parser
105-
106-
10731
def env_setup():
10832
# The safest multiprocessing method is `spawn`, as the default `fork` method
10933
# is not compatible with some accelerators. The default method will be
@@ -134,43 +58,17 @@ def main():
13458
action='version',
13559
version=vllm.version.__version__)
13660
subparsers = parser.add_subparsers(required=True, dest="subparser")
137-
138-
cmd_modules = [
139-
vllm.cmd.serve,
140-
]
14161
cmds = {}
142-
for cmd_module in cmd_modules:
143-
cmd = cmd_module.cmd_init()
144-
cmd.subparser_init(subparsers).set_defaults(dispatch_function=cmd.cmd)
145-
cmds[cmd.name] = cmd
146-
147-
complete_parser = subparsers.add_parser(
148-
"complete",
149-
help=("Generate text completions based on the given prompt "
150-
"via the running API server"),
151-
usage="vllm complete [options]")
152-
_add_query_options(complete_parser)
153-
complete_parser.set_defaults(dispatch_function=interactive_cli,
154-
command="complete")
155-
156-
chat_parser = subparsers.add_parser(
157-
"chat",
158-
help="Generate chat completions via the running API server",
159-
usage="vllm chat [options]")
160-
_add_query_options(chat_parser)
161-
chat_parser.add_argument(
162-
"--system-prompt",
163-
type=str,
164-
default=None,
165-
help=("The system prompt to be added to the chat template, "
166-
"used for models that support system prompts."))
167-
chat_parser.set_defaults(dispatch_function=interactive_cli, command="chat")
168-
62+
for cmd_module in CMD_MODULES:
63+
new_cmds = cmd_module.cmd_init()
64+
for cmd in new_cmds:
65+
cmd.subparser_init(subparsers).set_defaults(
66+
dispatch_function=cmd.cmd)
67+
cmds[cmd.name] = cmd
16968
args = parser.parse_args()
17069
if args.subparser in cmds:
17170
cmds[args.subparser].validate(args)
17271

173-
# One of the sub commands should be executed.
17472
if hasattr(args, "dispatch_function"):
17573
args.dispatch_function(args)
17674
else:

vllm/cmd/openai.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# Commands that act as an interactive OpenAI API client
3+
4+
import argparse
5+
import os
6+
import signal
7+
import sys
8+
from typing import List, Optional, Tuple
9+
10+
from openai import OpenAI
11+
from openai.types.chat import ChatCompletionMessageParam
12+
13+
from vllm.cmd.types import CLISubcommand
14+
from vllm.utils import FlexibleArgumentParser
15+
16+
17+
def _register_signal_handlers():
18+
19+
def signal_handler(sig, frame):
20+
sys.exit(0)
21+
22+
signal.signal(signal.SIGINT, signal_handler)
23+
signal.signal(signal.SIGTSTP, signal_handler)
24+
25+
26+
def _interactive_cli(args: argparse.Namespace) -> Tuple[str, OpenAI]:
27+
_register_signal_handlers()
28+
29+
base_url = args.url
30+
api_key = args.api_key or os.environ.get("OPENAI_API_KEY", "EMPTY")
31+
openai_client = OpenAI(api_key=api_key, base_url=base_url)
32+
33+
if args.model_name:
34+
model_name = args.model_name
35+
else:
36+
available_models = openai_client.models.list()
37+
model_name = available_models.data[0].id
38+
39+
print(f"Using model: {model_name}")
40+
41+
return model_name, openai_client
42+
43+
44+
def chat(system_prompt: Optional[str], model_name: str,
45+
client: OpenAI) -> None:
46+
conversation: List[ChatCompletionMessageParam] = []
47+
if system_prompt is not None:
48+
conversation.append({"role": "system", "content": system_prompt})
49+
50+
print("Please enter a message for the chat model:")
51+
while True:
52+
input_message = input("> ")
53+
conversation.append({"role": "user", "content": input_message})
54+
55+
chat_completion = client.chat.completions.create(model=model_name,
56+
messages=conversation)
57+
58+
response_message = chat_completion.choices[0].message
59+
output = response_message.content
60+
61+
conversation.append(response_message) # type: ignore
62+
print(output)
63+
64+
65+
def _add_query_options(
66+
parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
67+
parser.add_argument(
68+
"--url",
69+
type=str,
70+
default="http://localhost:8000/v1",
71+
help="url of the running OpenAI-Compatible RESTful API server")
72+
parser.add_argument(
73+
"--model-name",
74+
type=str,
75+
default=None,
76+
help=("The model name used in prompt completion, default to "
77+
"the first model in list models API call."))
78+
parser.add_argument(
79+
"--api-key",
80+
type=str,
81+
default=None,
82+
help=(
83+
"API key for OpenAI services. If provided, this api key "
84+
"will overwrite the api key obtained through environment variables."
85+
))
86+
return parser
87+
88+
89+
class ChatCommand(CLISubcommand):
90+
"""The `chat` subcommand for the vLLM CLI. """
91+
92+
def __init__(self):
93+
self.name = "chat"
94+
super().__init__()
95+
96+
@staticmethod
97+
def cmd(args: argparse.Namespace) -> None:
98+
model_name, client = _interactive_cli(args)
99+
system_prompt = args.system_prompt
100+
conversation: List[ChatCompletionMessageParam] = []
101+
if system_prompt is not None:
102+
conversation.append({"role": "system", "content": system_prompt})
103+
104+
print("Please enter a message for the chat model:")
105+
while True:
106+
input_message = input("> ")
107+
conversation.append({"role": "user", "content": input_message})
108+
109+
chat_completion = client.chat.completions.create(
110+
model=model_name, messages=conversation)
111+
112+
response_message = chat_completion.choices[0].message
113+
output = response_message.content
114+
115+
conversation.append(response_message) # type: ignore
116+
print(output)
117+
118+
def subparser_init(
119+
self,
120+
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
121+
chat_parser = subparsers.add_parser(
122+
"chat",
123+
help="Generate chat completions via the running API server",
124+
usage="vllm chat [options]")
125+
_add_query_options(chat_parser)
126+
chat_parser.add_argument(
127+
"--system-prompt",
128+
type=str,
129+
default=None,
130+
help=("The system prompt to be added to the chat template, "
131+
"used for models that support system prompts."))
132+
return chat_parser
133+
134+
135+
class CompleteCommand(CLISubcommand):
136+
"""The `complete` subcommand for the vLLM CLI. """
137+
138+
def __init__(self):
139+
self.name = "complete"
140+
super().__init__()
141+
142+
@staticmethod
143+
def cmd(args: argparse.Namespace) -> None:
144+
model_name, client = _interactive_cli(args)
145+
print("Please enter prompt to complete:")
146+
while True:
147+
input_prompt = input("> ")
148+
completion = client.completions.create(model=model_name,
149+
prompt=input_prompt)
150+
output = completion.choices[0].text
151+
print(output)
152+
153+
def subparser_init(
154+
self,
155+
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
156+
complete_parser = subparsers.add_parser(
157+
"complete",
158+
help=("Generate text completions based on the given prompt "
159+
"via the running API server"),
160+
usage="vllm complete [options]")
161+
_add_query_options(complete_parser)
162+
return complete_parser
163+
164+
165+
def cmd_init() -> List[CLISubcommand]:
166+
return [ChatCommand(), CompleteCommand()]

vllm/cmd/serve.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import argparse
4+
from typing import List
45

56
import uvloop
67

@@ -17,6 +18,7 @@ class ServeSubcommand(CLISubcommand):
1718

1819
def __init__(self):
1920
self.name = "serve"
21+
super().__init__()
2022

2123
@staticmethod
2224
def cmd(args: argparse.Namespace) -> None:
@@ -57,5 +59,5 @@ def subparser_init(
5759
return make_arg_parser(serve_parser)
5860

5961

60-
def cmd_init() -> CLISubcommand:
61-
return ServeSubcommand()
62+
def cmd_init() -> List[CLISubcommand]:
63+
return [ServeSubcommand()]

vllm/cmd/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def cmd(args: argparse.Namespace) -> None:
1515
raise NotImplementedError("Subclasses should implement this method")
1616

1717
def validate(self, args: argparse.Namespace) -> None:
18-
# No validation by deafult
18+
# No validation by default
1919
pass
2020

2121
def subparser_init(

0 commit comments

Comments
 (0)