22from typing import Callable , List , Optional
33
44import torch
5- import torch .nn .functional as F
65from torch .nn .parameter import Parameter
76
87from vllm .model_executor .layers .quantization .compressed_tensors .schemes import (
98 CompressedTensorsScheme )
10- from vllm .model_executor .layers .quantization .utils .nvfp4_emulation_utils import ( # noqa: E501
11- dequantize_to_dtype )
9+ from vllm .model_executor .layers .quantization .utils .marlin_utils_fp4 import (
10+ apply_fp4_marlin_linear , prepare_fp4_layer_for_marlin )
1211from vllm .model_executor .parameter import (GroupQuantScaleParameter ,
1312 ModelWeightParameter ,
1413 PerTensorScaleParameter )
@@ -31,6 +30,10 @@ def create_weights(self, layer: torch.nn.Module,
3130 input_size_per_partition : int ,
3231 params_dtype : torch .dtype , weight_loader : Callable ,
3332 ** kwargs ):
33+ output_size_per_partition = sum (output_partition_sizes )
34+ layer .logical_widths = output_partition_sizes
35+ layer .input_size_per_partition = input_size_per_partition
36+ layer .output_size_per_partition = output_size_per_partition
3437
3538 # Weight
3639 weight = ModelWeightParameter (data = torch .empty (
@@ -60,48 +63,30 @@ def create_weights(self, layer: torch.nn.Module,
6063
6164 layer .register_parameter ("weight_scale" , weight_scale )
6265
63- def swizzle_blockscale (self , scale : torch .tensor ):
64- assert (scale .dtype == torch .float8_e4m3fn )
65- # Pad and blockwise interleave weight_scale
66- scale_ndim = scale .ndim
67- if scale .ndim == 2 :
68- scale = scale .unsqueeze (0 )
69- assert scale .ndim == 3
70- B , M , K = scale .shape
71- round_up_multiple = lambda x , m : (x + m - 1 ) // m * m
72- M_padded = round_up_multiple (M , 128 )
73- K_padded = round_up_multiple (K , 4 )
74- padded_scale = torch .zeros ((B , M_padded , K_padded ), dtype = scale .dtype )
75- padded_scale [:B , :M , :K ] = scale
76- batches , rows , cols = padded_scale .shape
77- assert rows % 128 == 0
78- assert cols % 4 == 0
79- padded_scale = padded_scale .reshape (batches , rows // 128 , 4 , 32 ,
80- cols // 4 , 4 )
81- swizzled_scale = padded_scale .permute ((0 , 1 , 4 , 3 , 2 , 5 ))
82- swizzled_scale = swizzled_scale .contiguous ().cuda ()
83- return (swizzled_scale .reshape (M , K )
84- if scale_ndim == 2 else swizzled_scale .reshape (B , M , K ))
85-
8666 def process_weights_after_loading (self , layer ) -> None :
87- layer .weight_global_scale = Parameter (
88- layer .weight_global_scale .max ().to (torch .float32 ),
67+ # Process parameters for marlin repacking
68+
69+ # Rename weight_packed to weight that marlin expects
70+ layer .weight = Parameter (layer .weight_packed .data , requires_grad = False )
71+ del layer .weight_packed
72+ # Rename weight_global_scale to weight_scale_2 that marlin expects
73+ # Note: ct stores the inverse of what is expected by the marlin kernel
74+ layer .weight_scale_2 = Parameter (
75+ 1 / layer .weight_global_scale .max ().to (torch .float32 ),
8976 requires_grad = False )
90- # Note: a post weight loading step but not required for the emulation
91- swizzled_weight_scale = self .swizzle_blockscale (layer .weight_scale )
92- layer .weight_scale_swizzled = Parameter (swizzled_weight_scale ,
93- requires_grad = False )
77+ del layer .weight_global_scale
78+
79+ prepare_fp4_layer_for_marlin (layer )
9480
9581 def apply_weights (self ,
9682 layer : torch .nn .Module ,
9783 x : torch .Tensor ,
9884 bias : Optional [torch .Tensor ] = None ) -> torch .Tensor :
99-
100- w_fp4 = layer .weight_packed .data
101- w_global_scale = layer .weight_global_scale
102- w_blockscale = layer .weight_scale_swizzled .data
103- w_dq = dequantize_to_dtype (w_fp4 , w_blockscale , w_global_scale ,
104- x .dtype , x .device , self .group_size )
105- out = F .linear (x , w_dq )
106- del w_dq , w_fp4 , w_global_scale , w_blockscale
107- return out
85+ return apply_fp4_marlin_linear (input = x ,
86+ weight = layer .weight ,
87+ weight_scale = layer .weight_scale ,
88+ weight_scale_2 = layer .weight_scale_2 ,
89+ workspace = layer .workspace ,
90+ size_n = layer .output_size_per_partition ,
91+ size_k = layer .input_size_per_partition ,
92+ bias = bias )
0 commit comments