Skip to content

Commit aa8daba

Browse files
authored
Merge pull request huggingface#21 from huggingface/add_fbgemm
Adding fbgemm
2 parents 99f2297 + bdfb573 commit aa8daba

File tree

8 files changed

+221
-30
lines changed

8 files changed

+221
-30
lines changed

src/transformers/integrations/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
"unset_hf_deepspeed_config",
5454
],
5555
"eetq": ["replace_with_eetq_linear"],
56-
"fbgemm_fp8": ["FbgemmFp8Linear", "replace_with_fbgemm_fp8_linear"],
56+
"fbgemm_fp8": ["FbgemmFp8Linear", "replace_with_fbgemm_fp8_linear", "FbgemmFp8Llama4TextExperts"],
5757
"finegrained_fp8": ["FP8Linear", "replace_with_fp8_linear"],
5858
"fsdp": ["is_fsdp_managed_module"],
5959
"ggml": [
@@ -192,7 +192,7 @@
192192
unset_hf_deepspeed_config,
193193
)
194194
from .eetq import replace_with_eetq_linear
195-
from .fbgemm_fp8 import FbgemmFp8Linear, replace_with_fbgemm_fp8_linear
195+
from .fbgemm_fp8 import FbgemmFp8Linear, replace_with_fbgemm_fp8_linear, FbgemmFp8Llama4TextExperts
196196
from .finegrained_fp8 import FP8Linear, replace_with_fp8_linear
197197
from .fsdp import is_fsdp_managed_module
198198
from .ggml import (

src/transformers/integrations/fbgemm_fp8.py

Lines changed: 114 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from ..utils import is_accelerate_available, is_fbgemm_gpu_available, is_torch_available, logging
16-
16+
from ..activations import ACT2FN
1717

1818
if is_torch_available():
1919
import torch
@@ -28,18 +28,18 @@
2828
logger = logging.get_logger(__name__)
2929

3030

31-
class FbgemmFp8Linear(torch.nn.Module):
31+
class FbgemmFp8Linear(torch.nn.Linear):
3232
def __init__(self, in_features, out_features, bias, weight_dtype=torch.float32):
33-
super().__init__()
33+
super().__init__(in_features, out_features, bias)
3434
self.in_features = in_features
3535
self.out_features = out_features
3636

37-
self.register_buffer("weight", torch.zeros((out_features, in_features), dtype=torch.float8_e4m3fn))
38-
self.register_buffer("weight_scale", torch.zeros((out_features, 1), dtype=weight_dtype))
37+
self.weight = torch.nn.Parameter(torch.zeros((out_features, in_features), dtype=torch.float8_e4m3fn))
38+
self.weight_scale = torch.nn.Parameter(torch.zeros((out_features, 1), dtype=weight_dtype))
3939
self.register_buffer("input_scale_ub", torch.zeros([1], dtype=torch.float), persistent=False)
4040

4141
if bias:
42-
self.register_buffer("bias", torch.zeros((self.out_features), dtype=weight_dtype))
42+
self.bias = torch.nn.Parameter(torch.zeros((self.out_features), dtype=weight_dtype))
4343
else:
4444
self.bias = None
4545

@@ -50,15 +50,16 @@ def forward(self, x):
5050
# x_quantized and x_scale are not necessarily on the same device as x, this is an issue.
5151
# https://github.com/pytorch/FBGEMM/blob/e08af8539c391437f447173863df0f3f6f6f1855/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L1237C3-L1237C45
5252
x_quantized, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
53-
x.view(-1, x.shape[-1]), num_tokens, self.input_scale_ub
53+
x.view(-1, x.shape[-1]), scale_ub=self.input_scale_ub
5454
)
5555
# moving x_quantized, x_scale here creates glibberish output ... However, if we move the output, it works
5656
# x_quantized, x_scale = x_quantized.to(x.device), x_scale.to(x.device)
5757

5858
# The computation still happens on the device where self.weight is even if x_quantized is not on the same device as self.weight
59+
weight_scale_float32 = self.weight_scale.to(torch.float32)
5960
output = torch.ops.fbgemm.f8f8bf16_rowwise(
60-
x_quantized, self.weight, x_scale, self.weight_scale, use_fast_accum=True
61-
)
61+
x_quantized, self.weight, x_scale, weight_scale_float32, use_fast_accum=True
62+
)
6263
output = output + self.bias if self.bias is not None else output
6364
# Hacky for now, we have the output to the device of x
6465
output = output.to(x.device)
@@ -67,19 +68,104 @@ def forward(self, x):
6768
return output
6869

6970

71+
class FbgemmFp8Llama4TextExperts(nn.Module):
72+
def __init__(self, config, dtype=torch.float32):
73+
super().__init__()
74+
self.num_experts = config.num_local_experts
75+
self.intermediate_size = config.intermediate_size
76+
self.hidden_size = config.hidden_size
77+
self.expert_dim = self.intermediate_size
78+
self.act_fn = ACT2FN[config.hidden_act]
79+
# Register FP8 buffers for gate_up_proj
80+
self.gate_up_proj = torch.nn.Parameter(torch.zeros((self.num_experts, self.hidden_size, 2 * self.expert_dim), dtype=torch.float8_e4m3fn))
81+
self.gate_up_proj_scale = torch.nn.Parameter(torch.zeros((self.num_experts, 1, self.expert_dim * 2), dtype=torch.float32))
82+
# Register FP8 buffers for down_proj
83+
self.down_proj = torch.nn.Parameter(torch.zeros((self.num_experts, self.expert_dim, self.hidden_size), dtype=torch.float8_e4m3fn))
84+
self.down_proj_scale = torch.nn.Parameter(torch.zeros((self.num_experts, self.hidden_size, 1), dtype=torch.float32))
85+
# Register input scale upper bound
86+
self.register_buffer("input_scale_ub", torch.zeros([1], dtype=torch.float), persistent=False)
87+
88+
89+
def forward(self, hidden_states):
90+
"""
91+
Args:
92+
hidden_states (torch.Tensor): (batch_size * token_num, hidden_size)
93+
Returns:
94+
torch.Tensor: (batch_size * token_num, hidden_size)
95+
"""
96+
# Reshape hidden states for expert computation
97+
hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)
98+
num_tokens = None
99+
100+
# Pre-allocate tensor for all expert outputs with same shape as hidden_states
101+
next_states = torch.empty_like(hidden_states)
102+
103+
for i in range(self.num_experts):
104+
# Extract expert's hidden states
105+
expert_hidden = hidden_states[i]
106+
expert_hidden_reshaped = expert_hidden.reshape(-1, self.hidden_size)
107+
# Quantize for this expert
108+
expert_quantized, expert_scale = torch.ops.fbgemm.quantize_fp8_per_row(
109+
expert_hidden_reshaped, num_tokens, self.input_scale_ub
110+
)
111+
sharded_expert_dim = self.gate_up_proj.shape[-1] // 2
112+
gate_up_proj_scale_float32 = self.gate_up_proj_scale.to(torch.float32)
113+
114+
gate = torch.ops.fbgemm.f8f8bf16_rowwise(
115+
expert_quantized,
116+
self.gate_up_proj[i].transpose(0,1)[:sharded_expert_dim].contiguous(),
117+
expert_scale,
118+
gate_up_proj_scale_float32[i][0][:sharded_expert_dim].view(-1, 1).contiguous(),
119+
use_fast_accum=True
120+
)
121+
122+
up = torch.ops.fbgemm.f8f8bf16_rowwise(
123+
expert_quantized,
124+
self.gate_up_proj[i].transpose(0,1)[sharded_expert_dim:].contiguous(),
125+
expert_scale,
126+
gate_up_proj_scale_float32[i][0][sharded_expert_dim:].view(-1, 1).contiguous(),
127+
use_fast_accum=True
128+
)
129+
130+
activated = up * self.act_fn(gate)
131+
132+
activated_quantized, activated_scale = torch.ops.fbgemm.quantize_fp8_per_row(
133+
activated, num_tokens, self.input_scale_ub
134+
)
135+
136+
down_proj_scale_float32 = self.down_proj_scale.to(torch.float32)
137+
expert_output = torch.ops.fbgemm.f8f8bf16_rowwise(
138+
activated_quantized,
139+
self.down_proj[i].transpose(0,1).contiguous(),
140+
activated_scale,
141+
down_proj_scale_float32[i].view(-1, 1).contiguous(),
142+
use_fast_accum=True
143+
)
144+
145+
next_states[i] = expert_output
146+
next_states = next_states.to(hidden_states.device)
147+
return next_states.view(-1, self.hidden_size)
148+
149+
70150
def _replace_with_fbgemm_fp8_linear(
71151
model,
72152
modules_to_not_convert=None,
73153
current_key_name=None,
74154
quantization_config=None,
75155
has_been_replaced=False,
76156
pre_quantized=False,
157+
config=None,
158+
tp_plan=None
77159
):
78160
"""
79161
Private method that wraps the recursion for module replacement.
80162
81163
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
82164
"""
165+
166+
from transformers.models.llama4.modeling_llama4 import Llama4TextExperts
167+
import re
168+
83169
if current_key_name is None:
84170
current_key_name = []
85171

@@ -105,9 +191,24 @@ def _replace_with_fbgemm_fp8_linear(
105191
# Force requires grad to False to avoid unexpected errors
106192
model._modules[name].requires_grad_(False)
107193
# set non persistant buffer outside of init_empty_weights
194+
model._modules[name].input_scale_ub = torch.tensor(
195+
[quantization_config.activation_scale_ub], dtype=torch.float,
196+
)
197+
if module.__class__.__name__ == "Llama4TextExperts" and name not in modules_to_not_convert:
198+
current_key_name_str = ".".join(current_key_name)
199+
if not any(
200+
(key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
201+
):
202+
with init_empty_weights(include_buffers=True):
203+
tp_plan[re.sub(r"\d+", "*", current_key_name_str + ".gate_up_proj_scale")] = tp_plan[re.sub(r"\d+", "*", current_key_name_str + ".gate_up_proj")]
204+
tp_plan[re.sub(r"\d+", "*", current_key_name_str + ".down_proj_scale")] = None
205+
model._modules[name] = FbgemmFp8Llama4TextExperts(
206+
config.text_config,
207+
)
108208
model._modules[name].input_scale_ub = torch.tensor(
109209
[quantization_config.activation_scale_ub], dtype=torch.float
110210
)
211+
111212
if len(list(module.children())) > 0:
112213
_, has_been_replaced = _replace_with_fbgemm_fp8_linear(
113214
module,
@@ -116,14 +217,16 @@ def _replace_with_fbgemm_fp8_linear(
116217
quantization_config,
117218
has_been_replaced=has_been_replaced,
118219
pre_quantized=pre_quantized,
220+
config=config,
221+
tp_plan=tp_plan
119222
)
120223
# Remove the last key for recursion
121224
current_key_name.pop(-1)
122225
return model, has_been_replaced
123226

124227

125228
def replace_with_fbgemm_fp8_linear(
126-
model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, pre_quantized=False
229+
model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, pre_quantized=False, config=None, tp_plan=None
127230
):
128231
"""
129232
A helper function to replace all `torch.nn.Linear` modules by `FbgemmFp8Linear` modules.
@@ -151,9 +254,8 @@ def replace_with_fbgemm_fp8_linear(
151254
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
152255
modules_to_not_convert = list(set(modules_to_not_convert))
153256
model, has_been_replaced = _replace_with_fbgemm_fp8_linear(
154-
model, modules_to_not_convert, current_key_name, quantization_config, pre_quantized=pre_quantized
257+
model, modules_to_not_convert, current_key_name, quantization_config, pre_quantized=pre_quantized, config=config, tp_plan=tp_plan
155258
)
156-
157259
if not has_been_replaced:
158260
logger.warning(
159261
"You are loading your model using FP8 quantization but no linear modules were found in your model."

src/transformers/integrations/tensor_parallel.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,19 @@ def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> Li
6060
single_size = total_size // blocks
6161
return [single_size] * blocks
6262

63+
str_to_torch_dtype = {
64+
"BOOL": torch.bool,
65+
"U8": torch.uint8,
66+
"I8": torch.int8,
67+
"I16": torch.int16,
68+
"F16": torch.float16,
69+
"BF16": torch.bfloat16,
70+
"I32": torch.int32,
71+
"F32": torch.float32,
72+
"F64": torch.float64,
73+
"I64": torch.int64,
74+
"F8_E4M3": torch.float8_e4m3fn
75+
}
6376

6477
def get_packed_weights(param, empty_param, device_mesh, rank, dim):
6578
"""
@@ -105,6 +118,12 @@ def get_packed_weights(param, empty_param, device_mesh, rank, dim):
105118
stop = (rank + 1) * shard_block_size
106119
tensors_slices += range(block_offset + start, block_offset + stop)
107120
block_offset += block_size
121+
122+
slice_dtype = slice_.get_dtype()
123+
# Handle F8_E4M3 dtype by converting to float16 before slicing
124+
# Without upcasting, the slicing causes : RuntimeError: "index_cpu" not implemented for 'Float8_e4m3fn'
125+
if slice_dtype == "F8_E4M3":
126+
slice_ = slice_[...].to(torch.float16)
108127

109128
if dim == 0:
110129
tensor = slice_[tensors_slices, ...]
@@ -114,7 +133,7 @@ def get_packed_weights(param, empty_param, device_mesh, rank, dim):
114133
tensor = slice_[..., tensors_slices]
115134
else:
116135
raise ValueError(f"Unsupported dim {dim}, only dim 0, 1 or 2 are supported")
117-
return tensor
136+
return tensor.to(str_to_torch_dtype[slice_dtype])
118137

119138

120139
def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
@@ -539,10 +558,16 @@ def shard_and_distribute_module(
539558
module_to_tp._is_hooked = True
540559

541560
if current_module_plan is not None:
542-
tp_layer = translate_to_torch_parallel_style(current_module_plan)
543-
param = tp_layer.partition_tensor(
544-
param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh
545-
)
561+
try:
562+
tp_layer = translate_to_torch_parallel_style(current_module_plan)
563+
param = tp_layer.partition_tensor(
564+
param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh
565+
)
566+
except NotImplementedError as e:
567+
568+
print(
569+
f"Trying to prepare {parameter_name}, but it's not supported. Corresponding module: {module_to_tp} Fix it's TP plan, current layer: {tp_layer} : {e}"
570+
)
546571
else:
547572
# TODO log no plan modules in set
548573
print("No plan for", parameter_name,end ="\r")

src/transformers/modeling_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4273,6 +4273,7 @@ def from_pretrained(
42734273
)
42744274
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
42754275
device_map = hf_quantizer.update_device_map(device_map)
4276+
config = hf_quantizer.update_tp_plan(config)
42764277

42774278
# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
42784279
if hasattr(hf_quantizer.quantization_config.quant_method, "value"):
@@ -4405,9 +4406,8 @@ def from_pretrained(
44054406

44064407
if hf_quantizer is not None:
44074408
hf_quantizer.preprocess_model(
4408-
model=model, device_map=device_map, keep_in_fp32_modules=model._keep_in_fp32_modules
4409+
model=model, device_map=device_map, keep_in_fp32_modules=model._keep_in_fp32_modules, config=config
44094410
)
4410-
44114411
# We store the original dtype for quantized models as we cannot easily retrieve it
44124412
# once the weights have been quantized
44134413
# Note that once you have loaded a quantized model, you can't change its dtype so this will

src/transformers/models/llama4/modeling_llama4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def forward(self, hidden_states):
184184
input=hidden_states,
185185
dim=0,
186186
index=router_indices,
187-
)
187+
).to(hidden_states.device)
188188
# we gather inputs corresponding to each expert based on the router indices
189189
routed_in = routed_in * router_scores.reshape(-1, 1)
190190
expert_routed_out_list = []

src/transformers/quantizers/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,10 @@ def validate_environment(self, *args, **kwargs):
198198
"""
199199
return
200200

201+
def update_tp_plan(self, config):
202+
"updates the tp plan for the scales"
203+
return config
204+
201205
def preprocess_model(self, model: "PreTrainedModel", **kwargs):
202206
"""
203207
Setting model attributes and/or converting model before weights loading. At this point

src/transformers/quantizers/quantizer_compressed_tensors.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,19 @@ def _process_model_after_weight_loading(self, model, **kwargs):
141141

142142
self.compressor.quantization_config.quantization_status = QuantizationStatus.FROZEN
143143
self.compressor.decompress(model_path=cache_path, model=model)
144+
145+
def update_tp_plan(self, config):
146+
additional_plan = {
147+
"layers.*.feed_forward.experts.*.gate_proj.weight": "local_colwise",
148+
"layers.*.feed_forward.experts.*.gate_proj.weight_scale": "local_colwise",
149+
"layers.*.feed_forward.experts.*.up_proj.weight": "local_colwise",
150+
"layers.*.feed_forward.experts.*.up_proj.weight_scale": "local_colwise",
151+
"layers.*.feed_forward.experts.*.down_proj.weight": "local_rowwise",
152+
}
153+
if config.get_text_config() is not None and config.get_text_config().base_model_tp_plan is not None:
154+
config.get_text_config().base_model_tp_plan.update(additional_plan)
155+
156+
return config
144157

145158
@property
146159
def is_quantized(self):

0 commit comments

Comments
 (0)