Skip to content

Commit 46ee626

Browse files
committed
add more comprehensive testing for quantfp8 (-rmsnorm+-quant still failing)
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
1 parent 32989d8 commit 46ee626

File tree

1 file changed

+62
-27
lines changed

1 file changed

+62
-27
lines changed

tests/compile/test_fusion_all_reduce.py

Lines changed: 62 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
)
2727
from vllm.model_executor.layers.layernorm import RMSNorm
2828
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
29+
Fp8LinearOp,
2930
GroupShape,
30-
QuantFP8,
3131
)
3232
from vllm.platforms import current_platform
3333
from vllm.utils import update_environment_variables
@@ -43,9 +43,9 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
4343
self.eps = eps
4444
self.norm = RMSNorm(hidden_size, eps)
4545

46-
def forward(self, hidden_states, residual):
47-
view = hidden_states.reshape(-1, self.hidden_size)
48-
all_reduce = tensor_model_parallel_all_reduce(view)
46+
def forward(self, x):
47+
z = torch.relu(x)
48+
all_reduce = tensor_model_parallel_all_reduce(z)
4949
norm = self.norm(all_reduce)
5050
return norm
5151

@@ -63,9 +63,9 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
6363
self.eps = eps
6464
self.norm = RMSNorm(hidden_size, eps)
6565

66-
def forward(self, hidden_states, residual):
67-
view = hidden_states.reshape(-1, self.hidden_size)
68-
all_reduce = tensor_model_parallel_all_reduce(view)
66+
def forward(self, hidden_states):
67+
z = residual = torch.relu(hidden_states)
68+
all_reduce = tensor_model_parallel_all_reduce(z)
6969
norm, res = self.norm(all_reduce, residual)
7070

7171
return norm, res
@@ -77,21 +77,53 @@ def ops_in_model_after(self):
7777
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
7878

7979

80-
class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module):
80+
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
8181
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
8282
super().__init__()
8383
self.hidden_size = hidden_size
8484
self.eps = eps
85-
self.norm = RMSNorm(hidden_size, eps)
86-
self.quant_fp8 = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
87-
self.scale = torch.rand(1, dtype=torch.float32)
85+
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
86+
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
87+
self.w = [
88+
torch.rand(hidden_size, hidden_size)
89+
.to(dtype=current_platform.fp8_dtype())
90+
.t()
91+
for _ in range(3)
92+
]
8893

89-
def forward(self, hidden_states, residual):
90-
view = hidden_states.reshape(-1, self.hidden_size)
91-
all_reduce = tensor_model_parallel_all_reduce(view)
92-
norm_output, residual_output = self.norm(all_reduce, residual)
93-
quant_out, _ = self.quant_fp8(norm_output, self.scale)
94-
return quant_out, residual_output
94+
self.fp8_linear = Fp8LinearOp(
95+
act_quant_static=True,
96+
act_quant_group_shape=GroupShape.PER_TENSOR,
97+
)
98+
99+
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
100+
101+
def forward(self, hidden_states):
102+
# avoid having graph input be an arg to a pattern directly
103+
z = torch.relu(hidden_states)
104+
x = resid = tensor_model_parallel_all_reduce(z)
105+
y = self.norm[0](x)
106+
107+
z2 = self.fp8_linear.apply(
108+
y, self.w[0], self.wscale[0], input_scale=self.scale[0]
109+
)
110+
111+
x2 = tensor_model_parallel_all_reduce(z2)
112+
y2, resid = self.norm[1](x2, resid)
113+
114+
z3 = self.fp8_linear.apply(
115+
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
116+
)
117+
118+
x3 = tensor_model_parallel_all_reduce(z3)
119+
y3, resid = self.norm[2](x3, resid) # use resid here
120+
121+
z4 = self.fp8_linear.apply(
122+
y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
123+
)
124+
x4 = tensor_model_parallel_all_reduce(z4)
125+
y4, resid = self.norm[3](x4, resid) # use resid here
126+
return y4
95127

96128
def ops_in_model_after(self):
97129
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
@@ -100,7 +132,7 @@ def ops_in_model_before(self):
100132
return [
101133
torch.ops.vllm.all_reduce.default,
102134
torch.ops._C.static_scaled_fp8_quant.default
103-
if self.quant_fp8.enabled()
135+
if self.fp8_linear.quant_fp8.enabled()
104136
else torch.ops.aten.reciprocal.default,
105137
]
106138

@@ -120,11 +152,10 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
120152
rounded_n = round_up(scale_n, 4)
121153
self.output_scale = torch.empty((rounded_m, rounded_n // 4), dtype=torch.int32)
122154

123-
def forward(self, hidden_states, residual):
124-
view = hidden_states.reshape(-1, self.hidden_size)
125-
all_reduce = tensor_model_parallel_all_reduce(view)
155+
def forward(self, hidden_states):
156+
z = residual = torch.relu(hidden_states)
157+
all_reduce = tensor_model_parallel_all_reduce(z)
126158
norm_output, residual_output = self.norm(all_reduce, residual)
127-
norm_output = norm_output.reshape(-1, norm_output.shape[-1])
128159
torch.ops._C.scaled_fp4_quant(
129160
self.output, norm_output, self.output_scale, self.scale
130161
)
@@ -146,8 +177,8 @@ def ops_in_model_before(self):
146177
[
147178
(TestAllReduceRMSNormModel, False),
148179
(TestAllReduceFusedAddRMSNormModel, False),
149-
(TestAllReduceFusedAddRMSNormStaticQuantFP8Model, True),
150-
(TestAllReduceFusedAddRMSNormStaticQuantFP8Model, False),
180+
(TestAllReduceRMSNormStaticQuantFP8Model, True),
181+
(TestAllReduceRMSNormStaticQuantFP8Model, False),
151182
(TestAllReduceFusedAddRMSNormStaticQuantFP4Model, False),
152183
],
153184
)
@@ -269,12 +300,16 @@ def all_reduce_fusion_pass_on_test_model(
269300
model = test_model_cls(hidden_size, token_num)
270301

271302
hidden_states = torch.randn((token_num, hidden_size), requires_grad=False)
272-
residual = torch.randn((token_num, hidden_size), requires_grad=False)
273303

274304
compiled_model = torch.compile(model, backend=backend)
275-
compiled_model(hidden_states, residual)
305+
compiled_model(hidden_states)
276306

277-
assert all_reduce_fusion_pass.matched_count == 1
307+
# TODO cleanup
308+
expected = 4 if test_model_cls is TestAllReduceRMSNormStaticQuantFP8Model else 1
309+
310+
assert all_reduce_fusion_pass.matched_count == expected, (
311+
f"{all_reduce_fusion_pass.matched_count=}, {expected=}"
312+
)
278313
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
279314
backend.check_after_ops(model.ops_in_model_after())
280315
del all_reduce_fusion_pass

0 commit comments

Comments
 (0)