99from torch ._inductor .pattern_matcher import PatternMatcherPass
1010from torch ._ops import OpOverload
1111
12- from vllm .config import VllmConfig
12+ from vllm .config import VllmConfig , set_current_vllm_config
1313from vllm .logger import init_logger
1414from vllm .model_executor .layers .quantization .utils .quant_utils import (
1515 GroupShape , QuantKey , ScaleDesc , kFp8DynamicTensorSym , kFp8DynamicTokenSym ,
@@ -117,6 +117,10 @@ def pattern(input: torch.Tensor, weight: torch.Tensor,
117117
118118 def replacement (input : torch .Tensor , weight : torch .Tensor ,
119119 scale : torch .Tensor ):
120+ # In case we're matching native rms-norm, conversions might be
121+ # optimized out. We convert here just to be safe.
122+ input = input .to (dtype = torch .float16 ) # TODO model dtype
123+
120124 result = torch .empty_like (input , dtype = self .quant_dtype )
121125 at = auto_functionalized (self .FUSED_OP ,
122126 result = result ,
@@ -130,7 +134,7 @@ def replacement(input: torch.Tensor, weight: torch.Tensor,
130134
131135 inputs = [
132136 empty_bf16 (5 , 4 ), # input
133- empty_bf16 (4 ,), # weight
137+ empty_bf16 (4 , ), # weight
134138 empty_fp32 (1 , 1 ) # scale
135139 ]
136140 pattern (* inputs )
@@ -163,6 +167,11 @@ def pattern(input: torch.Tensor, residual: torch.Tensor,
163167
164168 def replacement (input : torch .Tensor , residual : torch .Tensor ,
165169 weight : torch .Tensor , scale : torch .Tensor ):
170+ # In case we're matching native rms-norm, conversions might be
171+ # optimized out. We convert here just to be safe.
172+ input = input .to (dtype = torch .float16 ) # TODO model dtype
173+ residual = residual .to (dtype = torch .float16 )
174+
166175 result = torch .empty_like (input , dtype = self .quant_dtype )
167176 at = auto_functionalized (self .FUSED_OP ,
168177 result = result ,
@@ -176,9 +185,11 @@ def replacement(input: torch.Tensor, residual: torch.Tensor,
176185 return at [1 ], at [2 ]
177186
178187 inputs = [
188+ # TODO: maybe 32bit for torch impl?
189+ # TODO dtype doesn't seem to matter?
179190 empty_bf16 (5 , 4 ), # input
180191 empty_bf16 (5 , 4 ), # residual
181- empty_bf16 (4 , ), # weight
192+ empty_bf16 (4 , ), # weight
182193 empty_fp32 (1 , 1 ) # scale
183194 ]
184195
@@ -213,6 +224,10 @@ def pattern(input: torch.Tensor, weight: torch.Tensor):
213224 return self .quant_matcher (result_rms )
214225
215226 def replacement (input : torch .Tensor , weight : torch .Tensor ):
227+ # In case we're matching native rms-norm, conversions might be
228+ # optimized out. We convert here just to be safe.
229+ input = input .to (dtype = torch .float16 ) # TODO model dtype
230+
216231 result = torch .empty_like (input , dtype = self .quant_dtype )
217232 scale = self .quant_matcher .make_scale (input )
218233 at = auto_functionalized (self .FUSED_OP ,
@@ -267,6 +282,11 @@ def pattern(input: torch.Tensor, residual: torch.Tensor,
267282
268283 def replacement (input : torch .Tensor , residual : torch .Tensor ,
269284 weight : torch .Tensor ):
285+ # In case we're matching native rms-norm, conversions might be
286+ # optimized out. We convert here just to be safe.
287+ input = input .to (dtype = torch .float16 ) # TODO model dtype
288+ residual = residual .to (dtype = torch .float16 )
289+
270290 result = torch .empty_like (input , dtype = self .quant_dtype )
271291 scale = self .quant_matcher .make_scale (input )
272292 at = auto_functionalized (self .FUSED_OP ,
@@ -309,22 +329,23 @@ def __init__(self, config: VllmConfig):
309329 self .patterns : PatternMatcherPass = PatternMatcherPass (
310330 pass_name = "rmsnorm_quant_fusion_pass" )
311331
312- for epsilon in [1e-5 , 1e-6 ]:
313- # Fuse rms_norm + static fp8 quant
314- RMSNormStaticQuantPattern (epsilon ,
315- FP8_DTYPE ).register (self .patterns )
332+ with set_current_vllm_config (config , check_compile = False ):
333+ for epsilon in [1e-5 , 1e-6 ]:
334+ # Fuse rms_norm + static fp8 quant
335+ RMSNormStaticQuantPattern (epsilon ,
336+ FP8_DTYPE ).register (self .patterns )
316337
317- # Fuse fused_add_rms_norm + static fp8 quant
318- FusedAddRMSNormStaticQuantPattern (epsilon , FP8_DTYPE ).register (
319- self .patterns )
338+ # Fuse fused_add_rms_norm + static fp8 quant
339+ FusedAddRMSNormStaticQuantPattern (epsilon , FP8_DTYPE ).register (
340+ self .patterns )
320341
321- # Fuse rms_norm + dynamic per-token fp8 quant
322- RMSNormDynamicQuantPattern (epsilon ,
323- FP8_DTYPE ).register (self .patterns )
342+ # Fuse rms_norm + dynamic per-token fp8 quant
343+ RMSNormDynamicQuantPattern (epsilon ,
344+ FP8_DTYPE ).register (self .patterns )
324345
325- # Fuse fused_add_rms_norm + dynamic per-token fp8 quant
326- FusedAddRMSNormDynamicQuantPattern ( epsilon , FP8_DTYPE ). register (
327- self .patterns )
346+ # Fuse fused_add_rms_norm + dynamic per-token fp8 quant
347+ FusedAddRMSNormDynamicQuantPattern (
348+ epsilon , FP8_DTYPE ). register ( self .patterns )
328349
329350 self .dump_patterns (config , self .patterns )
330351
0 commit comments