Skip to content

Commit 8e4a56f

Browse files
committed
rms works fully now, had to remove more conversions (and add them in replacements). TODO pass to remove unnecessary conversions?
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
1 parent cdad3c0 commit 8e4a56f

File tree

4 files changed

+61
-32
lines changed

4 files changed

+61
-32
lines changed

csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,11 @@ void rms_norm_dynamic_per_token_quant(
145145
if (scale_ub.has_value()) {
146146
TORCH_CHECK(out.dtype() == kFp8Type);
147147
}
148+
TORCH_CHECK(weight.dtype() == input.dtype());
148149
TORCH_CHECK(scales.dtype() == torch::kFloat32);
150+
if (residual) {
151+
TORCH_CHECK(residual->scalar_type() == input.scalar_type());
152+
}
149153

150154
VLLM_DISPATCH_FLOATING_TYPES(
151155
input.scalar_type(), "rms_norm_dynamic_per_token_quant_dispatch", [&] {

tests/compile/test_fusion.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
FusedRMSQuantKey, RMSNormQuantFusionPass)
1010
from vllm.compilation.noop_elimination import NoOpEliminationPass
1111
from vllm.compilation.post_cleanup import PostCleanupPass
12-
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
13-
VllmConfig)
12+
from vllm.config import (CompilationConfig, CompilationLevel, ModelConfig,
13+
PassConfig, VllmConfig)
1414
from vllm.model_executor.layers.layernorm import RMSNorm
1515
from vllm.model_executor.layers.quantization.utils.quant_utils import (
1616
GroupShape, QuantKey, ScaleDesc)
@@ -119,13 +119,16 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
119119
custom_ops.append("+rms_norm")
120120
if enable_quant_fp8:
121121
custom_ops.append("+quant_fp8")
122-
vllm_config = VllmConfig(compilation_config=CompilationConfig(
123-
debug_dump_path=f"/home/luka/git/vllm/._workspace/"
124-
f"debug_dump_{enable_rms_norm}_{enable_quant_fp8}",
125-
level=CompilationLevel.PIECEWISE,
126-
custom_ops=custom_ops,
127-
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
128-
))
122+
vllm_config = VllmConfig(
123+
model_config=ModelConfig(dtype=dtype),
124+
compilation_config=CompilationConfig(
125+
debug_dump_path=f"/home/luka/git/vllm/._workspace/"
126+
f"debug_dump_{enable_rms_norm}_{enable_quant_fp8}",
127+
level=CompilationLevel.PIECEWISE,
128+
custom_ops=custom_ops,
129+
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
130+
),
131+
)
129132
with vllm.config.set_current_vllm_config(vllm_config):
130133
# Reshape pass is needed for the fusion pass to work
131134
noop_pass = NoOpEliminationPass(vllm_config)

vllm/compilation/fusion.py

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from torch._inductor.pattern_matcher import PatternMatcherPass
1010
from torch._ops import OpOverload
1111

12-
from vllm.config import VllmConfig
12+
from vllm.config import VllmConfig, set_current_vllm_config
1313
from vllm.logger import init_logger
1414
from vllm.model_executor.layers.quantization.utils.quant_utils import (
1515
GroupShape, QuantKey, ScaleDesc, kFp8DynamicTensorSym, kFp8DynamicTokenSym,
@@ -117,6 +117,10 @@ def pattern(input: torch.Tensor, weight: torch.Tensor,
117117

118118
def replacement(input: torch.Tensor, weight: torch.Tensor,
119119
scale: torch.Tensor):
120+
# In case we're matching native rms-norm, conversions might be
121+
# optimized out. We convert here just to be safe.
122+
input = input.to(dtype=torch.float16) # TODO model dtype
123+
120124
result = torch.empty_like(input, dtype=self.quant_dtype)
121125
at = auto_functionalized(self.FUSED_OP,
122126
result=result,
@@ -130,7 +134,7 @@ def replacement(input: torch.Tensor, weight: torch.Tensor,
130134

131135
inputs = [
132136
empty_bf16(5, 4), # input
133-
empty_bf16(4,), # weight
137+
empty_bf16(4, ), # weight
134138
empty_fp32(1, 1) # scale
135139
]
136140
pattern(*inputs)
@@ -163,6 +167,11 @@ def pattern(input: torch.Tensor, residual: torch.Tensor,
163167

164168
def replacement(input: torch.Tensor, residual: torch.Tensor,
165169
weight: torch.Tensor, scale: torch.Tensor):
170+
# In case we're matching native rms-norm, conversions might be
171+
# optimized out. We convert here just to be safe.
172+
input = input.to(dtype=torch.float16) # TODO model dtype
173+
residual = residual.to(dtype=torch.float16)
174+
166175
result = torch.empty_like(input, dtype=self.quant_dtype)
167176
at = auto_functionalized(self.FUSED_OP,
168177
result=result,
@@ -176,9 +185,11 @@ def replacement(input: torch.Tensor, residual: torch.Tensor,
176185
return at[1], at[2]
177186

178187
inputs = [
188+
# TODO: maybe 32bit for torch impl?
189+
# TODO dtype doesn't seem to matter?
179190
empty_bf16(5, 4), # input
180191
empty_bf16(5, 4), # residual
181-
empty_bf16(4, ), # weight
192+
empty_bf16(4, ), # weight
182193
empty_fp32(1, 1) # scale
183194
]
184195

@@ -213,6 +224,10 @@ def pattern(input: torch.Tensor, weight: torch.Tensor):
213224
return self.quant_matcher(result_rms)
214225

215226
def replacement(input: torch.Tensor, weight: torch.Tensor):
227+
# In case we're matching native rms-norm, conversions might be
228+
# optimized out. We convert here just to be safe.
229+
input = input.to(dtype=torch.float16) # TODO model dtype
230+
216231
result = torch.empty_like(input, dtype=self.quant_dtype)
217232
scale = self.quant_matcher.make_scale(input)
218233
at = auto_functionalized(self.FUSED_OP,
@@ -267,6 +282,11 @@ def pattern(input: torch.Tensor, residual: torch.Tensor,
267282

268283
def replacement(input: torch.Tensor, residual: torch.Tensor,
269284
weight: torch.Tensor):
285+
# In case we're matching native rms-norm, conversions might be
286+
# optimized out. We convert here just to be safe.
287+
input = input.to(dtype=torch.float16) # TODO model dtype
288+
residual = residual.to(dtype=torch.float16)
289+
270290
result = torch.empty_like(input, dtype=self.quant_dtype)
271291
scale = self.quant_matcher.make_scale(input)
272292
at = auto_functionalized(self.FUSED_OP,
@@ -309,22 +329,23 @@ def __init__(self, config: VllmConfig):
309329
self.patterns: PatternMatcherPass = PatternMatcherPass(
310330
pass_name="rmsnorm_quant_fusion_pass")
311331

312-
for epsilon in [1e-5, 1e-6]:
313-
# Fuse rms_norm + static fp8 quant
314-
RMSNormStaticQuantPattern(epsilon,
315-
FP8_DTYPE).register(self.patterns)
332+
with set_current_vllm_config(config, check_compile=False):
333+
for epsilon in [1e-5, 1e-6]:
334+
# Fuse rms_norm + static fp8 quant
335+
RMSNormStaticQuantPattern(epsilon,
336+
FP8_DTYPE).register(self.patterns)
316337

317-
# Fuse fused_add_rms_norm + static fp8 quant
318-
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
319-
self.patterns)
338+
# Fuse fused_add_rms_norm + static fp8 quant
339+
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
340+
self.patterns)
320341

321-
# Fuse rms_norm + dynamic per-token fp8 quant
322-
RMSNormDynamicQuantPattern(epsilon,
323-
FP8_DTYPE).register(self.patterns)
342+
# Fuse rms_norm + dynamic per-token fp8 quant
343+
RMSNormDynamicQuantPattern(epsilon,
344+
FP8_DTYPE).register(self.patterns)
324345

325-
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
326-
FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
327-
self.patterns)
346+
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
347+
FusedAddRMSNormDynamicQuantPattern(
348+
epsilon, FP8_DTYPE).register(self.patterns)
328349

329350
self.dump_patterns(config, self.patterns)
330351

vllm/compilation/matcher_utils.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from torch._higher_order_ops import auto_functionalized
77
from torch._ops import OpOverload
88

9+
from vllm.config import get_current_vllm_config
910
from vllm.model_executor.layers.layernorm import RMSNorm
1011
from vllm.model_executor.layers.quantization.utils.quant_utils import (
1112
QuantKey, _normalize_quant_group_shape, kFp8DynamicTensorSym,
@@ -29,16 +30,18 @@
2930
# kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
3031

3132

32-
class MatcherRMSNorm: # TODO separate residual and not residual
33+
class MatcherRMSNorm: # TODO separate residual and not residual
3334

3435
def __init__(self, epsilon: float, enabled: Optional[bool] = None):
3536
self.epsilon = epsilon
3637

3738
if enabled is None:
38-
# TODO either pass config to enabled or set it globally (global during pass init seems reasonable)
39+
# TODO either pass config to enabled or set it globally
40+
# (global during pass init seems reasonable)
3941
enabled = RMSNorm.enabled()
4042

4143
self.forward = self.forward_custom if enabled else self.forward_native
44+
self.model_dtype = get_current_vllm_config().model_config.dtype
4245

4346
def forward_custom(
4447
self,
@@ -72,22 +75,20 @@ def forward_native(
7275
weight: torch.Tensor,
7376
residual: Optional[torch.Tensor] = None,
7477
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
75-
orig_dtype = input.dtype
76-
x = input.to(torch.float32)
78+
x = input # .to(torch.float32)
7779
if residual is not None:
7880
x = x + residual.to(torch.float32)
79-
residual = x
81+
residual = x # conversion to 16-bit is eliminated in full graph
8082

8183
variance = x.pow(2).mean(dim=-1, keepdim=True)
8284

8385
x = x * torch.rsqrt(variance + self.epsilon)
84-
x = x.to(orig_dtype)
86+
x = x.to(self.model_dtype)
8587
if weight is not None:
8688
x = x * weight
8789

8890
return x if residual is None else (x, residual)
8991

90-
9192
def __call__(
9293
self,
9394
input: torch.Tensor,

0 commit comments

Comments
 (0)