1313# limitations under the License.
1414
1515from ..utils import is_accelerate_available , is_fbgemm_gpu_available , is_torch_available , logging
16-
16+ from .. activations import ACT2FN
1717
1818if is_torch_available ():
1919 import torch
2828logger = logging .get_logger (__name__ )
2929
3030
31- class FbgemmFp8Linear (torch .nn .Module ):
31+ class FbgemmFp8Linear (torch .nn .Linear ):
3232 def __init__ (self , in_features , out_features , bias , weight_dtype = torch .float32 ):
33- super ().__init__ ()
33+ super ().__init__ (in_features , out_features , bias )
3434 self .in_features = in_features
3535 self .out_features = out_features
3636
37- self .register_buffer ( " weight" , torch .zeros ((out_features , in_features ), dtype = torch .float8_e4m3fn ))
38- self .register_buffer ( " weight_scale" , torch .zeros ((out_features , 1 ), dtype = weight_dtype ))
37+ self .weight = torch . nn . Parameter ( torch .zeros ((out_features , in_features ), dtype = torch .float8_e4m3fn ))
38+ self .weight_scale = torch . nn . Parameter ( torch .zeros ((out_features , 1 ), dtype = weight_dtype ))
3939 self .register_buffer ("input_scale_ub" , torch .zeros ([1 ], dtype = torch .float ), persistent = False )
4040
4141 if bias :
42- self .register_buffer ( " bias" , torch .zeros ((self .out_features ), dtype = weight_dtype ))
42+ self .bias = torch . nn . Parameter ( torch .zeros ((self .out_features ), dtype = weight_dtype ))
4343 else :
4444 self .bias = None
4545
@@ -50,15 +50,16 @@ def forward(self, x):
5050 # x_quantized and x_scale are not necessarily on the same device as x, this is an issue.
5151 # https://github.com/pytorch/FBGEMM/blob/e08af8539c391437f447173863df0f3f6f6f1855/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L1237C3-L1237C45
5252 x_quantized , x_scale = torch .ops .fbgemm .quantize_fp8_per_row (
53- x .view (- 1 , x .shape [- 1 ]), num_tokens , self .input_scale_ub
53+ x .view (- 1 , x .shape [- 1 ]), scale_ub = self .input_scale_ub
5454 )
5555 # moving x_quantized, x_scale here creates glibberish output ... However, if we move the output, it works
5656 # x_quantized, x_scale = x_quantized.to(x.device), x_scale.to(x.device)
5757
5858 # The computation still happens on the device where self.weight is even if x_quantized is not on the same device as self.weight
59+ weight_scale_float32 = self .weight_scale .to (torch .float32 )
5960 output = torch .ops .fbgemm .f8f8bf16_rowwise (
60- x_quantized , self .weight , x_scale , self . weight_scale , use_fast_accum = True
61- )
61+ x_quantized , self .weight , x_scale , weight_scale_float32 , use_fast_accum = True
62+ )
6263 output = output + self .bias if self .bias is not None else output
6364 # Hacky for now, we have the output to the device of x
6465 output = output .to (x .device )
@@ -67,19 +68,104 @@ def forward(self, x):
6768 return output
6869
6970
71+ class FbgemmFp8Llama4TextExperts (nn .Module ):
72+ def __init__ (self , config , dtype = torch .float32 ):
73+ super ().__init__ ()
74+ self .num_experts = config .num_local_experts
75+ self .intermediate_size = config .intermediate_size
76+ self .hidden_size = config .hidden_size
77+ self .expert_dim = self .intermediate_size
78+ self .act_fn = ACT2FN [config .hidden_act ]
79+ # Register FP8 buffers for gate_up_proj
80+ self .gate_up_proj = torch .nn .Parameter (torch .zeros ((self .num_experts , self .hidden_size , 2 * self .expert_dim ), dtype = torch .float8_e4m3fn ))
81+ self .gate_up_proj_scale = torch .nn .Parameter (torch .zeros ((self .num_experts , 1 , self .expert_dim * 2 ), dtype = torch .float32 ))
82+ # Register FP8 buffers for down_proj
83+ self .down_proj = torch .nn .Parameter (torch .zeros ((self .num_experts , self .expert_dim , self .hidden_size ), dtype = torch .float8_e4m3fn ))
84+ self .down_proj_scale = torch .nn .Parameter (torch .zeros ((self .num_experts , self .hidden_size , 1 ), dtype = torch .float32 ))
85+ # Register input scale upper bound
86+ self .register_buffer ("input_scale_ub" , torch .zeros ([1 ], dtype = torch .float ), persistent = False )
87+
88+
89+ def forward (self , hidden_states ):
90+ """
91+ Args:
92+ hidden_states (torch.Tensor): (batch_size * token_num, hidden_size)
93+ Returns:
94+ torch.Tensor: (batch_size * token_num, hidden_size)
95+ """
96+ # Reshape hidden states for expert computation
97+ hidden_states = hidden_states .view (self .num_experts , - 1 , self .hidden_size )
98+ num_tokens = None
99+
100+ # Pre-allocate tensor for all expert outputs with same shape as hidden_states
101+ next_states = torch .empty_like (hidden_states )
102+
103+ for i in range (self .num_experts ):
104+ # Extract expert's hidden states
105+ expert_hidden = hidden_states [i ]
106+ expert_hidden_reshaped = expert_hidden .reshape (- 1 , self .hidden_size )
107+ # Quantize for this expert
108+ expert_quantized , expert_scale = torch .ops .fbgemm .quantize_fp8_per_row (
109+ expert_hidden_reshaped , num_tokens , self .input_scale_ub
110+ )
111+ sharded_expert_dim = self .gate_up_proj .shape [- 1 ] // 2
112+ gate_up_proj_scale_float32 = self .gate_up_proj_scale .to (torch .float32 )
113+
114+ gate = torch .ops .fbgemm .f8f8bf16_rowwise (
115+ expert_quantized ,
116+ self .gate_up_proj [i ].transpose (0 ,1 )[:sharded_expert_dim ].contiguous (),
117+ expert_scale ,
118+ gate_up_proj_scale_float32 [i ][0 ][:sharded_expert_dim ].view (- 1 , 1 ).contiguous (),
119+ use_fast_accum = True
120+ )
121+
122+ up = torch .ops .fbgemm .f8f8bf16_rowwise (
123+ expert_quantized ,
124+ self .gate_up_proj [i ].transpose (0 ,1 )[sharded_expert_dim :].contiguous (),
125+ expert_scale ,
126+ gate_up_proj_scale_float32 [i ][0 ][sharded_expert_dim :].view (- 1 , 1 ).contiguous (),
127+ use_fast_accum = True
128+ )
129+
130+ activated = up * self .act_fn (gate )
131+
132+ activated_quantized , activated_scale = torch .ops .fbgemm .quantize_fp8_per_row (
133+ activated , num_tokens , self .input_scale_ub
134+ )
135+
136+ down_proj_scale_float32 = self .down_proj_scale .to (torch .float32 )
137+ expert_output = torch .ops .fbgemm .f8f8bf16_rowwise (
138+ activated_quantized ,
139+ self .down_proj [i ].transpose (0 ,1 ).contiguous (),
140+ activated_scale ,
141+ down_proj_scale_float32 [i ].view (- 1 , 1 ).contiguous (),
142+ use_fast_accum = True
143+ )
144+
145+ next_states [i ] = expert_output
146+ next_states = next_states .to (hidden_states .device )
147+ return next_states .view (- 1 , self .hidden_size )
148+
149+
70150def _replace_with_fbgemm_fp8_linear (
71151 model ,
72152 modules_to_not_convert = None ,
73153 current_key_name = None ,
74154 quantization_config = None ,
75155 has_been_replaced = False ,
76156 pre_quantized = False ,
157+ config = None ,
158+ tp_plan = None
77159):
78160 """
79161 Private method that wraps the recursion for module replacement.
80162
81163 Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
82164 """
165+
166+ from transformers .models .llama4 .modeling_llama4 import Llama4TextExperts
167+ import re
168+
83169 if current_key_name is None :
84170 current_key_name = []
85171
@@ -105,9 +191,24 @@ def _replace_with_fbgemm_fp8_linear(
105191 # Force requires grad to False to avoid unexpected errors
106192 model ._modules [name ].requires_grad_ (False )
107193 # set non persistant buffer outside of init_empty_weights
194+ model ._modules [name ].input_scale_ub = torch .tensor (
195+ [quantization_config .activation_scale_ub ], dtype = torch .float ,
196+ )
197+ if module .__class__ .__name__ == "Llama4TextExperts" and name not in modules_to_not_convert :
198+ current_key_name_str = "." .join (current_key_name )
199+ if not any (
200+ (key + "." in current_key_name_str ) or (key == current_key_name_str ) for key in modules_to_not_convert
201+ ):
202+ with init_empty_weights (include_buffers = True ):
203+ tp_plan [re .sub (r"\d+" , "*" , current_key_name_str + ".gate_up_proj_scale" )] = tp_plan [re .sub (r"\d+" , "*" , current_key_name_str + ".gate_up_proj" )]
204+ tp_plan [re .sub (r"\d+" , "*" , current_key_name_str + ".down_proj_scale" )] = None
205+ model ._modules [name ] = FbgemmFp8Llama4TextExperts (
206+ config .text_config ,
207+ )
108208 model ._modules [name ].input_scale_ub = torch .tensor (
109209 [quantization_config .activation_scale_ub ], dtype = torch .float
110210 )
211+
111212 if len (list (module .children ())) > 0 :
112213 _ , has_been_replaced = _replace_with_fbgemm_fp8_linear (
113214 module ,
@@ -116,14 +217,16 @@ def _replace_with_fbgemm_fp8_linear(
116217 quantization_config ,
117218 has_been_replaced = has_been_replaced ,
118219 pre_quantized = pre_quantized ,
220+ config = config ,
221+ tp_plan = tp_plan
119222 )
120223 # Remove the last key for recursion
121224 current_key_name .pop (- 1 )
122225 return model , has_been_replaced
123226
124227
125228def replace_with_fbgemm_fp8_linear (
126- model , modules_to_not_convert = None , current_key_name = None , quantization_config = None , pre_quantized = False
229+ model , modules_to_not_convert = None , current_key_name = None , quantization_config = None , pre_quantized = False , config = None , tp_plan = None
127230):
128231 """
129232 A helper function to replace all `torch.nn.Linear` modules by `FbgemmFp8Linear` modules.
@@ -151,9 +254,8 @@ def replace_with_fbgemm_fp8_linear(
151254 modules_to_not_convert .extend (quantization_config .modules_to_not_convert )
152255 modules_to_not_convert = list (set (modules_to_not_convert ))
153256 model , has_been_replaced = _replace_with_fbgemm_fp8_linear (
154- model , modules_to_not_convert , current_key_name , quantization_config , pre_quantized = pre_quantized
257+ model , modules_to_not_convert , current_key_name , quantization_config , pre_quantized = pre_quantized , config = config , tp_plan = tp_plan
155258 )
156-
157259 if not has_been_replaced :
158260 logger .warning (
159261 "You are loading your model using FP8 quantization but no linear modules were found in your model."
0 commit comments