File tree Expand file tree Collapse file tree 2 files changed +21
-4
lines changed 
vllm/model_executor/layers/quantization Expand file tree Collapse file tree 2 files changed +21
-4
lines changed Original file line number Diff line number Diff line change @@ -369,12 +369,9 @@ def apply(self,
369369                size_k = layer .input_size_per_partition ,
370370                bias = bias )
371371
372-         # Note: lazy import to avoid triton import error. 
373-         from  vllm .model_executor .layers .quantization .utils .fp8_utils  import  (
374-             apply_w8a8_block_fp8_linear )
375372        if  self .block_quant :
376373            assert  self .quant_config .weight_block_size  is  not None 
377-             return  apply_w8a8_block_fp8_linear (
374+             return  torch . ops . vllm . apply_w8a8_block_fp8_linear (
378375                input = x ,
379376                weight = layer .weight ,
380377                block_size = self .quant_config .weight_block_size ,
Original file line number Diff line number Diff line change 1717from  vllm .model_executor .layers .quantization .utils .w8a8_utils  import  (
1818    CUTLASS_BLOCK_FP8_SUPPORTED , CUTLASS_FP8_SUPPORTED , apply_fp8_linear )
1919from  vllm .platforms  import  current_platform 
20+ from  vllm .utils  import  direct_register_custom_op 
2021
2122logger  =  init_logger (__name__ )
2223
@@ -81,6 +82,25 @@ def apply_w8a8_block_fp8_linear(
8182    return  output .to (dtype = input .dtype ).view (* output_shape )
8283
8384
85+ def  apply_w8a8_block_fp8_linear_fake (
86+     input : torch .Tensor ,
87+     weight : torch .Tensor ,
88+     block_size : List [int ],
89+     weight_scale : torch .Tensor ,
90+     input_scale : Optional [torch .Tensor ] =  None ,
91+ ) ->  torch .Tensor :
92+     output_shape  =  [* input .shape [:- 1 ], weight .shape [0 ]]
93+     return  torch .empty (output_shape , dtype = input .dtype , device = input .device )
94+ 
95+ 
96+ direct_register_custom_op (
97+     op_name = "apply_w8a8_block_fp8_linear" ,
98+     op_func = apply_w8a8_block_fp8_linear ,
99+     mutates_args = [],
100+     fake_impl = apply_w8a8_block_fp8_linear_fake ,
101+ )
102+ 
103+ 
84104# Unify the interface between `apply_w8a8_block_fp8_linear` and 
85105# `apply_fp8_linear` 
86106# NOTE(lucas): this is quite messy, we should think through this more formally 
 
 
   
 
     
   
   
          
    
    
     
    
      
     
     
    You can’t perform that action at this time.
  
 
    
  
    
      
        
     
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments