22from  typing  import  Callable , List , Optional 
33
44import  torch 
5- import  torch .nn .functional  as  F 
65from  torch .nn .parameter  import  Parameter 
76
87from  vllm .model_executor .layers .quantization .compressed_tensors .schemes  import  (
98    CompressedTensorsScheme )
10- from  vllm .model_executor .layers .quantization .utils .nvfp4_emulation_utils  import  (   # noqa: E501 
11-     dequantize_to_dtype )
9+ from  vllm .model_executor .layers .quantization .utils .marlin_utils_fp4  import  (
10+     apply_fp4_marlin_linear ,  prepare_fp4_layer_for_marlin )
1211from  vllm .model_executor .parameter  import  (GroupQuantScaleParameter ,
1312                                           ModelWeightParameter ,
1413                                           PerTensorScaleParameter )
@@ -31,6 +30,10 @@ def create_weights(self, layer: torch.nn.Module,
3130                       input_size_per_partition : int ,
3231                       params_dtype : torch .dtype , weight_loader : Callable ,
3332                       ** kwargs ):
33+         output_size_per_partition  =  sum (output_partition_sizes )
34+         layer .logical_widths  =  output_partition_sizes 
35+         layer .input_size_per_partition  =  input_size_per_partition 
36+         layer .output_size_per_partition  =  output_size_per_partition 
3437
3538        # Weight 
3639        weight  =  ModelWeightParameter (data = torch .empty (
@@ -60,48 +63,30 @@ def create_weights(self, layer: torch.nn.Module,
6063
6164        layer .register_parameter ("weight_scale" , weight_scale )
6265
63-     def  swizzle_blockscale (self , scale : torch .tensor ):
64-         assert  (scale .dtype  ==  torch .float8_e4m3fn )
65-         # Pad and blockwise interleave weight_scale 
66-         scale_ndim  =  scale .ndim 
67-         if  scale .ndim  ==  2 :
68-             scale  =  scale .unsqueeze (0 )
69-         assert  scale .ndim  ==  3 
70-         B , M , K  =  scale .shape 
71-         round_up_multiple  =  lambda  x , m : (x  +  m  -  1 ) //  m  *  m 
72-         M_padded  =  round_up_multiple (M , 128 )
73-         K_padded  =  round_up_multiple (K , 4 )
74-         padded_scale  =  torch .zeros ((B , M_padded , K_padded ), dtype = scale .dtype )
75-         padded_scale [:B , :M , :K ] =  scale 
76-         batches , rows , cols  =  padded_scale .shape 
77-         assert  rows  %  128  ==  0 
78-         assert  cols  %  4  ==  0 
79-         padded_scale  =  padded_scale .reshape (batches , rows  //  128 , 4 , 32 ,
80-                                             cols  //  4 , 4 )
81-         swizzled_scale  =  padded_scale .permute ((0 , 1 , 4 , 3 , 2 , 5 ))
82-         swizzled_scale  =  swizzled_scale .contiguous ().cuda ()
83-         return  (swizzled_scale .reshape (M , K )
84-                 if  scale_ndim  ==  2  else  swizzled_scale .reshape (B , M , K ))
85- 
8666    def  process_weights_after_loading (self , layer ) ->  None :
87-         layer .weight_global_scale  =  Parameter (
88-             layer .weight_global_scale .max ().to (torch .float32 ),
67+         # Process parameters for marlin repacking 
68+ 
69+         # Rename weight_packed to weight that marlin expects 
70+         layer .weight  =  Parameter (layer .weight_packed .data , requires_grad = False )
71+         del  layer .weight_packed 
72+         # Rename weight_global_scale to weight_scale_2 that marlin expects 
73+         # Note: ct stores the inverse of what is expected by the marlin kernel 
74+         layer .weight_scale_2  =  Parameter (
75+             1  /  layer .weight_global_scale .max ().to (torch .float32 ),
8976            requires_grad = False )
90-         # Note: a post weight loading step but not required for the emulation 
91-         swizzled_weight_scale  =  self .swizzle_blockscale (layer .weight_scale )
92-         layer .weight_scale_swizzled  =  Parameter (swizzled_weight_scale ,
93-                                                 requires_grad = False )
77+         del  layer .weight_global_scale 
78+ 
79+         prepare_fp4_layer_for_marlin (layer )
9480
9581    def  apply_weights (self ,
9682                      layer : torch .nn .Module ,
9783                      x : torch .Tensor ,
9884                      bias : Optional [torch .Tensor ] =  None ) ->  torch .Tensor :
99- 
100-         w_fp4  =  layer .weight_packed .data 
101-         w_global_scale  =  layer .weight_global_scale 
102-         w_blockscale  =  layer .weight_scale_swizzled .data 
103-         w_dq  =  dequantize_to_dtype (w_fp4 , w_blockscale , w_global_scale ,
104-                                    x .dtype , x .device , self .group_size )
105-         out  =  F .linear (x , w_dq )
106-         del  w_dq , w_fp4 , w_global_scale , w_blockscale 
107-         return  out 
85+         return  apply_fp4_marlin_linear (input = x ,
86+                                        weight = layer .weight ,
87+                                        weight_scale = layer .weight_scale ,
88+                                        weight_scale_2 = layer .weight_scale_2 ,
89+                                        workspace = layer .workspace ,
90+                                        size_n = layer .output_size_per_partition ,
91+                                        size_k = layer .input_size_per_partition ,
92+                                        bias = bias )
0 commit comments