Skip to content

Commit 189a103

Browse files
liuzijing2014pcuencaluccafong
committed
Moe 128 rebased (#8)
* 128 experts * Use default rope * Unfuse mlp * Address feedback * Use None "default" for rope_scaling. Add eot. * Meta/llama quant compat (#7) * add quant compatible model & conversion code for llama4 * fix a few issues * fix a few issues * minor type mapping fix --------- Co-authored-by: Lu Fang <fanglu@fb.com> * use a new config parameter to determine which model definition to use for MoE --------- Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: Lu Fang <fanglu@fb.com>
1 parent fb748af commit 189a103

File tree

4 files changed

+48
-14
lines changed

4 files changed

+48
-14
lines changed

src/transformers/modeling_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,7 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
529529
"F32": torch.float32,
530530
"F64": torch.float64,
531531
"I64": torch.int64,
532+
"F8_E4M3": torch.float8_e4m3fn
532533
}
533534

534535
if is_torch_greater_or_equal("2.3.0"):
@@ -4061,7 +4062,7 @@ def from_pretrained(
40614062
if not torch.distributed.is_initialized():
40624063
try:
40634064
rank = int(os.environ["LOCAL_RANK"])
4064-
world_size = int(os.environ["ROLE_WORLD_SIZE"])
4065+
world_size = int(os.environ["ROLE_WORLD_SIZE"])
40654066
logger.warning(
40664067
"Tensor Parallel requires torch.distributed to be initialized first."
40674068
f"Initializing with world size {world_size} on rank {rank}"

src/transformers/models/llama4/configuration_llama4.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ def __init__(
177177
router_aux_loss_coef=0.001,
178178
router_jitter_noise=0.0,
179179
rope_scaling=None,
180+
for_llm_compressor=False,
180181
**kwargs,
181182
):
182183
super().__init__(
@@ -217,6 +218,8 @@ def __init__(
217218
self.router_aux_loss_coef = router_aux_loss_coef
218219
self.router_jitter_noise = router_jitter_noise
219220

221+
self.for_llm_compressor = for_llm_compressor
222+
220223

221224
class Llama4Config(PretrainedConfig):
222225
r"""
@@ -290,6 +293,9 @@ class Llama4Config(PretrainedConfig):
290293
The aux loss factor for the total loss.
291294
router_jitter_noise (`float`, *optional*, defaults to 0.0):
292295
Amount of noise to add to the router.
296+
for_llm_compressor: (`bool`, *optional*, defaults to `False`):
297+
Whether this config is for a checkpoint that aims to use LLM compressor for fp8 quantization.
298+
If `True`, the model MoE part would swap to use Linear instead of FusedMoE.
293299
294300
```python
295301
>>> from transformers import Llama4Model, Llama4Config

src/transformers/models/llama4/convert_llama4_weights_to_hf.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from transformers.integrations.tiktoken import TikTokenConverter
2222

2323

24+
_OFFLINE_QUANT_COMPATIBLE = os.environ.get("OFFLINE_QUANT_COMPATIBLE", "0") == "1"
25+
2426
torch.serialization.add_safe_globals([io.BytesIO])
2527
# fmt: off
2628

@@ -29,6 +31,8 @@
2931
# Still not sure what to do with those!
3032
# `None` means we drop the key
3133

34+
35+
weight_postfix = ".weight" if _OFFLINE_QUANT_COMPATIBLE else ""
3236
ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
3337
# CausalLM keys
3438
r"output.weight": r"language_model.lm_head.weight",
@@ -44,9 +48,9 @@
4448
r"layers.(\d+).attention.wqkv.weight": r"language_model.model.layers.\1.self_attn.qkv_proj.weight",
4549

4650
# MoE keys: no simple MLPmodel.
47-
r"layers.(\d+).feed_forward.experts.moe_w_in_eD_F": r"language_model.model.layers.\1.feed_forward.experts.gate_proj", # will be fused with up
48-
r"layers.(\d+).feed_forward.experts.moe_w_out_eF_D": r"language_model.model.layers.\1.feed_forward.experts.down_proj", # expert win
49-
r"layers.(\d+).feed_forward.experts.moe_w_swiglu_eD_F": r"language_model.model.layers.\1.feed_forward.experts.up_proj", # fused with up
51+
r"layers.(\d+).feed_forward.experts.moe_w_in_eD_F": r"language_model.model.layers.\1.feed_forward.experts.gate_proj" + weight_postfix, # will be fused with up
52+
r"layers.(\d+).feed_forward.experts.moe_w_out_eF_D": r"language_model.model.layers.\1.feed_forward.experts.down_proj" + weight_postfix, # expert win
53+
r"layers.(\d+).feed_forward.experts.moe_w_swiglu_eD_F": r"language_model.model.layers.\1.feed_forward.experts.up_proj" + weight_postfix, # fused with up
5054
r"layers.(\d+).feed_forward.router_DE": r"language_model.model.layers.\1.feed_forward.router.weight", # used for top
5155
r"layers.(\d+).feed_forward.w_in_shared_FD": r"language_model.model.layers.\1.feed_forward.shared_expert.gate_proj", # might need to be fused for efficiency?
5256
r"layers.(\d+).feed_forward.w_out_shared_DF": r"language_model.model.layers.\1.feed_forward.shared_expert.down_proj", # might need to be fused for efficiency?
@@ -262,6 +266,7 @@ def write_model(
262266
pad_token_id=pad_token_id,
263267
tie_word_embeddings=False, # Constant set to False
264268
torch_dtype=torch_dtype,
269+
for_llm_compressor=_OFFLINE_QUANT_COMPATIBLE,
265270
**config_kwargs,
266271
)
267272
# default vision config frmo params
@@ -380,6 +385,16 @@ def write_model(
380385
v = new_key.replace("qkv", "v")
381386
tqdm.write(f"Processing: {key.ljust(50)} ->\t {v}, {values.shape}")
382387
state_dict[v] = values
388+
elif _OFFLINE_QUANT_COMPATIBLE and "feed_forward.experts." in new_key:
389+
# for experts, we need to split expert for offline quantiation purpose and don't need to fuse
390+
expert_lists = []
391+
for k in current_parameter:
392+
expert_lists.append(list(k.reshape(num_experts, -1, k.shape[-1]).unbind(0))) # [#expert * IN, OUT] -> #experts * [IN, OUT]
393+
for i in range(num_experts):
394+
expert = torch.cat([expert_list[i] for expert_list in expert_lists], dim=concat_dim)
395+
expert_key = new_key.replace("experts.", f"experts.{i}.")
396+
state_dict[expert_key] = expert.transpose(0,1).contiguous() #[OUT, IN]
397+
tqdm.write(f"Processing: {key.ljust(50)} ->\t {expert_key}, {state_dict[expert_key].shape}")
383398
elif re.search(r"(gate|up)_proj", new_key):
384399
path = new_key.split(".")
385400
gate_key = re.sub(r"(gate|up)_proj", lambda m: "gate_proj", new_key)
@@ -408,6 +423,7 @@ def write_model(
408423
gate_up_proj = torch.cat((gate_proj, up_proj), dim=-1)
409424
new_key = new_key.replace("up_proj", "gate_up_proj")
410425
state_dict[new_key] = gate_up_proj.contiguous()
426+
411427
tqdm.write(f"Processing: {key.ljust(50)} ->\t {new_key}, {state_dict[new_key].shape}")
412428
elif "down_proj" in new_key:
413429
current_parameter = torch.cat(current_parameter, dim=concat_dim)
@@ -710,11 +726,11 @@ def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False):
710726
)
711727

712728
args = parser.parse_args()
713-
write_tokenizer(
714-
tokenizer_path=os.path.join(args.input_dir, "tokenizer.model"),
715-
save_dir=args.output_dir,
716-
instruct=args.instruct,
717-
)
729+
# write_tokenizer(
730+
# tokenizer_path=os.path.join(args.input_dir, "tokenizer.model"),
731+
# save_dir=args.output_dir,
732+
# instruct=args.instruct,
733+
# )
718734

719735
write_model(
720736
model_path=args.output_dir,

src/transformers/models/llama4/modeling_llama4.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from dataclasses import dataclass
2424
from typing import Callable, List, Optional, Tuple, Union
2525

26+
import os
2627
import torch
2728
import torch.nn as nn
2829
import torch.nn.functional as F
@@ -61,7 +62,6 @@
6162
_CHECKPOINT_FOR_DOC = "meta-ai/Llama-4-17B"
6263
_CONFIG_FOR_DOC = "Llama4Config"
6364

64-
6565
class Llama4TextExperts(nn.Module):
6666
def __init__(self, config: Llama4Config):
6767
super().__init__()
@@ -153,7 +153,12 @@ def __init__(self, config):
153153
super().__init__()
154154
self.top_k = config.num_experts_per_tok
155155
self.hidden_dim = config.hidden_size
156-
self.experts = Llama4TextExperts(config)
156+
self.num_experts = config.num_local_experts
157+
self.for_llm_compressor = config.for_llm_compressor
158+
if self.for_llm_compressor:
159+
self.experts = nn.ModuleList([Llama4TextMLP(config) for _ in range(self.num_experts)])
160+
else:
161+
self.experts = Llama4TextExperts(config)
157162
self.router = nn.Linear(config.hidden_size, config.num_local_experts, bias=False)
158163
self.shared_expert = Llama4TextMLP(config)
159164

@@ -184,8 +189,14 @@ def forward(self, hidden_states):
184189
)
185190
# we gather inputs corresponding to each expert based on the router indices
186191
routed_in = routed_in * router_scores.reshape(-1, 1)
187-
routed_out = self.experts(routed_in) # routed in is "sorted" / ready for EP
188-
192+
expert_routed_out_list = []
193+
if self.for_llm_compressor:
194+
routed_in = routed_in.reshape(self.num_experts, -1, routed_in.shape[-1])
195+
for expert_idx in range(self.num_experts):
196+
expert_routed_out_list.append(self.experts[expert_idx](routed_in[expert_idx]))
197+
routed_out = torch.cat(expert_routed_out_list, dim=0)
198+
else:
199+
routed_out = self.experts(routed_in)
189200
out = self.shared_expert(hidden_states)
190201
# now that we finished expert computation -> we scatter add because we gathered previously
191202
# we have to do this because we used all experts on all tokens. This is faster than the for loop, tho you are compute bound
@@ -1706,7 +1717,7 @@ def forward(
17061717
projected_vision_flat = self.multi_modal_projector(vision_flat)
17071718

17081719
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
1709-
final_mask = special_image_mask.to(inputs_embeds.device)
1720+
final_mask = special_image_mask.to(inputs_embeds.device)
17101721
inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1))
17111722

17121723
final_mask_1d = final_mask[..., 0].reshape(-1)

0 commit comments

Comments
 (0)