-
Notifications
You must be signed in to change notification settings - Fork 192
Sync amax & AWQ-Lite act_scale in context parallel/data parallel [OMNIML-2813] #359
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f17131f
42519cc
264adbb
7cbe5b9
1f7d17e
71a9f7a
d02365c
5a572da
fc0bb88
95da832
34c11ef
10e3e2b
9f0691f
fa8f4c8
d1fac44
22b8b73
ca7c0e8
3f857a3
93bfd52
6761109
291cfa3
a106dd9
50000dd
2664563
440ca48
2e8ef58
5cb380c
afe6f34
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -26,7 +26,7 @@ | |||||||||||||||||||||
|
|
||||||||||||||||||||||
| from modelopt.torch.opt.searcher import ForwardLoop | ||||||||||||||||||||||
| from modelopt.torch.utils import print_rank_0 | ||||||||||||||||||||||
| from modelopt.torch.utils.distributed import ParallelState | ||||||||||||||||||||||
| from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState | ||||||||||||||||||||||
| from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context | ||||||||||||||||||||||
|
|
@@ -81,6 +81,7 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis | |||||||||||||||||||||
| return | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def sync_quantizer_amax_across_dp(quantizer, parallel_state): | ||||||||||||||||||||||
| """Synchronize the amax across all ranks in the data parallel group.""" | ||||||||||||||||||||||
| if isinstance(quantizer, SequentialQuantizer): | ||||||||||||||||||||||
| for _q in quantizer: | ||||||||||||||||||||||
| sync_quantizer_amax_across_dp(_q, parallel_state) | ||||||||||||||||||||||
|
|
@@ -94,7 +95,6 @@ def sync_quantizer_amax_across_dp(quantizer, parallel_state): | |||||||||||||||||||||
| for child in module.children(): | ||||||||||||||||||||||
| if isinstance(child, (TensorQuantizer, SequentialQuantizer)): | ||||||||||||||||||||||
| sync_quantizer_amax_across_dp(child, module.parallel_state) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # TP sync: | ||||||||||||||||||||||
| # Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
@@ -114,6 +114,7 @@ def sync_quantizer_amax_across_tp( | |||||||||||||||||||||
| axes_for_sync: list, | ||||||||||||||||||||||
| parallel_state: ParallelState, | ||||||||||||||||||||||
| ): | ||||||||||||||||||||||
| # Syncing amax across TP for sequential quantizer | ||||||||||||||||||||||
| if isinstance(quantizer, SequentialQuantizer): | ||||||||||||||||||||||
| for _q in quantizer: | ||||||||||||||||||||||
| sync_quantizer_amax_across_tp( | ||||||||||||||||||||||
|
|
@@ -598,19 +599,37 @@ def forward(self, input, *args, **kwargs): | |||||||||||||||||||||
| # This will also perform distributed amax sync for input_quantizers | ||||||||||||||||||||||
| max_calibrate(model, lambda model: None) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def sync_act_scale_across_dp(module, data_parallel_group): | ||||||||||||||||||||||
| """Sync activation scale across Data Parallel (DP).""" | ||||||||||||||||||||||
| if data_parallel_group.is_initialized(): | ||||||||||||||||||||||
| dist.all_reduce( | ||||||||||||||||||||||
| module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=data_parallel_group.group | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| for name, module in model.named_modules(): | ||||||||||||||||||||||
| if ( | ||||||||||||||||||||||
| is_quantized_linear(module) | ||||||||||||||||||||||
| and hasattr(module, "awq_lite") | ||||||||||||||||||||||
| and module.awq_lite.num_cache_steps > 0 | ||||||||||||||||||||||
| ): | ||||||||||||||||||||||
| # Hack: MoEs forward all tokens through all experts if _if_calib is True | ||||||||||||||||||||||
| module._if_calib = True | ||||||||||||||||||||||
| module.awq_lite.act_scale = module.awq_lite.act_scale / module.awq_lite.num_cache_steps | ||||||||||||||||||||||
| if torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any( | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| has_nan_local = torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any( | ||||||||||||||||||||||
| torch.isnan(module.awq_lite.weight_scale) | ||||||||||||||||||||||
| ): | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
Comment on lines
+619
to
+621
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix tensor boolean evaluation before distributed sync.
- has_nan_local = torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any(
- torch.isnan(module.awq_lite.weight_scale)
- )
+ act_nan = torch.isnan(module.awq_lite.act_scale).any().item()
+ weight_nan = torch.isnan(module.awq_lite.weight_scale).any().item()
+ has_nan_local = bool(act_nan or weight_nan)📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||
| has_nan = DistributedProcessGroup.get_dist_syncd_obj( | ||||||||||||||||||||||
| has_nan_local, module.parallel_state.data_parallel_group, lambda objs: any(objs) | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if has_nan: | ||||||||||||||||||||||
| module.awq_lite.is_enabled = False | ||||||||||||||||||||||
| # Hack: MoEs forward all tokens through all experts if _if_calib is True | ||||||||||||||||||||||
| module._if_calib = True | ||||||||||||||||||||||
| else: | ||||||||||||||||||||||
| sync_act_scale_across_dp( | ||||||||||||||||||||||
| module, | ||||||||||||||||||||||
| module.parallel_state.data_parallel_group, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| AWQLiteHelper.cache_mode = False | ||||||||||||||||||||||
| print_rank_0("awq_lite: Searching parameters...") | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like we dont need separate methods |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch on this.
Can we use