-
Notifications
You must be signed in to change notification settings - Fork 191
Sync amax & AWQ-Lite act_scale in context parallel/data parallel [OMNIML-2813] #359
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
WalkthroughAdds DP-aware activation-scale and amax synchronization with local NaN detection and DP-wide NaN propagation/disable during calibration (including AWQ‑Lite); makes Megatron plugin request a CP-aware data-parallel group with fallback+warning; tweaks ParallelState repr; expands tests/fixtures for DP/TP/CP (adds 8‑GPU fixture); updates example README/docker run. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Calib as Calibration Loop
participant Mod as Quantized Module
participant Q as Quantizer
participant DP as Data-Parallel Group
Note right of DP #D3E4CD: DP sync & NaN propagation
Calib->>Mod: run forward / collect stats
Mod->>Q: compute local amax / act_scale
alt local NaN detected
Mod->>DP: all-reduce NaN flag
DP-->>Mod: consensus NaN (any rank true)
Mod-->>Calib: mark module disabled due to NaN
else
Mod->>DP: all-reduce amax/act_scale (DP sync)
Q-->>Mod: updated synchronized scales
Mod-->>Calib: continue calibration
end
sequenceDiagram
autonumber
participant Init as Megatron Plugin Init
participant MPG as megatron.core.parallel_state
participant PS as ParallelState
Note right of MPG #F0F4FF: try CP-aware DP group, fallback if uninitialized
Init->>MPG: get_data_parallel_group(with_context_parallel=true)
alt context-parallel available
MPG-->>Init: dp_group (CP-aware)
else fallback
Init->>MPG: get_data_parallel_group()
MPG-->>Init: dp_group (fallback)
Init->>Init: log warning about fallback
end
Init->>MPG: get_tensor_model_parallel_group()
Init->>PS: __init__(data_parallel_group=dp_group, tensor_parallel_group=tp_group, ...)
PS-->>Init: ParallelState ready
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #359 +/- ##
==========================================
- Coverage 73.36% 73.36% -0.01%
==========================================
Files 180 180
Lines 17919 17925 +6
==========================================
+ Hits 13147 13151 +4
- Misses 4772 4774 +2 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
@jenchen13 could you please add unit tests for context parallel quantization (similar to tensor parallel) to here -
basically the TP test checks whether amax is similar across the TP group. see TensorRT-Model-Optimizer/tests/_test_utils/torch_quantization/quantize_common.py Line 119 in 26c203a
|
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
e764e79 to
42519cc
Compare
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
aa5b8fd to
264adbb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tests/_test_utils/torch_quantization/quantize_common.py (1)
132-149: Replace config-based guards with attribute-presence checks and sync AWQ pre_quant_scale
• For amax:if model.fc2.input_quantizer.amax is not None: activation_amax = model.fc2.input_quantizer.amax.clone() dist.all_reduce(activation_amax, op=dist.ReduceOp.MAX, group=tp_group) assert torch.allclose(activation_amax, model.fc2.input_quantizer.amax)• For scales (SmoothQuant, AWQ/AWQ-Lite):
# input scale if (scale := model.fc1.input_quantizer.pre_quant_scale) is not None: scale_clone = scale.clone() dist.all_reduce(scale_clone, op=dist.ReduceOp.MAX, group=tp_group) assert torch.allclose(scale_clone, scale) # weight scale (AWQ-Lite) if (wscale := model.fc1.weight_quantizer.pre_quant_scale) is not None: wscale_clone = wscale.clone() dist.all_reduce(wscale_clone, op=dist.ReduceOp.MAX, group=tp_group) assert torch.allclose(wscale_clone, wscale)Drop all
if config in […]checks.
🧹 Nitpick comments (2)
modelopt/torch/quantization/model_calib.py (1)
628-651: Weight DP/CP averages by token countsRight now we average
act_scaleequally across ranks. In mixed-workload runs (e.g., MoE routing) we can see unevennum_tokens, so the lighter ranks end up pulling the mean down. Since we already tracknum_tokens, we can switch to a weighted reduction (sum of scale * tokens and sum of tokens) before normalizing.A sketch:
- module.awq_lite.act_scale = module.awq_lite.act_scale / module.awq_lite.num_cache_steps - sync_act_scale_across_dp_cp(...) + module.awq_lite.act_scale = module.awq_lite.act_scale / module.awq_lite.num_cache_steps + token_count = torch.tensor( + [module.awq_lite.num_tokens], + device=module.awq_lite.act_scale.device, + dtype=module.awq_lite.act_scale.dtype, + ) + scale_sum = module.awq_lite.act_scale * token_count.item() + sync_reduction_across_dp_cp(scale_sum, token_count, module.parallel_state) + module.awq_lite.act_scale = scale_sum / token_count(
sync_reduction_across_dp_cpwould all-reduce both tensors across DP/CP groups.)tests/_test_utils/torch_quantization/quantize_common.py (1)
215-237: 3D helper: sync input pre_quant_scale across TP/CP/DP
- Replace the config-based presence check with an attribute check for input amax (e.g. if getattr(model.fc1.input_quantizer, "amax", None) is not None) in tests/_test_utils/torch_quantization/quantize_common.py → data_tensor_context_parallel_test_helper.
- If input_quantizer.pre_quant_scale is present, clone it, all_reduce (MAX) across tp_group, cp_group, dp_group and assert torch.allclose(reduced, input_quantizer.pre_quant_scale).
- AWQ-Lite’s activation scale is stored on module.awq_lite.act_scale and model_calib syncs it using AVG across DP/CP — if you want to validate AWQ-Lite end-to-end, also check module.awq_lite.act_scale is synchronized (use the same group ops as model_calib).
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
examples/nemo_run/qat/README.md(1 hunks)modelopt/torch/quantization/model_calib.py(2 hunks)modelopt/torch/quantization/plugins/megatron.py(2 hunks)modelopt/torch/utils/distributed.py(1 hunks)tests/_test_utils/torch_dist/plugins/megatron_common.py(2 hunks)tests/_test_utils/torch_quantization/quantize_common.py(2 hunks)tests/gpu/torch/conftest.py(1 hunks)tests/gpu/torch/quantization/plugins/test_megatron.py(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
modelopt/torch/quantization/plugins/megatron.py (2)
modelopt/torch/opt/dynamic.py (2)
parallel_state(876-878)parallel_state(881-886)modelopt/torch/utils/distributed.py (1)
ParallelState(232-256)
tests/gpu/torch/quantization/plugins/test_megatron.py (3)
tests/_test_utils/torch_quantization/quantize_common.py (4)
context_parallel_test_helper(179-202)data_parallel_test_helper(153-176)data_tensor_context_parallel_test_helper(205-237)tensor_parallel_test_helper(119-150)tests/_test_utils/torch_dist/plugins/megatron_common.py (2)
initialize_for_megatron(386-398)MegatronModel(85-131)tests/gpu/torch/conftest.py (2)
need_2_gpus(32-34)need_8_gpus(38-40)
tests/_test_utils/torch_quantization/quantize_common.py (2)
tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
get_dummy_input(130-131)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
amax(231-236)amax(239-250)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
🔇 Additional comments (7)
modelopt/torch/quantization/model_calib.py (1)
82-98: Nice DP/CP amax sync integrationThe recursive helper cleanly reuses the existing SequentialQuantizer handling while adding the CP hop, so both DP and CP replicas end up aligned. 👍
tests/gpu/torch/conftest.py (1)
37-40:need_8_gpusfixture looks goodThe skip guard mirrors
need_2_gpus, so multi-rank tests will short-circuit cleanly when the hardware isn’t there.tests/gpu/torch/quantization/plugins/test_megatron.py (1)
176-205: Great coverage for the 2×2×2 scenarioSpinning up the combined DP/TP/CP path ensures the new sync logic is exercised end-to-end; thanks for wiring the groups explicitly.
tests/_test_utils/torch_quantization/quantize_common.py (2)
129-131: Good addition: post-quantization sanity forward.Running a forward pass after quantize helps catch latent issues. Looks good.
205-212: LGTM on the sequential TP→CP→DP all-reduce pattern.This correctly propagates maxima across orthogonal groups in 3D parallelism.
Please confirm group construction matches orthogonal decomposition (i.e., each rank belongs to exactly one group per dimension). If not, propagation may be incomplete.
examples/nemo_run/qat/README.md (1)
95-95: Doc command update looks good.Nice to see the example reflecting the new tensor-parallel setup.
modelopt/torch/utils/distributed.py (1)
244-256: Context group wiring looks solid.
ParallelStatenow mirrors DP/TP/CP consistently, so downstream logging/debugging will show the full layout.
| data_parallel_group = None | ||
| try: | ||
| data_parallel_group = get_data_parallel_group(with_context_parallel=True) | ||
| except AssertionError: | ||
| data_parallel_group = get_data_parallel_group() | ||
| self.parallel_state = ParallelState( | ||
| getattr(mcore_parallel, "get_expert_data_parallel_group", "get_data_parallel_group")(), | ||
| data_parallel_group, | ||
| mcore_parallel.get_tensor_model_parallel_group(), | ||
| mcore_parallel.get_context_parallel_group(), | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard get_context_parallel_group() when CP is disabled
get_context_parallel_group() asserts that context parallelism was initialized. When we run TP/DP-only (the default in plenty of setups), that assertion fires and _MegatronParallelLinear._setup() will crash. Please mirror the DP guard and fall back to -1 (unused) when the call raises.
Something along these lines keeps the DP-only path working:
- self.parallel_state = ParallelState(
- data_parallel_group,
- mcore_parallel.get_tensor_model_parallel_group(),
- mcore_parallel.get_context_parallel_group(),
- )
+ try:
+ context_parallel_group = mcore_parallel.get_context_parallel_group()
+ except AssertionError:
+ context_parallel_group = -1
+ self.parallel_state = ParallelState(
+ data_parallel_group,
+ mcore_parallel.get_tensor_model_parallel_group(),
+ context_parallel_group,
+ )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| data_parallel_group = None | |
| try: | |
| data_parallel_group = get_data_parallel_group(with_context_parallel=True) | |
| except AssertionError: | |
| data_parallel_group = get_data_parallel_group() | |
| self.parallel_state = ParallelState( | |
| getattr(mcore_parallel, "get_expert_data_parallel_group", "get_data_parallel_group")(), | |
| data_parallel_group, | |
| mcore_parallel.get_tensor_model_parallel_group(), | |
| mcore_parallel.get_context_parallel_group(), | |
| ) | |
| data_parallel_group = None | |
| try: | |
| data_parallel_group = get_data_parallel_group(with_context_parallel=True) | |
| except AssertionError: | |
| data_parallel_group = get_data_parallel_group() | |
| try: | |
| context_parallel_group = mcore_parallel.get_context_parallel_group() | |
| except AssertionError: | |
| context_parallel_group = -1 | |
| self.parallel_state = ParallelState( | |
| data_parallel_group, | |
| mcore_parallel.get_tensor_model_parallel_group(), | |
| context_parallel_group, | |
| ) |
🤖 Prompt for AI Agents
In modelopt/torch/quantization/plugins/megatron.py around lines 221 to 230, the
call to mcore_parallel.get_context_parallel_group() is unguarded and will assert
(and crash) when context-parallelism is disabled; mirror the data-parallel
guard: try to call get_context_parallel_group() and if it raises
(AssertionError) set the context group to -1 (or the sentinel used for
"unused"), then pass that value into ParallelState so TP/DP-only setups won't
fail. Ensure you only catch the assertion from the context-group call and keep
the existing fallback for get_data_parallel_group() unchanged.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
tests/_test_utils/torch_quantization/quantize_common.py(3 hunks)tests/gpu/torch/quantization/plugins/test_megatron.py(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/_test_utils/torch_quantization/quantize_common.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
SequentialQuantizer(1146-1254)amax(231-236)amax(239-250)
tests/gpu/torch/quantization/plugins/test_megatron.py (3)
tests/_test_utils/torch_quantization/quantize_common.py (3)
data_tensor_context_parallel_test_helper(185-217)dp_cp_parallel_test_helper(154-182)tensor_parallel_test_helper(120-151)tests/_test_utils/torch_dist/plugins/megatron_common.py (2)
initialize_for_megatron(386-398)MegatronModel(85-131)tests/gpu/torch/conftest.py (2)
need_2_gpus(32-34)need_8_gpus(38-40)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (2)
tests/_test_utils/torch_quantization/quantize_common.py (2)
151-151: Remove process group teardown from helperWe’re still tearing down the global process group inside the helper, which the test harness owns. Subsequent jobs in the same worker will hit a destroyed NCCL group. Please drop this call.
- dist.destroy_process_group()
162-182: Guard quantizer sync for AWQ-Lite and validate act_scaleAWQ-Lite quantizers set
is_mx_format, soTensorQuantizer.amaxreturnsNone(see modelopt/torch/quantization/nn/modules/tensor_quantizer.py lines 230-235). The current helper unconditionally clones.amax, which will raiseAttributeErroras soon as we run the AWQ-Lite DP/CP test. Even when it doesn’t explode, we never exercise the newly synchronizedact_scale, so regressions slide through. Please guard for missingamaxand also checkpre_quant_scale/act_scalewhen present.- def reduce_amax(quantizer): - amax = quantizer.amax.clone() - dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=group) - assert torch.allclose(amax, quantizer.amax) + def reduce_quant_attrs(quantizer): + validated = False + for attr in ("amax", "pre_quant_scale", "act_scale"): + tensor = getattr(quantizer, attr, None) + if tensor is None: + continue + synced = tensor.detach().clone() + dist.all_reduce(synced, op=dist.ReduceOp.MAX, group=group) + assert torch.allclose(synced, tensor), f"{attr} mismatch on {quantizer}" + validated = True + if not validated: + pytest.fail(f"{quantizer} exposed no distributed attribute to validate") @@ - reduce_amax(model.fc1.input_quantizer) - reduce_amax(model.fc2.input_quantizer) + reduce_quant_attrs(model.fc1.input_quantizer) + reduce_quant_attrs(model.fc2.input_quantizer) @@ - reduce_amax(quantizer) + reduce_quant_attrs(quantizer) @@ - reduce_amax(model.fc1.weight_quantizer) + reduce_quant_attrs(model.fc1.weight_quantizer) @@ - reduce_amax(quantizer) + reduce_quant_attrs(quantizer) @@ - reduce_amax(model.fc2.weight_quantizer) + reduce_quant_attrs(model.fc2.weight_quantizer)
Signed-off-by: Jenny Chen <jennifchen@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
modelopt/torch/quantization/model_calib.py (1)
602-612: Consider adding a docstring for consistency.The act_scale synchronization logic is correct (using AVG for statistics is appropriate), but the function lacks a docstring.
Apply this diff to add a docstring:
def sync_act_scale_across_dp_cp(module, data_parallel_group, context_parallel_group): + """Synchronize the act_scale across all ranks in the data parallel and context parallel groups.""" # Sync across Data Parallel (DP)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
modelopt/torch/quantization/model_calib.py(2 hunks)modelopt/torch/quantization/plugins/megatron.py(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt/torch/quantization/plugins/megatron.py
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/quantization/model_calib.py (3)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
SequentialQuantizer(1114-1222)sync_amax_across_distributed_group(1071-1083)TensorQuantizer(65-1111)is_enabled(395-397)modelopt/torch/utils/distributed.py (2)
is_initialized(49-51)is_initialized(196-198)modelopt/torch/quantization/utils.py (1)
is_quantized_linear(246-256)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (2)
modelopt/torch/quantization/model_calib.py (2)
83-91: LGTM! CP support properly added to amax synchronization.The extension from DP-only to DP+CP synchronization is implemented correctly. The function handles both
SequentialQuantizerrecursion and directTensorQuantizersync, and the docstring accurately reflects the new behavior.
98-98: Correct update to use CP-aware synchronization.The call site properly updated to invoke the new CP-aware sync function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
modelopt/torch/quantization/model_calib.py(2 hunks)tests/_test_utils/torch_dist/plugins/megatron_common.py(2 hunks)tests/_test_utils/torch_quantization/quantize_common.py(3 hunks)tests/gpu/torch/quantization/plugins/test_megatron.py(4 hunks)tests/gpu/torch/quantization/test_model_calib.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/_test_utils/torch_dist/plugins/megatron_common.py
🧰 Additional context used
🧬 Code graph analysis (4)
tests/gpu/torch/quantization/plugins/test_megatron.py (3)
tests/_test_utils/torch_quantization/quantize_common.py (3)
data_tensor_context_parallel_test_helper(185-221)dp_cp_parallel_test_helper(154-182)tensor_parallel_test_helper(120-151)tests/_test_utils/torch_dist/plugins/megatron_common.py (2)
initialize_for_megatron(386-401)MegatronModel(85-131)tests/gpu/torch/conftest.py (2)
need_2_gpus(32-34)need_8_gpus(38-40)
tests/gpu/torch/quantization/test_model_calib.py (2)
tests/_test_utils/torch_dist/plugins/megatron_common.py (3)
MegatronModel(85-131)initialize_for_megatron(386-401)get_dummy_input(130-131)modelopt/torch/quantization/model_calib.py (1)
awq_lite(424-677)
tests/_test_utils/torch_quantization/quantize_common.py (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
SequentialQuantizer(1114-1222)amax(236-241)amax(244-255)tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
get_dummy_input(130-131)
modelopt/torch/quantization/model_calib.py (4)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
SequentialQuantizer(1114-1222)sync_amax_across_distributed_group(1071-1083)TensorQuantizer(65-1111)is_enabled(395-397)modelopt/torch/trace/symbols.py (1)
named_modules(444-447)modelopt/torch/quantization/nn/modules/quant_module.py (1)
QuantModule(37-96)modelopt/torch/utils/distributed.py (2)
is_initialized(49-51)is_initialized(196-198)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (13)
modelopt/torch/quantization/model_calib.py (2)
83-98: LGTM! Well-structured CP-aware amax synchronization.The implementation correctly extends amax synchronization to context parallel groups while maintaining backward compatibility with data parallel. The recursive handling of
SequentialQuantizeris appropriate, and usingReduceOp.MAXfor both DP and CP groups ensures correct semantics (maximum amax across all ranks).
602-613: LGTM! Correct use of averaging for act_scale synchronization.The function properly synchronizes
act_scaleusingReduceOp.AVGacross both DP and CP groups, which is the correct reduction operation for averaging activation scales. The guard checks ensure synchronization only occurs when the groups are initialized.tests/gpu/torch/quantization/test_model_calib.py (1)
32-33: LGTM! Test setup is correct.The test properly uses
spawn_multiprocess_jobwith 2 GPUs and NCCL backend for distributed testing.tests/gpu/torch/quantization/plugins/test_megatron.py (5)
34-35: LGTM!The new imports are necessary for the DP/CP test helpers and context-parallel group retrieval used in the tests below.
Also applies to: 45-45
101-103: LGTM!The explicit
tp_sizekeyword argument improves clarity, and removingdp_groupfrom thetensor_parallel_test_helpercall aligns with the updated signature inquantize_common.py.
124-130: Per-rank seed is overridden; test won't catch broken DP sync.Passing
SEED + ranktoinitialize_for_megatronis overridden by the internal call tomodel_parallel_cuda_manual_seed(seed)(seetests/_test_utils/torch_dist/plugins/megatron_common.py, lines 385-400), so all ranks still produce identicalget_dummy_input()activations. The test will pass even if DP synchronization is broken. Introduce a rank-dependent perturbation after initialization—e.g., reseed or add a small offset before callingdp_cp_parallel_test_helper.Based on past review comments.
148-156: Per-rank seed is overridden; test won't catch broken CP sync.The same issue from the DP test applies here:
initialize_for_megatroninternally callsmodel_parallel_cuda_manual_seed(seed)with the provided seed, overriding the per-rank divergence you intended withSEED + rank. All CP ranks will produce identical calibration data, so the test won't fail if CP synchronization regresses. Add a rank-dependent perturbation after initialization.Based on past review comments.
176-187: Fixed seed produces identical calibration data; test won't catch broken DP/CP sync.Line 178 uses
SEEDwithout rank-dependent divergence. Sinceinitialize_for_megatroncallsmodel_parallel_cuda_manual_seed(SEED)uniformly across all 8 ranks, every rank will produce identicalget_dummy_input()activations, so the assertions indata_tensor_context_parallel_test_helperwill pass even if DP or CP synchronization is broken. Introduce rank-dependent perturbation (e.g.,SEED + rank + 1) or add a small offset after initialization to ensure different calibration data per DP/CP rank.Based on past review comments.
tests/_test_utils/torch_quantization/quantize_common.py (5)
26-26: LGTM!The
SequentialQuantizerimport is necessary for the new helpers to handle multi-format weight quantization correctly.
120-120: LGTM!Removing the unused
dp_groupparameter simplifies the signature; the function only validates tensor-parallel synchronization.
151-151: Remove global process group destruction from helper.This call unconditionally tears down the global process group, breaking any subsequent DP/CP/TP tests that run in the same process. The test harness owns the process group lifecycle. Remove this line.
Based on past review comments.
Apply this diff:
- dist.destroy_process_group()
154-183: Guard amax access and add AWQ-Lite scale validation.Line 163 unconditionally accesses
quantizer.amax, which returnsNonefor MX formats (seemodelopt/torch/quantization/nn/modules/tensor_quantizer.py, line 237) and will crash on.clone(). The config equality check at line 168 is brittle and misses future configs. Additionally, the PR objective includes synchronizing AWQ-Liteact_scale, but this helper doesn't validate it.Based on past review comments.
Replace config checks with attribute guards and add scale validation:
- def reduce_amax(quantizer): - amax = quantizer.amax.clone() - dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=group) - assert torch.allclose(amax, quantizer.amax) + def reduce_quant_attrs(quantizer): + for attr in ("amax", "pre_quant_scale", "act_scale"): + tensor = getattr(quantizer, attr, None) + if tensor is not None: + synced = tensor.clone() + dist.all_reduce(synced, op=dist.ReduceOp.MAX, group=group) + assert torch.allclose(synced, tensor), f"{attr} mismatch" # Input quantizer amax - if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: - reduce_amax(model.fc1.input_quantizer) - reduce_amax(model.fc2.input_quantizer) + if hasattr(model.fc1, "input_quantizer"): + reduce_quant_attrs(model.fc1.input_quantizer) + if hasattr(model.fc2, "input_quantizer"): + reduce_quant_attrs(model.fc2.input_quantizer) # Weight quantizer amax if isinstance(model.fc1.weight_quantizer, SequentialQuantizer): for quantizer in model.fc1.weight_quantizer: - reduce_amax(quantizer) + reduce_quant_attrs(quantizer) else: - reduce_amax(model.fc1.weight_quantizer) + reduce_quant_attrs(model.fc1.weight_quantizer) if isinstance(model.fc2.weight_quantizer, SequentialQuantizer): for quantizer in model.fc2.weight_quantizer: - reduce_amax(quantizer) + reduce_quant_attrs(quantizer) else: - reduce_amax(model.fc2.weight_quantizer) + reduce_quant_attrs(model.fc2.weight_quantizer)
185-222: Guard amax access, add AWQ-Lite scale validation, and remove debug prints.Lines 196, 197-198, 202-203 have the same issues as
dp_cp_parallel_test_helper: unconditional.amax.clone()crashes for MX formats, config equality checks are brittle, and AWQ-Liteact_scale/pre_quant_scalearen't validated. Additionally, theBased on past review comments.
Apply the same attribute-guard pattern from
dp_cp_parallel_test_helper, replicate for all three groups (dp_group, cp_group, tp_group), and remove print statements:- def reduce_amax(quantizer): - amax = quantizer.amax.clone() - print("amax before reduce", amax) - print("quantizer.amax before reduce", quantizer.amax) - dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=dp_group) - dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=cp_group) - dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=tp_group) - print("amax after reduce", amax) - print("quantizer.amax after reduce", quantizer.amax) - assert torch.allclose(amax, quantizer.amax) + def reduce_quant_attrs(quantizer): + for attr in ("amax", "pre_quant_scale", "act_scale"): + tensor = getattr(quantizer, attr, None) + if tensor is not None: + synced = tensor.clone() + for g in (dp_group, cp_group, tp_group): + dist.all_reduce(synced, op=dist.ReduceOp.MAX, group=g) + assert torch.allclose(synced, tensor), f"{attr} mismatch" # Input quantizer amax - if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: - reduce_amax(model.fc1.input_quantizer) - reduce_amax(model.fc2.input_quantizer) + if hasattr(model.fc1, "input_quantizer"): + reduce_quant_attrs(model.fc1.input_quantizer) + if hasattr(model.fc2, "input_quantizer"): + reduce_quant_attrs(model.fc2.input_quantizer) if isinstance(model.fc1.weight_quantizer, SequentialQuantizer): for quantizer in model.fc1.weight_quantizer: - reduce_amax(quantizer) + reduce_quant_attrs(quantizer) else: - reduce_amax(model.fc1.weight_quantizer) + reduce_quant_attrs(model.fc1.weight_quantizer) if isinstance(model.fc2.weight_quantizer, SequentialQuantizer): for quantizer in model.fc2.weight_quantizer: - reduce_amax(quantizer) + reduce_quant_attrs(quantizer) else: - reduce_amax(model.fc2.weight_quantizer) + reduce_quant_attrs(model.fc2.weight_quantizer)
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like we dont need separate methods tensor_parallel_test_helper, dp_cp_parallel_test_helper and data_tensor_context_parallel_test_helper for testing out all the combinations. Can we merge them into one and do data_tensor_context_parallel_test_helper(..., tp_group=None, dp_group=None, cp_group=None)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as https://github.com/NVIDIA/TensorRT-Model-Optimizer/pull/359/files#r2410523895
Can we combine dp, tp, cp tests by parameterize them?
realAsma
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks very good!
I left a few more comments on consolidating the unit tests. Let's try to consolidate the tests as much as we can to reduce the code maintainence.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
modelopt/torch/quantization/model_calib.py(3 hunks)tests/_test_utils/torch_dist/plugins/megatron_common.py(5 hunks)tests/_test_utils/torch_quantization/quantize_common.py(3 hunks)tests/gpu/torch/quantization/plugins/test_megatron.py(5 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt/torch/quantization/model_calib.py
🧰 Additional context used
🧬 Code graph analysis (2)
tests/gpu/torch/quantization/plugins/test_megatron.py (3)
tests/_test_utils/torch_quantization/quantize_common.py (1)
data_tensor_context_parallel_test_helper(181-248)tests/_test_utils/torch_dist/plugins/megatron_common.py (2)
initialize_for_megatron(394-409)MegatronModel(85-139)tests/gpu/torch/conftest.py (2)
need_2_gpus(32-34)need_8_gpus(38-40)
tests/_test_utils/torch_quantization/quantize_common.py (3)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
SequentialQuantizer(1114-1222)modelopt/torch/quantization/model_calib.py (1)
awq_lite(424-680)tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
get_dummy_input(134-139)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
| def _distributed_attr_check(quantizer, attr: str, op=dist.ReduceOp.MAX, groups=[]): | ||
| for group in groups: | ||
| if group is not None: | ||
| quantizer_attr = getattr(quantizer, attr).clone() | ||
| dist.all_reduce(quantizer_attr, op=op, group=group) | ||
| assert torch.allclose(quantizer_attr, getattr(quantizer, attr)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix distributed attribute check to cover all groups
_distributed_attr_check clones and reduces for every group, but the sole assertion after the loop only evaluates the final non‑None group. When we pass [dp_group, tp_group] this lets a DP regression slip through silently because the reduced tensor from the first group is discarded. Please assert inside the loop (or otherwise accumulate checks) so every provided group is actually validated.
-def _distributed_attr_check(quantizer, attr: str, op=dist.ReduceOp.MAX, groups=[]):
- for group in groups:
- if group is not None:
- quantizer_attr = getattr(quantizer, attr).clone()
- dist.all_reduce(quantizer_attr, op=op, group=group)
- assert torch.allclose(quantizer_attr, getattr(quantizer, attr))
+def _distributed_attr_check(quantizer, attr: str, op=dist.ReduceOp.MAX, groups=()):
+ attr_value = getattr(quantizer, attr)
+ checked = False
+ for group in groups:
+ if group is None:
+ continue
+ reduced = attr_value.clone()
+ dist.all_reduce(reduced, op=op, group=group)
+ assert torch.allclose(reduced, attr_value)
+ checked = True
+ assert checked, "expected at least one distributed group to validate"📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def _distributed_attr_check(quantizer, attr: str, op=dist.ReduceOp.MAX, groups=[]): | |
| for group in groups: | |
| if group is not None: | |
| quantizer_attr = getattr(quantizer, attr).clone() | |
| dist.all_reduce(quantizer_attr, op=op, group=group) | |
| assert torch.allclose(quantizer_attr, getattr(quantizer, attr)) | |
| def _distributed_attr_check(quantizer, attr: str, op=dist.ReduceOp.MAX, groups=()): | |
| attr_value = getattr(quantizer, attr) | |
| checked = False | |
| for group in groups: | |
| if group is None: | |
| continue | |
| reduced = attr_value.clone() | |
| dist.all_reduce(reduced, op=op, group=group) | |
| assert torch.allclose(reduced, attr_value) | |
| checked = True | |
| assert checked, "expected at least one distributed group to validate" |
🤖 Prompt for AI Agents
In tests/_test_utils/torch_quantization/quantize_common.py around lines 122-127,
the distributed attribute check clones and reduces for each group but only
asserts once after the loop, so only the last group's reduction is validated;
modify the function to perform the assertion inside the loop (i.e., after
dist.all_reduce for each non-None group, compare the reduced tensor to the
original attribute and assert they match) or alternatively accumulate/compare
reductions for every group before exiting so every provided group is actually
validated.
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (2)
tests/_test_utils/torch_quantization/quantize_common.py (2)
122-127: Assert per group in_distributed_attr_checkWe still only compare the reduction result from the final non-
Nonegroup against the local tensor, so any earlier group (e.g., DP) can drift without being caught. Please assert after eachall_reduce(and switch the mutable default) so every provided group is validated.-def _distributed_attr_check(quantizer, attr: str, op=dist.ReduceOp.MAX, groups=[]): - for group in groups: - if group is not None: - quantizer_attr = getattr(quantizer, attr).clone() - dist.all_reduce(quantizer_attr, op=op, group=group) - assert torch.allclose(quantizer_attr, getattr(quantizer, attr)) +def _distributed_attr_check(quantizer, attr: str, op=dist.ReduceOp.MAX, groups=()): + attr_value = getattr(quantizer, attr) + for group in groups: + if group is None: + continue + reduced = attr_value.clone() + dist.all_reduce(reduced, op=op, group=group) + assert torch.allclose(reduced, attr_value)
203-206: Also validatefc2.awq_lite.act_scaleRow-parallel AWQ-Lite sync can regress silently because we never touch
fc2.awq_lite.act_scale. Mirror thefc1check (guarded withgetattrif needed) so both halves of the layer are exercised.if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: _distributed_attr_check( model.fc1.awq_lite, "act_scale", dist.ReduceOp.AVG, groups=[dp_group, tp_group] ) + fc2_awq = getattr(model.fc2, "awq_lite", None) + if fc2_awq is not None: + _distributed_attr_check( + fc2_awq, "act_scale", dist.ReduceOp.AVG, groups=[dp_group, tp_group] + )
🧹 Nitpick comments (1)
tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
375-378: Limit inference debug prints to a single rankThese prints fire on every rank for every test run, flooding logs. Gate them behind a rank check (or drop them) so only rank 0 emits the shapes.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
tests/_test_utils/torch_dist/plugins/megatron_common.py(6 hunks)tests/_test_utils/torch_quantization/quantize_common.py(3 hunks)tests/gpu/torch/quantization/plugins/test_megatron.py(5 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/_test_utils/torch_quantization/quantize_common.py (3)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
SequentialQuantizer(1114-1222)modelopt/torch/quantization/model_calib.py (1)
awq_lite(424-680)tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
get_dummy_input(134-139)
tests/gpu/torch/quantization/plugins/test_megatron.py (3)
tests/_test_utils/torch_quantization/quantize_common.py (1)
data_tensor_context_parallel_test_helper(139-206)tests/_test_utils/torch_dist/plugins/megatron_common.py (2)
initialize_for_megatron(396-411)MegatronModel(85-139)tests/gpu/torch/conftest.py (2)
need_2_gpus(32-34)need_8_gpus(38-40)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
tests/gpu/torch/quantization/plugins/test_megatron.py (2)
100-100: Clarify the purpose ofuse_rank_in_seed.The parameter
use_rank_in_seedcontrols whether the global seed passed toinitialize_for_megatronincludes the rank offset. However, the actual calibration data divergence across DP/CP ranks is already handled bydata_tensor_context_parallel_test_helperseedingget_dummy_inputwithdp_rank(see quantize_common.py line 143). Consider documenting why this parameter affectsinitialize_for_megatron's seed—perhaps it ensures model initialization RNG states differ across ranks, which could matter for edge cases. If it's not needed, consider removing it to reduce confusion.
159-224: Consider parameterizing tests to reduce duplication.As noted by @realAsma, the four test functions (
test_tensor_parallel,test_data_parallel,test_context_parallel,test_data_tensor_context_parallel) share nearly identical structure. You could combine them into a single parameterized test:@pytest.mark.parametrize( "size,tp_size,cp_size,use_rank_in_seed,fixture", [ (2, 2, 1, False, "need_2_gpus"), # TP (2, 1, 1, True, "need_2_gpus"), # DP (2, 1, 2, True, "need_2_gpus"), # CP (8, 2, 2, True, "need_8_gpus"), # DP+TP+CP ], ) @pytest.mark.parametrize("config", [...]) def test_parallelism(request, size, tp_size, cp_size, use_rank_in_seed, fixture, config): request.getfixturevalue(fixture) spawn_multiprocess_job( size=size, job=partial( _test_parallelism_helper, config, tensor_model_parallel_size=tp_size, context_parallel_size=cp_size, use_rank_in_seed=use_rank_in_seed, ), backend="nccl", )This would improve maintainability by centralizing the test logic.
Based on learnings
tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
382-384: Consider replacingDebug statements using
logging.debug()would allow better control over verbosity in CI/local runs.Apply this change if you want to use logging:
+import logging + +logger = logging.getLogger(__name__) ... - print("inference_input size", inference_input["tokens"].shape) + logger.debug("inference_input size %s", inference_input["tokens"].shape) logits = wrapped_model.run_one_forward_step(inference_input) - print("logits size", logits.shape) + logger.debug("logits size %s", logits.shape)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
tests/_test_utils/torch_dist/plugins/megatron_common.py(6 hunks)tests/gpu/torch/quantization/plugins/test_megatron.py(5 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/gpu/torch/quantization/plugins/test_megatron.py (3)
tests/_test_utils/torch_quantization/quantize_common.py (1)
data_tensor_context_parallel_test_helper(139-206)tests/_test_utils/torch_dist/plugins/megatron_common.py (2)
initialize_for_megatron(403-418)MegatronModel(86-140)tests/gpu/torch/conftest.py (2)
need_2_gpus(32-34)need_8_gpus(38-40)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (3)
tests/gpu/torch/quantization/plugins/test_megatron.py (1)
93-136: LGTM: Unified helper consolidates TP/DP/CP test logic.The refactor to a single
_test_parallelism_helperreduces duplication and correctly wires parallelism groups. The logic properly:
- Determines
tp_groupanddp_groupbased on parallelism config- Passes them to
data_tensor_context_parallel_test_helperfor validation- Uses rank-based seeding when
use_rank_in_seed=Truefor DP/CP scenariostests/_test_utils/torch_dist/plugins/megatron_common.py (2)
87-140: LGTM: MegatronModel extensions support CP and seeded inputs.The additions properly support context parallelism and enable per-rank calibration data divergence:
cp_sizeandtp_groupparameters thread through to TransformerConfig and parallel layersget_dummy_input(seed)allows deterministic, per-rank data generation whenseedis provided- Backward compatibility maintained by defaulting new parameters
403-418: LGTM: Parameter order maintains backward compatibility.The parameter order (
tensor_model_parallel_size,pipeline_model_parallel_size,seed,context_parallel_size) ensures existing positional calls passingseedcontinue to work, addressing the concern raised in previous reviews.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
401-416: LGTM! Backward compatibility correctly restored.The parameter ordering now correctly places
seedin the third position, maintaining backward compatibility with existing code that passes it positionally. The newcontext_parallel_sizeparameter is appropriately placed at the end and correctly passed toinitialize_model_parallel.Consider updating the docstring to document the new
context_parallel_sizeparameter:def initialize_for_megatron( tensor_model_parallel_size=1, pipeline_model_parallel_size=1, seed=1234, context_parallel_size=1, ): """Initialize Megatron model parallelism. + + Args: + tensor_model_parallel_size: Size of tensor model parallel group. + pipeline_model_parallel_size: Size of pipeline model parallel group. + seed: Random seed for model parallel initialization. + context_parallel_size: Size of context parallel group. NOTE: If used in a non-spawned process, make sure to call `megatron.core.parallel_state.destroy_model_parallel()`. """
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
examples/nemo_run/qat/README.md(2 hunks)tests/_test_utils/torch_dist/plugins/megatron_common.py(5 hunks)tests/gpu/torch/quantization/plugins/test_apex.py(3 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- examples/nemo_run/qat/README.md
🧰 Additional context used
🧬 Code graph analysis (2)
tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
tests/gpu/torch/quantization/plugins/test_apex.py (1)
get_dummy_input(61-66)
tests/gpu/torch/quantization/plugins/test_apex.py (2)
tests/_test_utils/torch_quantization/quantize_common.py (1)
data_tensor_context_parallel_test_helper(139-206)tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
get_dummy_input(135-140)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (5)
tests/_test_utils/torch_dist/plugins/megatron_common.py (3)
87-92: LGTM! Context parallel and tensor parallel group parameters added correctly.The additions of
cp_sizeandtp_groupparameters enable fine-grained control over parallel configurations for testing. Default values maintain backward compatibility for existing callers.
110-110: LGTM! Tensor parallel group correctly propagated to parallel layers.Passing
tp_groupto bothColumnParallelLinearandRowParallelLinearenables explicit control over the tensor parallel group, which is essential for testing scenarios with multiple parallel groups.Also applies to: 125-125
135-140: LGTM! Deterministic input generation implemented correctly.The optional
seedparameter enables reproducible test inputs when needed, while maintaining non-seeded behavior by default. The implementation correctly usestorch.Generatorwithmanual_seedand follows the established pattern fromtest_apex.py.tests/gpu/torch/quantization/plugins/test_apex.py (2)
26-26: LGTM!The import name change from
tensor_parallel_test_helpertodata_tensor_context_parallel_test_helpercorrectly reflects the updated test utility that now supports combined DP/TP/CP testing.
61-66: LGTM!The seeded input generation implementation is correct and matches the pattern used in the Megatron common test utilities (see
tests/_test_utils/torch_dist/plugins/megatron_common.pylines 134-139). This enables deterministic data generation per DP rank for proper distributed testing.
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
examples/nemo_run/qat/README.md(2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: wait-checks / wait
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
examples/nemo_run/qat/README.md
Outdated
| Run docker command: | ||
|
|
||
| ```bash | ||
| docker run -v /home/user/:/home/user/ -v /home/user/NeMo:/opt/NeMo -v /home/user/TensorRT-Model-Optimizer/modelopt/:/usr/local/lib/python3.12/dist-packages/modelopt --gpus all -it --shm-size 20g --rm nvcr.io/nvidia/nemo:25.07 bash | ||
| docker run -v /home/user/:/home/user/ -v /home/user/TensorRT-Model-Optimizer/:/opt/TensorRT-Model-Optimizer/ --gpus all -it --shm-size 20g --rm nvcr.io/nvidia/nemo:25.09 bash | ||
| ``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pass HF_TOKEN into the container
You tell users to export HF_TOKEN, but the subsequent docker run does not propagate that variable, so the container still can’t access the token and downloads will fail. Add -e HF_TOKEN=$HF_TOKEN (or equivalent) to the command.
🤖 Prompt for AI Agents
In examples/nemo_run/qat/README.md around lines 66 to 70, the docker run command
does not propagate the HF_TOKEN environment variable despite instructing users
to export it; update the docker run invocation to pass the token into the
container (for example by adding -e HF_TOKEN=$HF_TOKEN or an equivalent --env
specification) so the container can access the token for downloads.
| mtq.NVFP4_DEFAULT_CFG, | ||
| ], | ||
| ) | ||
| def test_data_parallel(need_2_gpus, config): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For data parallel and context parallel, do we really need to test all configs? Or testing one sufficient given that we have extensive tensor parallel tests?
Thoughts @realAsma @jenchen13
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
| module.awq_lite.act_scale = module.awq_lite.act_scale / module.awq_lite.num_cache_steps | ||
| if torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any( | ||
|
|
||
| has_nan_local = torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch on this.
Can we use
DistributedProcessGroup.get_dist_syncd_obj(has_nan_local, module.parallel_state.data_parallel_group, lambda obis: any(obis))
| ) | ||
| has_nan = torch.tensor(int(has_nan_local), device=module.awq_lite.act_scale.device) | ||
| if module.parallel_state.data_parallel_group.is_initialized(): | ||
| dist.all_reduce( | ||
| has_nan, | ||
| op=dist.ReduceOp.MAX, | ||
| group=module.parallel_state.data_parallel_group.group, | ||
| ) | ||
|
|
||
| if has_nan.item() > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use the sync_obj to sync has_nan
| ) | |
| has_nan = torch.tensor(int(has_nan_local), device=module.awq_lite.act_scale.device) | |
| if module.parallel_state.data_parallel_group.is_initialized(): | |
| dist.all_reduce( | |
| has_nan, | |
| op=dist.ReduceOp.MAX, | |
| group=module.parallel_state.data_parallel_group.group, | |
| ) | |
| if has_nan.item() > 0: | |
| ) | |
| has_nan = DistributedProcessGroup.get_dist_syncd_obj(has_nan_local, module.parallel_state.data_parallel_group, lambda objs: any(objs)) | |
| if has_nan: |
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
examples/nemo_run/qat/README.md(2 hunks)modelopt/torch/quantization/model_calib.py(4 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- examples/nemo_run/qat/README.md
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/quantization/model_calib.py (3)
modelopt/torch/utils/distributed.py (5)
DistributedProcessGroup(189-229)ParallelState(232-253)is_initialized(49-51)is_initialized(196-198)get_dist_syncd_obj(212-229)modelopt/torch/trace/symbols.py (1)
named_modules(444-447)modelopt/torch/quantization/utils.py (1)
is_quantized_linear(246-256)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
| has_nan_local = torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any( | ||
| torch.isnan(module.awq_lite.weight_scale) | ||
| ): | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix tensor boolean evaluation before distributed sync.
torch.any(...) returns a 0‑D tensor, so the Python or tries to convert that tensor to bool, triggering RuntimeError: Boolean value of Tensor with more than one value is ambiguous at runtime. AWQ-Lite calibration will crash the first time it hits this branch. Convert the NaN checks to Python booleans before combining them.
- has_nan_local = torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any(
- torch.isnan(module.awq_lite.weight_scale)
- )
+ act_nan = torch.isnan(module.awq_lite.act_scale).any().item()
+ weight_nan = torch.isnan(module.awq_lite.weight_scale).any().item()
+ has_nan_local = bool(act_nan or weight_nan)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| has_nan_local = torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any( | |
| torch.isnan(module.awq_lite.weight_scale) | |
| ): | |
| ) | |
| @@ modelopt/torch/quantization/model_calib.py:619 | |
| - has_nan_local = torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any( | |
| - torch.isnan(module.awq_lite.weight_scale) | |
| act_nan = torch.isnan(module.awq_lite.act_scale).any().item() | |
| weight_nan = torch.isnan(module.awq_lite.weight_scale).any().item() | |
| has_nan_local = bool(act_nan or weight_nan) |
🤖 Prompt for AI Agents
In modelopt/torch/quantization/model_calib.py around lines 619 to 621, the code
uses torch.any(...) twice and combines them with the Python "or", which attempts
to convert 0-D tensors to bools and raises a RuntimeError; change the checks to
produce Python booleans before combining (for example call .any().item() or wrap
each check with bool(...).cpu().item() as appropriate) so has_nan_local becomes
a plain Python bool, then use that to drive the subsequent logic.
What does this PR do?
Type of change: ? New Feature
Overview: Sync quantizer amax in Context Parallelism & AWQ-Lite
act_scalein CP/DPUsage
# Add a code snippet demonstrating how to use thisTesting
act_scaleBefore your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Documentation
Bug Fixes
Style
Tests
Performance