Skip to content

Commit b3eb768

Browse files
ca1207WyldeCatjeejeelee
authored andcommitted
[Model] New model support for Motif-1-Tiny (vllm-project#23414)
Signed-off-by: ca1207 <ca1207zzz@gmail.com> Signed-off-by: TaehyunKim <73943231+ca1207@users.noreply.github.com> Co-authored-by: WyldeCat <skan1543@gmail.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
1 parent 2015223 commit b3eb768

File tree

13 files changed

+871
-4
lines changed

13 files changed

+871
-4
lines changed
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import itertools
5+
6+
import torch
7+
8+
from vllm import _custom_ops as vllm_ops
9+
from vllm.triton_utils import triton
10+
11+
12+
def polynorm_naive(
13+
x: torch.Tensor,
14+
weight: torch.Tensor,
15+
bias: torch.Tensor,
16+
eps: float = 1e-6,
17+
):
18+
orig_shape = x.shape
19+
x = x.view(-1, x.shape[-1])
20+
21+
def norm(x, eps: float):
22+
return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
23+
24+
x = x.float()
25+
return (
26+
(
27+
weight[0] * norm(x**3, eps)
28+
+ weight[1] * norm(x**2, eps)
29+
+ weight[2] * norm(x, eps)
30+
+ bias
31+
)
32+
.to(weight.dtype)
33+
.view(orig_shape)
34+
)
35+
36+
37+
def polynorm_vllm(
38+
x: torch.Tensor,
39+
weight: torch.Tensor,
40+
bias: torch.Tensor,
41+
eps: float = 1e-6,
42+
):
43+
orig_shape = x.shape
44+
x = x.view(-1, x.shape[-1])
45+
46+
out = torch.empty_like(x)
47+
vllm_ops.poly_norm(out, x, weight, bias, eps)
48+
output = out
49+
50+
output = output.view(orig_shape)
51+
return output
52+
53+
54+
def calculate_diff(batch_size, seq_len, hidden_dim):
55+
dtype = torch.bfloat16
56+
x = torch.randn(batch_size, seq_len, hidden_dim, dtype=dtype, device="cuda")
57+
weight = torch.ones(3, dtype=dtype, device="cuda")
58+
bias = torch.ones(1, dtype=dtype, device="cuda")
59+
60+
output_naive = polynorm_naive(x, weight, bias)
61+
output_vllm = polynorm_vllm(x, weight, bias)
62+
63+
if torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2):
64+
print("✅ All implementations match")
65+
else:
66+
print("❌ Implementations differ")
67+
68+
69+
batch_size_range = [2**i for i in range(0, 7, 2)]
70+
seq_length_range = [2**i for i in range(6, 11, 1)]
71+
dim_range = [2048, 4096]
72+
configs = list(itertools.product(dim_range, batch_size_range, seq_length_range))
73+
74+
75+
def get_benchmark():
76+
@triton.testing.perf_report(
77+
triton.testing.Benchmark(
78+
x_names=["dim", "batch_size", "seq_len"],
79+
x_vals=[list(_) for _ in configs],
80+
line_arg="provider",
81+
line_vals=["naive", "vllm"],
82+
line_names=["Naive", "vLLM"],
83+
styles=[("blue", "-"), ("red", "-")],
84+
ylabel="us",
85+
plot_name="polynorm-perf",
86+
args={},
87+
)
88+
)
89+
def benchmark(dim, batch_size, seq_len, provider):
90+
dtype = torch.bfloat16
91+
hidden_dim = dim * 4
92+
93+
x = torch.randn(batch_size, seq_len, hidden_dim, dtype=dtype, device="cuda")
94+
weight = torch.ones(3, dtype=dtype, device="cuda")
95+
bias = torch.ones(1, dtype=dtype, device="cuda")
96+
97+
quantiles = [0.5, 0.2, 0.8]
98+
99+
if provider == "naive":
100+
ms, min_ms, max_ms = triton.testing.do_bench(
101+
lambda: polynorm_naive(x, weight, bias),
102+
quantiles=quantiles,
103+
)
104+
else:
105+
ms, min_ms, max_ms = triton.testing.do_bench(
106+
lambda: polynorm_vllm(x, weight, bias),
107+
quantiles=quantiles,
108+
)
109+
110+
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
111+
112+
return benchmark
113+
114+
115+
if __name__ == "__main__":
116+
import argparse
117+
118+
parser = argparse.ArgumentParser()
119+
parser.add_argument(
120+
"--batch-size",
121+
type=int,
122+
default=4,
123+
help="Batch size",
124+
)
125+
parser.add_argument(
126+
"--seq-len",
127+
type=int,
128+
default=128,
129+
help="Sequence length",
130+
)
131+
parser.add_argument(
132+
"--hidden-dim",
133+
type=int,
134+
default=8192,
135+
help="Intermediate size of MLP",
136+
)
137+
parser.add_argument(
138+
"--save-path",
139+
type=str,
140+
default="./configs/polnorm/",
141+
help="Path to save polnorm benchmark results",
142+
)
143+
144+
args = parser.parse_args()
145+
146+
# Run correctness test
147+
calculate_diff(
148+
batch_size=args.batch_size,
149+
seq_len=args.seq_len,
150+
hidden_dim=args.hidden_dim,
151+
)
152+
153+
benchmark = get_benchmark()
154+
# Run performance benchmark
155+
benchmark.run(print_data=True, save_path=args.save_path)

0 commit comments

Comments
 (0)