@@ -719,6 +719,8 @@ def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
719719        self .quant_dtype  =  torch .float8_e4m3fn 
720720        self .quant_fp8  =  QuantFP8 (static = True ,
721721                                  group_shape = GroupShape .PER_TENSOR )
722+         # TODO HACK 
723+         self .quant_fp8 ._forward_method  =  self .quant_fp8 .forward_native 
722724
723725    def  register (self , pm_pass : PatternMatcherPass ):
724726
@@ -729,9 +731,9 @@ def get_inputs():
729731            rmsnorm_result  =  torch .empty ([1 , 8 , 4 ],
730732                                         device = self .device ,
731733                                         dtype = self .dtype )
732-             quant_result  =  torch .empty ([1 , 8 , 4 ],
733-                                        device = self .device ,
734-                                        dtype = self .quant_dtype )
734+             #  quant_result = torch.empty([1, 8, 4],
735+             #                             device=self.device,
736+             #                             dtype=self.quant_dtype)
735737            weight  =  torch .empty ([4 ], device = self .device , dtype = self .dtype )
736738            scale  =  torch .tensor (1.0 , device = self .device , dtype = torch .float32 )
737739            return  [
@@ -807,6 +809,8 @@ def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
807809        self .quant_dtype  =  torch .float8_e4m3fn 
808810        self .quant_fp8  =  QuantFP8 (static = True ,
809811                                  group_shape = GroupShape .PER_TENSOR )
812+         # TODO HACK 
813+         self .quant_fp8 ._forward_method  =  self .quant_fp8 .forward_native 
810814
811815    def  register (self , pm_pass : PatternMatcherPass ):
812816
@@ -817,9 +821,9 @@ def get_inputs():
817821                                   device = self .device ,
818822                                   dtype = self .dtype )
819823            weight  =  torch .empty ([4 , 4 ], device = self .device , dtype = self .dtype )
820-             quant_result  =  torch .empty ([4 , 4 ],
821-                                        device = self .device ,
822-                                        dtype = self .quant_dtype )
824+             #  quant_result = torch.empty([4, 4],
825+             #                             device=self.device,
826+             #                             dtype=self.quant_dtype)
823827            scale  =  torch .empty ([1 , 1 ],
824828                                device = self .device ,
825829                                dtype = torch .float32 )
@@ -1166,6 +1170,9 @@ def __init__(self, config: VllmConfig):
11661170            # and allow multiple values of epsilon. 
11671171            torch ._inductor .pattern_matcher ._seen_patterns .clear ()
11681172
1173+         if  path  :=  config .compilation_config .debug_dump_path :
1174+             with  open (f"{ path }  , 'w' ) as  f :
1175+                 print (self .patterns .patterns , file = f )
11691176        self .disabled  =  False 
11701177
11711178    def  __call__ (self , graph : fx .Graph ):
0 commit comments