Skip to content

Commit 721fa3d

Browse files
authored
FastAPI-based working frontend (#10)
1 parent d359cda commit 721fa3d

File tree

15 files changed

+536
-146
lines changed

15 files changed

+536
-146
lines changed

README.md

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,46 @@ pip install flash-attn # This may take up to 10 mins.
88
pip install -e .
99
```
1010

11-
## Run
11+
## Test simple server
1212

1313
```bash
1414
ray start --head
15-
python server.py [--tensor-parallel-size <N>]
15+
python simple_server.py
16+
```
17+
18+
The detailed arguments for `simple_server.py` can be found by:
19+
```bash
20+
python simple_server.py --help
21+
```
22+
23+
## FastAPI server
24+
25+
Install the following additional dependencies:
26+
```bash
27+
pip install fastapi uvicorn
28+
```
29+
30+
To start the server:
31+
```bash
32+
ray start --head
33+
python -m cacheflow.http_frontend.fastapi_frontend
34+
```
35+
36+
To test the server:
37+
```bash
38+
python -m cacheflow.http_frontend.test_cli_client
39+
```
40+
41+
## Gradio web server
42+
43+
Install the following additional dependencies:
44+
```bash
45+
pip install gradio
46+
```
47+
48+
Start the server:
49+
```bash
50+
python -m cacheflow.http_frontend.fastapi_frontend
51+
# At another terminal
52+
python -m cacheflow.http_frontend.gradio_webserver
1653
```
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
import argparse
2+
import asyncio
3+
import time
4+
from typing import List, Dict
5+
import json
6+
7+
import ray
8+
from transformers import AutoTokenizer
9+
from fastapi import FastAPI, Request
10+
from fastapi.responses import StreamingResponse
11+
import uvicorn
12+
13+
from cacheflow.sampling_params import SamplingParams
14+
from cacheflow.sequence import Sequence, SequenceGroup
15+
from cacheflow.master.server import (Server, add_server_arguments,
16+
initialize_ray_cluster)
17+
from cacheflow.worker.controller import DeviceID
18+
from cacheflow.utils import Counter, get_gpu_memory, get_cpu_memory
19+
20+
app = FastAPI()
21+
22+
class FastAPIFrontend:
23+
def __init__(
24+
self,
25+
model: str,
26+
model_path: str,
27+
pipeline_parallel_size: int,
28+
tensor_parallel_size: int,
29+
block_size: int,
30+
dtype: str,
31+
seed: int,
32+
swap_space: int,
33+
max_batch_size: int,
34+
num_nodes: int,
35+
num_devices_per_node: int,
36+
distributed_init_method: str,
37+
all_stage_devices: List[List[DeviceID]],
38+
):
39+
self.block_size = block_size
40+
41+
self.tokenizer = AutoTokenizer.from_pretrained(model)
42+
self.seq_group_counter = Counter()
43+
self.seq_counter = Counter()
44+
remote_server_class = ray.remote(num_cpus=0)(Server)
45+
self.server = remote_server_class.remote(
46+
model=model,
47+
model_path=model_path,
48+
pipeline_parallel_size=pipeline_parallel_size,
49+
tensor_parallel_size=tensor_parallel_size,
50+
block_size=block_size,
51+
dtype=dtype,
52+
seed=seed,
53+
swap_space=swap_space,
54+
max_batch_size=max_batch_size,
55+
num_nodes=num_nodes,
56+
num_devices_per_node=num_devices_per_node,
57+
distributed_init_method=distributed_init_method,
58+
all_stage_devices=all_stage_devices,
59+
gpu_memory=get_gpu_memory(),
60+
cpu_memory=get_cpu_memory(),
61+
)
62+
63+
self.running_seq_groups: Dict[int, SequenceGroup] = {}
64+
self.sequence_group_events: Dict[int, asyncio.Event] = {}
65+
self.is_server_running = False
66+
67+
async def server_step(self):
68+
self.is_server_running = True
69+
updated_seq_groups = await self.server.step.remote()
70+
self.is_server_running = False
71+
for seq_group in updated_seq_groups:
72+
group_id = seq_group.group_id
73+
self.running_seq_groups[group_id] = seq_group
74+
self.sequence_group_events[group_id].set()
75+
76+
async def generate(self, request_dict: Dict):
77+
prompt = request_dict["prompt"]
78+
sampling_params = SamplingParams.from_dict(request_dict)
79+
sampling_params.stop_token_ids.add(self.tokenizer.eos_token_id)
80+
token_ids = self.tokenizer.encode(prompt)
81+
seqs: List[Sequence] = []
82+
for _ in range(sampling_params.n):
83+
seq_id = next(self.seq_counter)
84+
seq = Sequence(seq_id, token_ids, block_size=self.block_size)
85+
seqs.append(seq)
86+
87+
group_id = next(self.seq_group_counter)
88+
seq_group = SequenceGroup(group_id, seqs)
89+
group_event = asyncio.Event()
90+
self.sequence_group_events[group_id] = group_event
91+
await self.server.add_sequence_groups.remote([(seq_group, sampling_params)])
92+
while True:
93+
if not self.is_server_running:
94+
await self.server_step()
95+
# Wait for new output. Add a 1s timeout to prevent dead lock.
96+
await asyncio.wait_for(group_event.wait(), timeout=1)
97+
group_event.clear()
98+
seq_group = self.running_seq_groups[group_id]
99+
all_outputs = []
100+
for seq in seq_group.seqs:
101+
token_ids = seq.get_token_ids()
102+
output = self.tokenizer.decode(token_ids, skip_special_tokens=True)
103+
all_outputs.append(output)
104+
ret = {
105+
"text": all_outputs,
106+
"error": 0,
107+
}
108+
yield (json.dumps(ret) + "\0").encode("utf-8")
109+
if seq_group.is_finished():
110+
break
111+
112+
113+
@app.post("/generate")
114+
async def generate_stream(request: Request):
115+
request_dict = await request.json()
116+
return StreamingResponse(frontend.generate(request_dict))
117+
118+
119+
if __name__ == "__main__":
120+
parser = argparse.ArgumentParser()
121+
parser.add_argument("--host", type=str, default="localhost")
122+
parser.add_argument("--port", type=int, default=10002)
123+
parser = add_server_arguments(parser)
124+
args = parser.parse_args()
125+
126+
# TODO(zhuohan): Support pipeline parallelism.
127+
assert args.pipeline_parallel_size == 1, (
128+
'Pipeline parallelism is not supported yet.')
129+
130+
(num_nodes, num_devices_per_node, distributed_init_method,
131+
all_stage_devices) = (
132+
initialize_ray_cluster(
133+
pipeline_parallel_size=args.pipeline_parallel_size,
134+
tensor_parallel_size=args.tensor_parallel_size))
135+
136+
frontend = FastAPIFrontend(
137+
model=args.model,
138+
model_path=args.model_path,
139+
pipeline_parallel_size=args.pipeline_parallel_size,
140+
tensor_parallel_size=args.tensor_parallel_size,
141+
block_size=args.block_size,
142+
dtype=args.dtype,
143+
seed=args.seed,
144+
swap_space=args.swap_space,
145+
max_batch_size=args.max_batch_size,
146+
num_nodes=num_nodes,
147+
num_devices_per_node=num_devices_per_node,
148+
distributed_init_method=distributed_init_method,
149+
all_stage_devices=all_stage_devices,
150+
)
151+
152+
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import argparse
2+
import json
3+
import time
4+
5+
import gradio as gr
6+
import requests
7+
8+
9+
def http_bot(prompt):
10+
headers = {"User-Agent": "Cacheflow Client"}
11+
pload = {
12+
"prompt": prompt,
13+
"max_num_steps": 128,
14+
}
15+
response = requests.post(args.model_url, headers=headers, json=pload, stream=True)
16+
17+
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
18+
if chunk:
19+
data = json.loads(chunk.decode("utf-8"))
20+
output = data["text"][0]
21+
yield output
22+
23+
24+
def build_demo():
25+
with gr.Blocks() as demo:
26+
gr.Markdown(
27+
"# Cacheflow demo\n"
28+
)
29+
inputbox = gr.Textbox(label="Input", placeholder="Enter text and press ENTER")# .style(container=False)
30+
outputbox = gr.Textbox(label="Output", placeholder="Generated result from the model")
31+
inputbox.submit(http_bot, [inputbox], [outputbox])
32+
return demo
33+
34+
35+
if __name__ == "__main__":
36+
parser = argparse.ArgumentParser()
37+
parser.add_argument("--host", type=str, default="localhost")
38+
parser.add_argument("--port", type=int, default=10003)
39+
parser.add_argument("--model-url", type=str, default="http://localhost:10002/generate")
40+
args = parser.parse_args()
41+
42+
demo = build_demo()
43+
demo.queue(concurrency_count=100).launch(server_name=args.host, server_port=args.port)
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import requests
2+
import json
3+
4+
def http_request():
5+
prompt = "Ion Stoica is a"
6+
7+
headers = {"User-Agent": "Test Client"}
8+
pload = {
9+
"prompt": prompt,
10+
"n": 4,
11+
"use_beam_search": True,
12+
"temperature": 0.0,
13+
}
14+
response = requests.post("http://localhost:10002/generate", headers=headers, json=pload, stream=True)
15+
16+
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
17+
if chunk:
18+
data = json.loads(chunk.decode("utf-8"))
19+
output = data["text"]
20+
yield output
21+
22+
for h in http_request():
23+
print(h, flush=True)

cacheflow/master/scheduler.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
from typing import Dict, List
1+
from typing import Dict, List, Tuple
22

33
from cacheflow.master.block_manager import BlockSpaceManager
4-
from cacheflow.master.frontend import Frontend
54
from cacheflow.sampling_params import SamplingParams
65
from cacheflow.sequence import Sequence
76
from cacheflow.sequence import SequenceGroup
@@ -14,14 +13,12 @@ class Scheduler:
1413

1514
def __init__(
1615
self,
17-
frontend: Frontend,
1816
controllers: List,
1917
block_size: int,
2018
num_gpu_blocks: int,
2119
num_cpu_blocks: int,
2220
max_num_batched_tokens: int,
2321
) -> None:
24-
self.frontend = frontend
2522
self.controllers = controllers
2623
self.block_size = block_size
2724
self.num_gpu_blocks = num_gpu_blocks
@@ -47,9 +44,12 @@ def __init__(
4744
# Pending sequence groups (FIFO).
4845
self.pending: List[SequenceGroup] = []
4946

50-
def _fetch_inputs(self) -> None:
51-
inputs = self.frontend.get_inputs()
52-
for seq_group, sampling_params in inputs:
47+
def add_sequence_groups(
48+
self,
49+
sequence_groups: List[Tuple[SequenceGroup, SamplingParams]],
50+
) -> None:
51+
# Add sequence groups to the pending queue.
52+
for seq_group, sampling_params in sequence_groups:
5353
self.pending.append(seq_group)
5454
self.sampling_params[seq_group.group_id] = sampling_params
5555

@@ -104,7 +104,7 @@ def _swap_out(
104104
seq.status = SequenceStatus.SWAPPED
105105
self.swapped.append(seq_group)
106106

107-
def step(self) -> None:
107+
def step(self) -> List[SequenceGroup]:
108108
# Blocks that need to be swaped or copied before model execution.
109109
blocks_to_swap_in: Dict[int, int] = {}
110110
blocks_to_swap_out: Dict[int, int] = {}
@@ -158,7 +158,6 @@ def step(self) -> None:
158158
# 3. Join new sequences if possible.
159159
# NOTE: Here we implicitly assume FCFS scheduling.
160160
# TODO(woosuk): Add a batching policy to control the batch size.
161-
self._fetch_inputs()
162161
if not self.swapped:
163162
for i, seq_group in enumerate(self.pending):
164163
num_prompt_tokens = seq_group.seqs[0].get_len()
@@ -176,6 +175,8 @@ def step(self) -> None:
176175

177176
# 4. Create input data structures.
178177
input_seq_groups: List[SequenceGroupInputs] = []
178+
updated_seq_groups: List[SequenceGroup] = self.running.copy()
179+
179180
for seq_group in self.running:
180181
group_id = seq_group.group_id
181182
num_steps = self.num_steps[group_id]
@@ -219,6 +220,8 @@ def step(self) -> None:
219220
blocks_to_copy,
220221
)
221222

223+
return updated_seq_groups
224+
222225
def post_step(
223226
self,
224227
seq_outputs: Dict[int, SequenceOutputs],
@@ -268,13 +271,12 @@ def post_step(
268271
running: List[SequenceGroup] = []
269272
for seq_group in self.running:
270273
if seq_group.is_finished():
271-
self._return(seq_group)
274+
self._free_seq_group(seq_group)
272275
else:
273276
running.append(seq_group)
274277
self.running = running
275278

276-
def _return(self, seq_group: SequenceGroup) -> None:
279+
def _free_seq_group(self, seq_group: SequenceGroup) -> None:
277280
group_id = seq_group.group_id
278281
del self.num_steps[group_id]
279282
del self.sampling_params[group_id]
280-
self.frontend.print_response(seq_group)

0 commit comments

Comments
 (0)