|
28 | 28 | from megatron.core.extensions import transformer_engine as megatron_te |
29 | 29 | from megatron.core.parallel_state import get_data_parallel_group |
30 | 30 | from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region |
31 | | -from megatron.core.parallel_state import get_data_parallel_group |
32 | | -from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region |
33 | 31 | from megatron.core.transformer import MegatronModule |
34 | 32 | from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint |
35 | 33 | from megatron.core.utils import get_tensor_model_parallel_group_if_none |
@@ -63,23 +61,24 @@ def get_sequential_mlp_expert_names(name: str, module: torch.nn.Module): |
63 | 61 | expert_name, local_expert_name = name.split(".local_experts.") |
64 | 62 | # extract quantizer name by removing local_expert number from the name |
65 | 63 | local_expert_name = ".".join(local_expert_name.split(".")[1:]) |
66 | | - return expert_name, local_expert_name |
67 | | - return None, None |
| 64 | + return f"{expert_name}.{local_expert_name}" |
| 65 | + return None |
68 | 66 |
|
69 | 67 | # gather amax values from SequentialMLP experts |
70 | 68 | for name, module in model.named_modules(): |
71 | | - expert_name, local_expert_name = get_sequential_mlp_expert_names(name, module) |
72 | | - if expert_name and local_expert_name: |
73 | | - amax_dict[local_expert_name] = amax_dict.get(local_expert_name, {}) |
74 | | - amax_dict[local_expert_name][expert_name] = max( |
75 | | - amax_dict[local_expert_name].get(expert_name, 0), module.amax |
| 69 | + expert_name = get_sequential_mlp_expert_names(name, module) |
| 70 | + if expert_name and module.amax is not None: |
| 71 | + stored_amax = amax_dict.get(expert_name) |
| 72 | + amax_tensor = module.amax.detach().clone() |
| 73 | + amax_dict[expert_name] = ( |
| 74 | + amax_tensor if stored_amax is None else torch.maximum(stored_amax, amax_tensor) |
76 | 75 | ) |
77 | 76 |
|
78 | 77 | # sync amax values across experts in SequentialMLP |
79 | 78 | for name, module in model.named_modules(): |
80 | | - expert_name, local_expert_name = get_sequential_mlp_expert_names(name, module) |
81 | | - if expert_name and local_expert_name: |
82 | | - module.amax = amax_dict[local_expert_name][expert_name] |
| 79 | + expert_name = get_sequential_mlp_expert_names(name, module) |
| 80 | + if expert_name and module.amax is not None: |
| 81 | + module.amax = amax_dict[expert_name].detach().clone().to(module.amax.device) |
83 | 82 |
|
84 | 83 |
|
85 | 84 | CUSTOM_POST_CALIBRATION_PLUGINS.add(sync_amax_across_sequential_mlp) |
|
0 commit comments