Skip to content

Commit c45f3c3

Browse files
authored
Optimize tensor parallel execution speed (#17)
1 parent 7a7929a commit c45f3c3

File tree

3 files changed

+103
-287
lines changed

3 files changed

+103
-287
lines changed

benchmark/benchmark_latency.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import argparse
2+
import time
3+
from typing import List
4+
5+
from tqdm import tqdm
6+
import numpy as np
7+
import torch
8+
9+
from cacheflow.master.simple_frontend import SimpleFrontend
10+
from cacheflow.master.server import (Server, add_server_arguments,
11+
initialize_ray_cluster)
12+
from cacheflow.sampling_params import SamplingParams
13+
from cacheflow.utils import get_gpu_memory, get_cpu_memory
14+
15+
16+
def main(args: argparse.Namespace):
17+
# TODO(zhuohan): Support pipeline parallelism.
18+
assert args.pipeline_parallel_size == 1, (
19+
'Pipeline parallelism is not supported yet.')
20+
21+
(num_nodes, num_devices_per_node, distributed_init_method,
22+
all_stage_devices) = (
23+
initialize_ray_cluster(
24+
address='local',
25+
pipeline_parallel_size=args.pipeline_parallel_size,
26+
tensor_parallel_size=args.tensor_parallel_size))
27+
28+
# Create a server.
29+
server = Server(
30+
model=args.model,
31+
model_path=args.model_path,
32+
pipeline_parallel_size=args.pipeline_parallel_size,
33+
tensor_parallel_size=args.tensor_parallel_size,
34+
block_size=args.block_size,
35+
dtype=args.dtype,
36+
seed=args.seed,
37+
swap_space=args.swap_space,
38+
max_batch_size=args.max_batch_size,
39+
num_nodes=num_nodes,
40+
num_devices_per_node=num_devices_per_node,
41+
distributed_init_method=distributed_init_method,
42+
all_stage_devices=all_stage_devices,
43+
gpu_memory=get_gpu_memory(),
44+
cpu_memory=get_cpu_memory(),
45+
)
46+
47+
# Create a frontend.
48+
frontend = SimpleFrontend(
49+
model_name=args.model,
50+
block_size=args.block_size,
51+
)
52+
sampling_params_dict = {
53+
'n': 1,
54+
'temperature': 0.0,
55+
'top_p': 1.0,
56+
'use_beam_search': False,
57+
'stop_token_ids': set(),
58+
'max_num_steps': args.output_len,
59+
}
60+
sampling_params = SamplingParams.from_dict(sampling_params_dict)
61+
input_token_ids = [0] * args.input_len
62+
63+
def profile_step(profile=False):
64+
if profile:
65+
torch.cuda.cudart().cudaProfilerStart()
66+
for _ in range(args.batch_size):
67+
frontend._add_query(input_token_ids, sampling_params)
68+
server.add_sequence_groups(frontend.get_inputs())
69+
start_time = time.time()
70+
while True:
71+
server.step()
72+
if not server.has_unfinished_requests():
73+
break
74+
end_time = time.time()
75+
latency = end_time - start_time
76+
if profile:
77+
torch.cuda.cudart().cudaProfilerStop()
78+
return latency
79+
80+
print("Warm up step")
81+
profile_step()
82+
83+
# Benchmark.
84+
latencies = []
85+
for _ in tqdm(range(3), desc="Profile step"):
86+
latencies.append(profile_step())
87+
print(f'Avg latency: {np.mean(latencies)} seconds')
88+
89+
90+
if __name__ == '__main__':
91+
parser = argparse.ArgumentParser(description='CacheFlow simple server.')
92+
parser = add_server_arguments(parser)
93+
parser.add_argument('--input-len', type=int, default=32)
94+
parser.add_argument('--output-len', type=int, default=128)
95+
parser.add_argument('--batch-size', type=int, default=8)
96+
args = parser.parse_args()
97+
args.max_batch_size = max(args.max_batch_size, args.batch_size * args.input_len)
98+
print(args)
99+
main(args)

cacheflow/parallel_utils/tensor_parallel/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
set_defaults_if_not_set_tensor_model_parallel_attributes,
77
copy_tensor_model_parallel_attributes,
88
param_is_not_tensor_parallel_duplicate,
9-
linear_with_grad_accumulation_and_async_allreduce
10-
119
)
1210

1311
from .mappings import (
@@ -39,7 +37,6 @@
3937
"set_defaults_if_not_set_tensor_model_parallel_attributes",
4038
"copy_tensor_model_parallel_attributes",
4139
"param_is_not_tensor_parallel_duplicate",
42-
"linear_with_grad_accumulation_and_async_allreduce",
4340
# mappings.py
4441
"copy_to_tensor_model_parallel_region",
4542
"gather_from_tensor_model_parallel_region",

0 commit comments

Comments
 (0)