Skip to content

Conversation

@jenchen13
Copy link
Contributor

@jenchen13 jenchen13 commented Sep 24, 2025

What does this PR do?

Type of change: ? New Feature

Overview: Sync quantizer amax in Context Parallelism & AWQ-Lite act_scale in CP/DP

Usage

# Add a code snippet demonstrating how to use this

Testing

  • tests for DP, CP, and DP/TP/CP combined
  • tests for AWQ lite act_scale

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

Summary by CodeRabbit

  • New Features

    • Parallel-aware quantization calibration: DP activation-scale synchronization with NaN-safe disable and AWQ‑Lite DP sync.
  • Documentation

    • README updated: HuggingFace token export, adjusted clone/mount steps, Docker example uses NeMo v25.09, and QAT usage notes.
  • Bug Fixes

    • Safer parallel-group initialization with warning/fallback behavior to improve robustness.
  • Style

    • Parallel state string representation moved to multi-line for clearer output.
  • Tests

    • Expanded DP/TP/CP coverage, unified test helpers, AWQ‑Lite mocking, seeded inputs, larger test models, and new need_8_gpus fixture.
  • Performance

    • Reduced data-loading parallelism in example flow (num_workers 1).

@copy-pr-bot
Copy link

copy-pr-bot bot commented Sep 24, 2025

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 24, 2025

Walkthrough

Adds DP-aware activation-scale and amax synchronization with local NaN detection and DP-wide NaN propagation/disable during calibration (including AWQ‑Lite); makes Megatron plugin request a CP-aware data-parallel group with fallback+warning; tweaks ParallelState repr; expands tests/fixtures for DP/TP/CP (adds 8‑GPU fixture); updates example README/docker run.

Changes

Cohort / File(s) Summary
Quantization calibration CP/DP sync
modelopt/torch/quantization/model_calib.py
Adds sync_act_scale_across_dp helper and integrates DP-aware amax/act_scale synchronization into max_calibrate and AWQ‑Lite flows; implements local NaN detection, DP-wide NaN propagation, and module-disable path; adds inline documentation.
Megatron plugin — parallel group retrieval & logging
modelopt/torch/quantization/plugins/megatron.py
Imports logging and get_data_parallel_group, creates a logger, tries get_data_parallel_group(with_context_parallel=True) with AssertionError handling and fallback to non-context call, logs a warning on fallback, and passes explicit data_parallel_group into ParallelState.
ParallelState repr formatting
modelopt/torch/utils/distributed.py
Changes ParallelState.__repr__ to multi-line formatting that includes data_parallel_group and tensor_parallel_group (no behavior/API change).
Quantization test helpers & AWQ patching
tests/_test_utils/torch_quantization/quantize_common.py
Adds centralized _distributed_attr_check, introduces _debug_awq_lite and patching support for AWQ‑Lite, imports SequentialQuantizer, and replaces TP-only helper with data_tensor_context_parallel_test_helper supporting DP/TP/CP and patch injection.
Megatron test utilities: CP wiring & seeded inputs
tests/_test_utils/torch_dist/plugins/megatron_common.py
Extends MegatronModel.__init__ to accept cp_size and tp_group, wires context_parallel_size into initialization, passes tp_group into parallel linear layers, and adds optional seed to get_dummy_input for deterministic inputs.
Megatron quantization tests: DP/TP/CP suites
tests/gpu/torch/quantization/plugins/test_megatron.py
Replaces tensor_parallel_test_helper with data_tensor_context_parallel_test_helper; adds unified _test_parallelism_helper covering TP/DP/CP combinations; updates MegatronModel instantiation to include tp_group/cp_size; increases model attention heads in tests.
Apex tests: seeded inputs & helper rename
tests/gpu/torch/quantization/plugins/test_apex.py
Replaces tensor_parallel_test_helper usage with data_tensor_context_parallel_test_helper; updates `ApexModel.get_dummy_input(self, seed: int
GPU test fixtures
tests/gpu/torch/conftest.py
Adds need_8_gpus pytest fixture that skips tests when fewer than 8 GPUs are available.
Examples / README runbook
examples/nemo_run/qat/README.md
Updates run instructions to reference cloning TensorRT‑Model‑Optimizer first, sets a HuggingFace token export, updates NeMo container tag to 25.09, adjusts docker mount paths, and appends --tensor_parallelism 4 to the QAT invocation.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Calib as Calibration Loop
  participant Mod as Quantized Module
  participant Q as Quantizer
  participant DP as Data-Parallel Group
  Note right of DP #D3E4CD: DP sync & NaN propagation
  Calib->>Mod: run forward / collect stats
  Mod->>Q: compute local amax / act_scale
  alt local NaN detected
    Mod->>DP: all-reduce NaN flag
    DP-->>Mod: consensus NaN (any rank true)
    Mod-->>Calib: mark module disabled due to NaN
  else
    Mod->>DP: all-reduce amax/act_scale (DP sync)
    Q-->>Mod: updated synchronized scales
    Mod-->>Calib: continue calibration
  end
Loading
sequenceDiagram
  autonumber
  participant Init as Megatron Plugin Init
  participant MPG as megatron.core.parallel_state
  participant PS as ParallelState
  Note right of MPG #F0F4FF: try CP-aware DP group, fallback if uninitialized
  Init->>MPG: get_data_parallel_group(with_context_parallel=true)
  alt context-parallel available
    MPG-->>Init: dp_group (CP-aware)
  else fallback
    Init->>MPG: get_data_parallel_group()
    MPG-->>Init: dp_group (fallback)
    Init->>Init: log warning about fallback
  end
  Init->>MPG: get_tensor_model_parallel_group()
  Init->>PS: __init__(data_parallel_group=dp_group, tensor_parallel_group=tp_group, ...)
  PS-->>Init: ParallelState ready
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

I nibble at scales and hop through the code,
syncing amax across ranks on my road.
Megatron finds groups, seeds tests to run,
docker mounts and tokens set — work done.
A quantized carrot crunches — hop, celebrate, friend 🐇✨

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The provided title clearly and concisely summarizes the primary change, which is the synchronization of quantizer amax and AWQ-Lite act_scale across context and data parallel groups, and it directly reflects the main objective of the pull request without unnecessary detail.
Docstring Coverage ✅ Passed Docstring coverage is 80.00% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch jennifchen/cp_amax_sync

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@codecov
Copy link

codecov bot commented Sep 24, 2025

Codecov Report

❌ Patch coverage is 88.88889% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 73.36%. Comparing base (5b02483) to head (afe6f34).
⚠️ Report is 5 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/quantization/model_calib.py 88.88% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #359      +/-   ##
==========================================
- Coverage   73.36%   73.36%   -0.01%     
==========================================
  Files         180      180              
  Lines       17919    17925       +6     
==========================================
+ Hits        13147    13151       +4     
- Misses       4772     4774       +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@realAsma
Copy link
Contributor

@jenchen13 could you please add unit tests for context parallel quantization (similar to tensor parallel) to here -

def test_tensor_parallel(need_2_gpus, config):

basically the TP test checks whether amax is similar across the TP group. see

def tensor_parallel_test_helper(model, config, tp_group, dp_group):

@copy-pr-bot
Copy link

copy-pr-bot bot commented Sep 24, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
@jenchen13 jenchen13 force-pushed the jennifchen/cp_amax_sync branch from e764e79 to 42519cc Compare September 25, 2025 18:57
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
@jenchen13 jenchen13 force-pushed the jennifchen/cp_amax_sync branch from aa5b8fd to 264adbb Compare September 25, 2025 21:46
@jenchen13 jenchen13 changed the title sync amax in context parallel Sync amax & AWQ-Lite in context parallel/data parallel Sep 25, 2025
@jenchen13 jenchen13 changed the title Sync amax & AWQ-Lite in context parallel/data parallel Sync amax & AWQ-Lite act_scale in context parallel/data parallel Sep 25, 2025
@jenchen13 jenchen13 marked this pull request as ready for review September 25, 2025 23:23
@jenchen13 jenchen13 requested review from a team as code owners September 25, 2025 23:23
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tests/_test_utils/torch_quantization/quantize_common.py (1)

132-149: Replace config-based guards with attribute-presence checks and sync AWQ pre_quant_scale
• For amax:

if model.fc2.input_quantizer.amax is not None:
    activation_amax = model.fc2.input_quantizer.amax.clone()
    dist.all_reduce(activation_amax, op=dist.ReduceOp.MAX, group=tp_group)
    assert torch.allclose(activation_amax, model.fc2.input_quantizer.amax)

• For scales (SmoothQuant, AWQ/AWQ-Lite):

# input scale
if (scale := model.fc1.input_quantizer.pre_quant_scale) is not None:
    scale_clone = scale.clone()
    dist.all_reduce(scale_clone, op=dist.ReduceOp.MAX, group=tp_group)
    assert torch.allclose(scale_clone, scale)
# weight scale (AWQ-Lite)
if (wscale := model.fc1.weight_quantizer.pre_quant_scale) is not None:
    wscale_clone = wscale.clone()
    dist.all_reduce(wscale_clone, op=dist.ReduceOp.MAX, group=tp_group)
    assert torch.allclose(wscale_clone, wscale)

Drop all if config in […] checks.

🧹 Nitpick comments (2)
modelopt/torch/quantization/model_calib.py (1)

628-651: Weight DP/CP averages by token counts

Right now we average act_scale equally across ranks. In mixed-workload runs (e.g., MoE routing) we can see uneven num_tokens, so the lighter ranks end up pulling the mean down. Since we already track num_tokens, we can switch to a weighted reduction (sum of scale * tokens and sum of tokens) before normalizing.

A sketch:

-    module.awq_lite.act_scale = module.awq_lite.act_scale / module.awq_lite.num_cache_steps
-    sync_act_scale_across_dp_cp(...)
+    module.awq_lite.act_scale = module.awq_lite.act_scale / module.awq_lite.num_cache_steps
+    token_count = torch.tensor(
+        [module.awq_lite.num_tokens],
+        device=module.awq_lite.act_scale.device,
+        dtype=module.awq_lite.act_scale.dtype,
+    )
+    scale_sum = module.awq_lite.act_scale * token_count.item()
+    sync_reduction_across_dp_cp(scale_sum, token_count, module.parallel_state)
+    module.awq_lite.act_scale = scale_sum / token_count

(sync_reduction_across_dp_cp would all-reduce both tensors across DP/CP groups.)

tests/_test_utils/torch_quantization/quantize_common.py (1)

215-237: 3D helper: sync input pre_quant_scale across TP/CP/DP

  • Replace the config-based presence check with an attribute check for input amax (e.g. if getattr(model.fc1.input_quantizer, "amax", None) is not None) in tests/_test_utils/torch_quantization/quantize_common.py → data_tensor_context_parallel_test_helper.
  • If input_quantizer.pre_quant_scale is present, clone it, all_reduce (MAX) across tp_group, cp_group, dp_group and assert torch.allclose(reduced, input_quantizer.pre_quant_scale).
  • AWQ-Lite’s activation scale is stored on module.awq_lite.act_scale and model_calib syncs it using AVG across DP/CP — if you want to validate AWQ-Lite end-to-end, also check module.awq_lite.act_scale is synchronized (use the same group ops as model_calib).
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b4d6ced and 7cbe5b9.

📒 Files selected for processing (8)
  • examples/nemo_run/qat/README.md (1 hunks)
  • modelopt/torch/quantization/model_calib.py (2 hunks)
  • modelopt/torch/quantization/plugins/megatron.py (2 hunks)
  • modelopt/torch/utils/distributed.py (1 hunks)
  • tests/_test_utils/torch_dist/plugins/megatron_common.py (2 hunks)
  • tests/_test_utils/torch_quantization/quantize_common.py (2 hunks)
  • tests/gpu/torch/conftest.py (1 hunks)
  • tests/gpu/torch/quantization/plugins/test_megatron.py (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
modelopt/torch/quantization/plugins/megatron.py (2)
modelopt/torch/opt/dynamic.py (2)
  • parallel_state (876-878)
  • parallel_state (881-886)
modelopt/torch/utils/distributed.py (1)
  • ParallelState (232-256)
tests/gpu/torch/quantization/plugins/test_megatron.py (3)
tests/_test_utils/torch_quantization/quantize_common.py (4)
  • context_parallel_test_helper (179-202)
  • data_parallel_test_helper (153-176)
  • data_tensor_context_parallel_test_helper (205-237)
  • tensor_parallel_test_helper (119-150)
tests/_test_utils/torch_dist/plugins/megatron_common.py (2)
  • initialize_for_megatron (386-398)
  • MegatronModel (85-131)
tests/gpu/torch/conftest.py (2)
  • need_2_gpus (32-34)
  • need_8_gpus (38-40)
tests/_test_utils/torch_quantization/quantize_common.py (2)
tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
  • get_dummy_input (130-131)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
  • amax (231-236)
  • amax (239-250)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
🔇 Additional comments (7)
modelopt/torch/quantization/model_calib.py (1)

82-98: Nice DP/CP amax sync integration

The recursive helper cleanly reuses the existing SequentialQuantizer handling while adding the CP hop, so both DP and CP replicas end up aligned. 👍

tests/gpu/torch/conftest.py (1)

37-40: need_8_gpus fixture looks good

The skip guard mirrors need_2_gpus, so multi-rank tests will short-circuit cleanly when the hardware isn’t there.

tests/gpu/torch/quantization/plugins/test_megatron.py (1)

176-205: Great coverage for the 2×2×2 scenario

Spinning up the combined DP/TP/CP path ensures the new sync logic is exercised end-to-end; thanks for wiring the groups explicitly.

tests/_test_utils/torch_quantization/quantize_common.py (2)

129-131: Good addition: post-quantization sanity forward.

Running a forward pass after quantize helps catch latent issues. Looks good.


205-212: LGTM on the sequential TP→CP→DP all-reduce pattern.

This correctly propagates maxima across orthogonal groups in 3D parallelism.

Please confirm group construction matches orthogonal decomposition (i.e., each rank belongs to exactly one group per dimension). If not, propagation may be incomplete.

examples/nemo_run/qat/README.md (1)

95-95: Doc command update looks good.

Nice to see the example reflecting the new tensor-parallel setup.

modelopt/torch/utils/distributed.py (1)

244-256: Context group wiring looks solid.

ParallelState now mirrors DP/TP/CP consistently, so downstream logging/debugging will show the full layout.

Comment on lines 221 to 230
data_parallel_group = None
try:
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
except AssertionError:
data_parallel_group = get_data_parallel_group()
self.parallel_state = ParallelState(
getattr(mcore_parallel, "get_expert_data_parallel_group", "get_data_parallel_group")(),
data_parallel_group,
mcore_parallel.get_tensor_model_parallel_group(),
mcore_parallel.get_context_parallel_group(),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard get_context_parallel_group() when CP is disabled

get_context_parallel_group() asserts that context parallelism was initialized. When we run TP/DP-only (the default in plenty of setups), that assertion fires and _MegatronParallelLinear._setup() will crash. Please mirror the DP guard and fall back to -1 (unused) when the call raises.

Something along these lines keeps the DP-only path working:

-        self.parallel_state = ParallelState(
-            data_parallel_group,
-            mcore_parallel.get_tensor_model_parallel_group(),
-            mcore_parallel.get_context_parallel_group(),
-        )
+        try:
+            context_parallel_group = mcore_parallel.get_context_parallel_group()
+        except AssertionError:
+            context_parallel_group = -1
+        self.parallel_state = ParallelState(
+            data_parallel_group,
+            mcore_parallel.get_tensor_model_parallel_group(),
+            context_parallel_group,
+        )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
data_parallel_group = None
try:
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
except AssertionError:
data_parallel_group = get_data_parallel_group()
self.parallel_state = ParallelState(
getattr(mcore_parallel, "get_expert_data_parallel_group", "get_data_parallel_group")(),
data_parallel_group,
mcore_parallel.get_tensor_model_parallel_group(),
mcore_parallel.get_context_parallel_group(),
)
data_parallel_group = None
try:
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
except AssertionError:
data_parallel_group = get_data_parallel_group()
try:
context_parallel_group = mcore_parallel.get_context_parallel_group()
except AssertionError:
context_parallel_group = -1
self.parallel_state = ParallelState(
data_parallel_group,
mcore_parallel.get_tensor_model_parallel_group(),
context_parallel_group,
)
🤖 Prompt for AI Agents
In modelopt/torch/quantization/plugins/megatron.py around lines 221 to 230, the
call to mcore_parallel.get_context_parallel_group() is unguarded and will assert
(and crash) when context-parallelism is disabled; mirror the data-parallel
guard: try to call get_context_parallel_group() and if it raises
(AssertionError) set the context group to -1 (or the sentinel used for
"unused"), then pass that value into ParallelState so TP/DP-only setups won't
fail. Ensure you only catch the assertion from the context-group call and keep
the existing fallback for get_data_parallel_group() unchanged.

Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7cbe5b9 and 1f7d17e.

📒 Files selected for processing (2)
  • tests/_test_utils/torch_quantization/quantize_common.py (3 hunks)
  • tests/gpu/torch/quantization/plugins/test_megatron.py (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/_test_utils/torch_quantization/quantize_common.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
  • SequentialQuantizer (1146-1254)
  • amax (231-236)
  • amax (239-250)
tests/gpu/torch/quantization/plugins/test_megatron.py (3)
tests/_test_utils/torch_quantization/quantize_common.py (3)
  • data_tensor_context_parallel_test_helper (185-217)
  • dp_cp_parallel_test_helper (154-182)
  • tensor_parallel_test_helper (120-151)
tests/_test_utils/torch_dist/plugins/megatron_common.py (2)
  • initialize_for_megatron (386-398)
  • MegatronModel (85-131)
tests/gpu/torch/conftest.py (2)
  • need_2_gpus (32-34)
  • need_8_gpus (38-40)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (2)
tests/_test_utils/torch_quantization/quantize_common.py (2)

151-151: Remove process group teardown from helper

We’re still tearing down the global process group inside the helper, which the test harness owns. Subsequent jobs in the same worker will hit a destroyed NCCL group. Please drop this call.

-    dist.destroy_process_group()

162-182: Guard quantizer sync for AWQ-Lite and validate act_scale

AWQ-Lite quantizers set is_mx_format, so TensorQuantizer.amax returns None (see modelopt/torch/quantization/nn/modules/tensor_quantizer.py lines 230-235). The current helper unconditionally clones .amax, which will raise AttributeError as soon as we run the AWQ-Lite DP/CP test. Even when it doesn’t explode, we never exercise the newly synchronized act_scale, so regressions slide through. Please guard for missing amax and also check pre_quant_scale/act_scale when present.

-    def reduce_amax(quantizer):
-        amax = quantizer.amax.clone()
-        dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=group)
-        assert torch.allclose(amax, quantizer.amax)
+    def reduce_quant_attrs(quantizer):
+        validated = False
+        for attr in ("amax", "pre_quant_scale", "act_scale"):
+            tensor = getattr(quantizer, attr, None)
+            if tensor is None:
+                continue
+            synced = tensor.detach().clone()
+            dist.all_reduce(synced, op=dist.ReduceOp.MAX, group=group)
+            assert torch.allclose(synced, tensor), f"{attr} mismatch on {quantizer}"
+            validated = True
+        if not validated:
+            pytest.fail(f"{quantizer} exposed no distributed attribute to validate")
@@
-        reduce_amax(model.fc1.input_quantizer)
-        reduce_amax(model.fc2.input_quantizer)
+        reduce_quant_attrs(model.fc1.input_quantizer)
+        reduce_quant_attrs(model.fc2.input_quantizer)
@@
-            reduce_amax(quantizer)
+            reduce_quant_attrs(quantizer)
@@
-        reduce_amax(model.fc1.weight_quantizer)
+        reduce_quant_attrs(model.fc1.weight_quantizer)
@@
-            reduce_amax(quantizer)
+            reduce_quant_attrs(quantizer)
@@
-        reduce_amax(model.fc2.weight_quantizer)
+        reduce_quant_attrs(model.fc2.weight_quantizer)

Signed-off-by: Jenny Chen <jennifchen@nvidia.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
modelopt/torch/quantization/model_calib.py (1)

602-612: Consider adding a docstring for consistency.

The act_scale synchronization logic is correct (using AVG for statistics is appropriate), but the function lacks a docstring.

Apply this diff to add a docstring:

 def sync_act_scale_across_dp_cp(module, data_parallel_group, context_parallel_group):
+    """Synchronize the act_scale across all ranks in the data parallel and context parallel groups."""
     # Sync across Data Parallel (DP)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1f7d17e and 71a9f7a.

📒 Files selected for processing (2)
  • modelopt/torch/quantization/model_calib.py (2 hunks)
  • modelopt/torch/quantization/plugins/megatron.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/quantization/plugins/megatron.py
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/quantization/model_calib.py (3)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
  • SequentialQuantizer (1114-1222)
  • sync_amax_across_distributed_group (1071-1083)
  • TensorQuantizer (65-1111)
  • is_enabled (395-397)
modelopt/torch/utils/distributed.py (2)
  • is_initialized (49-51)
  • is_initialized (196-198)
modelopt/torch/quantization/utils.py (1)
  • is_quantized_linear (246-256)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (2)
modelopt/torch/quantization/model_calib.py (2)

83-91: LGTM! CP support properly added to amax synchronization.

The extension from DP-only to DP+CP synchronization is implemented correctly. The function handles both SequentialQuantizer recursion and direct TensorQuantizer sync, and the docstring accurately reflects the new behavior.


98-98: Correct update to use CP-aware synchronization.

The call site properly updated to invoke the new CP-aware sync function.

Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 71a9f7a and d02365c.

📒 Files selected for processing (5)
  • modelopt/torch/quantization/model_calib.py (2 hunks)
  • tests/_test_utils/torch_dist/plugins/megatron_common.py (2 hunks)
  • tests/_test_utils/torch_quantization/quantize_common.py (3 hunks)
  • tests/gpu/torch/quantization/plugins/test_megatron.py (4 hunks)
  • tests/gpu/torch/quantization/test_model_calib.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/_test_utils/torch_dist/plugins/megatron_common.py
🧰 Additional context used
🧬 Code graph analysis (4)
tests/gpu/torch/quantization/plugins/test_megatron.py (3)
tests/_test_utils/torch_quantization/quantize_common.py (3)
  • data_tensor_context_parallel_test_helper (185-221)
  • dp_cp_parallel_test_helper (154-182)
  • tensor_parallel_test_helper (120-151)
tests/_test_utils/torch_dist/plugins/megatron_common.py (2)
  • initialize_for_megatron (386-401)
  • MegatronModel (85-131)
tests/gpu/torch/conftest.py (2)
  • need_2_gpus (32-34)
  • need_8_gpus (38-40)
tests/gpu/torch/quantization/test_model_calib.py (2)
tests/_test_utils/torch_dist/plugins/megatron_common.py (3)
  • MegatronModel (85-131)
  • initialize_for_megatron (386-401)
  • get_dummy_input (130-131)
modelopt/torch/quantization/model_calib.py (1)
  • awq_lite (424-677)
tests/_test_utils/torch_quantization/quantize_common.py (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
  • SequentialQuantizer (1114-1222)
  • amax (236-241)
  • amax (244-255)
tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
  • get_dummy_input (130-131)
modelopt/torch/quantization/model_calib.py (4)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
  • SequentialQuantizer (1114-1222)
  • sync_amax_across_distributed_group (1071-1083)
  • TensorQuantizer (65-1111)
  • is_enabled (395-397)
modelopt/torch/trace/symbols.py (1)
  • named_modules (444-447)
modelopt/torch/quantization/nn/modules/quant_module.py (1)
  • QuantModule (37-96)
modelopt/torch/utils/distributed.py (2)
  • is_initialized (49-51)
  • is_initialized (196-198)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (13)
modelopt/torch/quantization/model_calib.py (2)

83-98: LGTM! Well-structured CP-aware amax synchronization.

The implementation correctly extends amax synchronization to context parallel groups while maintaining backward compatibility with data parallel. The recursive handling of SequentialQuantizer is appropriate, and using ReduceOp.MAX for both DP and CP groups ensures correct semantics (maximum amax across all ranks).


602-613: LGTM! Correct use of averaging for act_scale synchronization.

The function properly synchronizes act_scale using ReduceOp.AVG across both DP and CP groups, which is the correct reduction operation for averaging activation scales. The guard checks ensure synchronization only occurs when the groups are initialized.

tests/gpu/torch/quantization/test_model_calib.py (1)

32-33: LGTM! Test setup is correct.

The test properly uses spawn_multiprocess_job with 2 GPUs and NCCL backend for distributed testing.

tests/gpu/torch/quantization/plugins/test_megatron.py (5)

34-35: LGTM!

The new imports are necessary for the DP/CP test helpers and context-parallel group retrieval used in the tests below.

Also applies to: 45-45


101-103: LGTM!

The explicit tp_size keyword argument improves clarity, and removing dp_group from the tensor_parallel_test_helper call aligns with the updated signature in quantize_common.py.


124-130: Per-rank seed is overridden; test won't catch broken DP sync.

Passing SEED + rank to initialize_for_megatron is overridden by the internal call to model_parallel_cuda_manual_seed(seed) (see tests/_test_utils/torch_dist/plugins/megatron_common.py, lines 385-400), so all ranks still produce identical get_dummy_input() activations. The test will pass even if DP synchronization is broken. Introduce a rank-dependent perturbation after initialization—e.g., reseed or add a small offset before calling dp_cp_parallel_test_helper.

Based on past review comments.


148-156: Per-rank seed is overridden; test won't catch broken CP sync.

The same issue from the DP test applies here: initialize_for_megatron internally calls model_parallel_cuda_manual_seed(seed) with the provided seed, overriding the per-rank divergence you intended with SEED + rank. All CP ranks will produce identical calibration data, so the test won't fail if CP synchronization regresses. Add a rank-dependent perturbation after initialization.

Based on past review comments.


176-187: Fixed seed produces identical calibration data; test won't catch broken DP/CP sync.

Line 178 uses SEED without rank-dependent divergence. Since initialize_for_megatron calls model_parallel_cuda_manual_seed(SEED) uniformly across all 8 ranks, every rank will produce identical get_dummy_input() activations, so the assertions in data_tensor_context_parallel_test_helper will pass even if DP or CP synchronization is broken. Introduce rank-dependent perturbation (e.g., SEED + rank + 1) or add a small offset after initialization to ensure different calibration data per DP/CP rank.

Based on past review comments.

tests/_test_utils/torch_quantization/quantize_common.py (5)

26-26: LGTM!

The SequentialQuantizer import is necessary for the new helpers to handle multi-format weight quantization correctly.


120-120: LGTM!

Removing the unused dp_group parameter simplifies the signature; the function only validates tensor-parallel synchronization.


151-151: Remove global process group destruction from helper.

This call unconditionally tears down the global process group, breaking any subsequent DP/CP/TP tests that run in the same process. The test harness owns the process group lifecycle. Remove this line.

Based on past review comments.

Apply this diff:

-    dist.destroy_process_group()

154-183: Guard amax access and add AWQ-Lite scale validation.

Line 163 unconditionally accesses quantizer.amax, which returns None for MX formats (see modelopt/torch/quantization/nn/modules/tensor_quantizer.py, line 237) and will crash on .clone(). The config equality check at line 168 is brittle and misses future configs. Additionally, the PR objective includes synchronizing AWQ-Lite act_scale, but this helper doesn't validate it.

Based on past review comments.

Replace config checks with attribute guards and add scale validation:

-    def reduce_amax(quantizer):
-        amax = quantizer.amax.clone()
-        dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=group)
-        assert torch.allclose(amax, quantizer.amax)
+    def reduce_quant_attrs(quantizer):
+        for attr in ("amax", "pre_quant_scale", "act_scale"):
+            tensor = getattr(quantizer, attr, None)
+            if tensor is not None:
+                synced = tensor.clone()
+                dist.all_reduce(synced, op=dist.ReduceOp.MAX, group=group)
+                assert torch.allclose(synced, tensor), f"{attr} mismatch"
 
     # Input quantizer amax
-    if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]:
-        reduce_amax(model.fc1.input_quantizer)
-        reduce_amax(model.fc2.input_quantizer)
+    if hasattr(model.fc1, "input_quantizer"):
+        reduce_quant_attrs(model.fc1.input_quantizer)
+    if hasattr(model.fc2, "input_quantizer"):
+        reduce_quant_attrs(model.fc2.input_quantizer)
 
     # Weight quantizer amax
     if isinstance(model.fc1.weight_quantizer, SequentialQuantizer):
         for quantizer in model.fc1.weight_quantizer:
-            reduce_amax(quantizer)
+            reduce_quant_attrs(quantizer)
     else:
-        reduce_amax(model.fc1.weight_quantizer)
+        reduce_quant_attrs(model.fc1.weight_quantizer)
     if isinstance(model.fc2.weight_quantizer, SequentialQuantizer):
         for quantizer in model.fc2.weight_quantizer:
-            reduce_amax(quantizer)
+            reduce_quant_attrs(quantizer)
     else:
-        reduce_amax(model.fc2.weight_quantizer)
+        reduce_quant_attrs(model.fc2.weight_quantizer)

185-222: Guard amax access, add AWQ-Lite scale validation, and remove debug prints.

Lines 196, 197-198, 202-203 have the same issues as dp_cp_parallel_test_helper: unconditional .amax.clone() crashes for MX formats, config equality checks are brittle, and AWQ-Lite act_scale/pre_quant_scale aren't validated. Additionally, the print statements at lines 197-198 and 202-203 should be removed or converted to logging for production.

Based on past review comments.

Apply the same attribute-guard pattern from dp_cp_parallel_test_helper, replicate for all three groups (dp_group, cp_group, tp_group), and remove print statements:

-    def reduce_amax(quantizer):
-        amax = quantizer.amax.clone()
-        print("amax before reduce", amax)
-        print("quantizer.amax before reduce", quantizer.amax)
-        dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=dp_group)
-        dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=cp_group)
-        dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=tp_group)
-        print("amax after reduce", amax)
-        print("quantizer.amax after reduce", quantizer.amax)
-        assert torch.allclose(amax, quantizer.amax)
+    def reduce_quant_attrs(quantizer):
+        for attr in ("amax", "pre_quant_scale", "act_scale"):
+            tensor = getattr(quantizer, attr, None)
+            if tensor is not None:
+                synced = tensor.clone()
+                for g in (dp_group, cp_group, tp_group):
+                    dist.all_reduce(synced, op=dist.ReduceOp.MAX, group=g)
+                assert torch.allclose(synced, tensor), f"{attr} mismatch"
 
     # Input quantizer amax
-    if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]:
-        reduce_amax(model.fc1.input_quantizer)
-        reduce_amax(model.fc2.input_quantizer)
+    if hasattr(model.fc1, "input_quantizer"):
+        reduce_quant_attrs(model.fc1.input_quantizer)
+    if hasattr(model.fc2, "input_quantizer"):
+        reduce_quant_attrs(model.fc2.input_quantizer)
 
     if isinstance(model.fc1.weight_quantizer, SequentialQuantizer):
         for quantizer in model.fc1.weight_quantizer:
-            reduce_amax(quantizer)
+            reduce_quant_attrs(quantizer)
     else:
-        reduce_amax(model.fc1.weight_quantizer)
+        reduce_quant_attrs(model.fc1.weight_quantizer)
 
     if isinstance(model.fc2.weight_quantizer, SequentialQuantizer):
         for quantizer in model.fc2.weight_quantizer:
-            reduce_amax(quantizer)
+            reduce_quant_attrs(quantizer)
     else:
-        reduce_amax(model.fc2.weight_quantizer)
+        reduce_quant_attrs(model.fc2.weight_quantizer)

Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like we dont need separate methods tensor_parallel_test_helper, dp_cp_parallel_test_helper and data_tensor_context_parallel_test_helper for testing out all the combinations. Can we merge them into one and do data_tensor_context_parallel_test_helper(..., tp_group=None, dp_group=None, cp_group=None)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as https://github.com/NVIDIA/TensorRT-Model-Optimizer/pull/359/files#r2410523895
Can we combine dp, tp, cp tests by parameterize them?

Copy link
Contributor

@realAsma realAsma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks very good!

I left a few more comments on consolidating the unit tests. Let's try to consolidate the tests as much as we can to reduce the code maintainence.

Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 3f857a3 and 93bfd52.

📒 Files selected for processing (4)
  • modelopt/torch/quantization/model_calib.py (3 hunks)
  • tests/_test_utils/torch_dist/plugins/megatron_common.py (5 hunks)
  • tests/_test_utils/torch_quantization/quantize_common.py (3 hunks)
  • tests/gpu/torch/quantization/plugins/test_megatron.py (5 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/quantization/model_calib.py
🧰 Additional context used
🧬 Code graph analysis (2)
tests/gpu/torch/quantization/plugins/test_megatron.py (3)
tests/_test_utils/torch_quantization/quantize_common.py (1)
  • data_tensor_context_parallel_test_helper (181-248)
tests/_test_utils/torch_dist/plugins/megatron_common.py (2)
  • initialize_for_megatron (394-409)
  • MegatronModel (85-139)
tests/gpu/torch/conftest.py (2)
  • need_2_gpus (32-34)
  • need_8_gpus (38-40)
tests/_test_utils/torch_quantization/quantize_common.py (3)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
  • SequentialQuantizer (1114-1222)
modelopt/torch/quantization/model_calib.py (1)
  • awq_lite (424-680)
tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
  • get_dummy_input (134-139)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality

Comment on lines 122 to 127
def _distributed_attr_check(quantizer, attr: str, op=dist.ReduceOp.MAX, groups=[]):
for group in groups:
if group is not None:
quantizer_attr = getattr(quantizer, attr).clone()
dist.all_reduce(quantizer_attr, op=op, group=group)
assert torch.allclose(quantizer_attr, getattr(quantizer, attr))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix distributed attribute check to cover all groups

_distributed_attr_check clones and reduces for every group, but the sole assertion after the loop only evaluates the final non‑None group. When we pass [dp_group, tp_group] this lets a DP regression slip through silently because the reduced tensor from the first group is discarded. Please assert inside the loop (or otherwise accumulate checks) so every provided group is actually validated.

-def _distributed_attr_check(quantizer, attr: str, op=dist.ReduceOp.MAX, groups=[]):
-    for group in groups:
-        if group is not None:
-            quantizer_attr = getattr(quantizer, attr).clone()
-            dist.all_reduce(quantizer_attr, op=op, group=group)
-    assert torch.allclose(quantizer_attr, getattr(quantizer, attr))
+def _distributed_attr_check(quantizer, attr: str, op=dist.ReduceOp.MAX, groups=()):
+    attr_value = getattr(quantizer, attr)
+    checked = False
+    for group in groups:
+        if group is None:
+            continue
+        reduced = attr_value.clone()
+        dist.all_reduce(reduced, op=op, group=group)
+        assert torch.allclose(reduced, attr_value)
+        checked = True
+    assert checked, "expected at least one distributed group to validate"
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _distributed_attr_check(quantizer, attr: str, op=dist.ReduceOp.MAX, groups=[]):
for group in groups:
if group is not None:
quantizer_attr = getattr(quantizer, attr).clone()
dist.all_reduce(quantizer_attr, op=op, group=group)
assert torch.allclose(quantizer_attr, getattr(quantizer, attr))
def _distributed_attr_check(quantizer, attr: str, op=dist.ReduceOp.MAX, groups=()):
attr_value = getattr(quantizer, attr)
checked = False
for group in groups:
if group is None:
continue
reduced = attr_value.clone()
dist.all_reduce(reduced, op=op, group=group)
assert torch.allclose(reduced, attr_value)
checked = True
assert checked, "expected at least one distributed group to validate"
🤖 Prompt for AI Agents
In tests/_test_utils/torch_quantization/quantize_common.py around lines 122-127,
the distributed attribute check clones and reduces for each group but only
asserts once after the loop, so only the last group's reduction is validated;
modify the function to perform the assertion inside the loop (i.e., after
dist.all_reduce for each non-None group, compare the reduced tensor to the
original attribute and assert they match) or alternatively accumulate/compare
reductions for every group before exiting so every provided group is actually
validated.

Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (2)
tests/_test_utils/torch_quantization/quantize_common.py (2)

122-127: Assert per group in _distributed_attr_check

We still only compare the reduction result from the final non-None group against the local tensor, so any earlier group (e.g., DP) can drift without being caught. Please assert after each all_reduce (and switch the mutable default) so every provided group is validated.

-def _distributed_attr_check(quantizer, attr: str, op=dist.ReduceOp.MAX, groups=[]):
-    for group in groups:
-        if group is not None:
-            quantizer_attr = getattr(quantizer, attr).clone()
-            dist.all_reduce(quantizer_attr, op=op, group=group)
-    assert torch.allclose(quantizer_attr, getattr(quantizer, attr))
+def _distributed_attr_check(quantizer, attr: str, op=dist.ReduceOp.MAX, groups=()):
+    attr_value = getattr(quantizer, attr)
+    for group in groups:
+        if group is None:
+            continue
+        reduced = attr_value.clone()
+        dist.all_reduce(reduced, op=op, group=group)
+        assert torch.allclose(reduced, attr_value)

203-206: Also validate fc2.awq_lite.act_scale

Row-parallel AWQ-Lite sync can regress silently because we never touch fc2.awq_lite.act_scale. Mirror the fc1 check (guarded with getattr if needed) so both halves of the layer are exercised.

     if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]:
         _distributed_attr_check(
             model.fc1.awq_lite, "act_scale", dist.ReduceOp.AVG, groups=[dp_group, tp_group]
         )
+        fc2_awq = getattr(model.fc2, "awq_lite", None)
+        if fc2_awq is not None:
+            _distributed_attr_check(
+                fc2_awq, "act_scale", dist.ReduceOp.AVG, groups=[dp_group, tp_group]
+            )
🧹 Nitpick comments (1)
tests/_test_utils/torch_dist/plugins/megatron_common.py (1)

375-378: Limit inference debug prints to a single rank

These prints fire on every rank for every test run, flooding logs. Gate them behind a rank check (or drop them) so only rank 0 emits the shapes.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 93bfd52 and 6761109.

📒 Files selected for processing (3)
  • tests/_test_utils/torch_dist/plugins/megatron_common.py (6 hunks)
  • tests/_test_utils/torch_quantization/quantize_common.py (3 hunks)
  • tests/gpu/torch/quantization/plugins/test_megatron.py (5 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/_test_utils/torch_quantization/quantize_common.py (3)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
  • SequentialQuantizer (1114-1222)
modelopt/torch/quantization/model_calib.py (1)
  • awq_lite (424-680)
tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
  • get_dummy_input (134-139)
tests/gpu/torch/quantization/plugins/test_megatron.py (3)
tests/_test_utils/torch_quantization/quantize_common.py (1)
  • data_tensor_context_parallel_test_helper (139-206)
tests/_test_utils/torch_dist/plugins/megatron_common.py (2)
  • initialize_for_megatron (396-411)
  • MegatronModel (85-139)
tests/gpu/torch/conftest.py (2)
  • need_2_gpus (32-34)
  • need_8_gpus (38-40)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (3)
tests/gpu/torch/quantization/plugins/test_megatron.py (2)

100-100: Clarify the purpose of use_rank_in_seed.

The parameter use_rank_in_seed controls whether the global seed passed to initialize_for_megatron includes the rank offset. However, the actual calibration data divergence across DP/CP ranks is already handled by data_tensor_context_parallel_test_helper seeding get_dummy_input with dp_rank (see quantize_common.py line 143). Consider documenting why this parameter affects initialize_for_megatron's seed—perhaps it ensures model initialization RNG states differ across ranks, which could matter for edge cases. If it's not needed, consider removing it to reduce confusion.


159-224: Consider parameterizing tests to reduce duplication.

As noted by @realAsma, the four test functions (test_tensor_parallel, test_data_parallel, test_context_parallel, test_data_tensor_context_parallel) share nearly identical structure. You could combine them into a single parameterized test:

@pytest.mark.parametrize(
    "size,tp_size,cp_size,use_rank_in_seed,fixture",
    [
        (2, 2, 1, False, "need_2_gpus"),  # TP
        (2, 1, 1, True, "need_2_gpus"),   # DP
        (2, 1, 2, True, "need_2_gpus"),   # CP
        (8, 2, 2, True, "need_8_gpus"),   # DP+TP+CP
    ],
)
@pytest.mark.parametrize("config", [...])
def test_parallelism(request, size, tp_size, cp_size, use_rank_in_seed, fixture, config):
    request.getfixturevalue(fixture)
    spawn_multiprocess_job(
        size=size,
        job=partial(
            _test_parallelism_helper,
            config,
            tensor_model_parallel_size=tp_size,
            context_parallel_size=cp_size,
            use_rank_in_seed=use_rank_in_seed,
        ),
        backend="nccl",
    )

This would improve maintainability by centralizing the test logic.

Based on learnings

tests/_test_utils/torch_dist/plugins/megatron_common.py (1)

382-384: Consider replacing print with logging.

Debug statements using print work for test utilities, but migrating to logging.debug() would allow better control over verbosity in CI/local runs.

Apply this change if you want to use logging:

+import logging
+
+logger = logging.getLogger(__name__)
...
-    print("inference_input size", inference_input["tokens"].shape)
+    logger.debug("inference_input size %s", inference_input["tokens"].shape)
     logits = wrapped_model.run_one_forward_step(inference_input)
-    print("logits size", logits.shape)
+    logger.debug("logits size %s", logits.shape)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6761109 and 291cfa3.

📒 Files selected for processing (2)
  • tests/_test_utils/torch_dist/plugins/megatron_common.py (6 hunks)
  • tests/gpu/torch/quantization/plugins/test_megatron.py (5 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/gpu/torch/quantization/plugins/test_megatron.py (3)
tests/_test_utils/torch_quantization/quantize_common.py (1)
  • data_tensor_context_parallel_test_helper (139-206)
tests/_test_utils/torch_dist/plugins/megatron_common.py (2)
  • initialize_for_megatron (403-418)
  • MegatronModel (86-140)
tests/gpu/torch/conftest.py (2)
  • need_2_gpus (32-34)
  • need_8_gpus (38-40)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (3)
tests/gpu/torch/quantization/plugins/test_megatron.py (1)

93-136: LGTM: Unified helper consolidates TP/DP/CP test logic.

The refactor to a single _test_parallelism_helper reduces duplication and correctly wires parallelism groups. The logic properly:

  • Determines tp_group and dp_group based on parallelism config
  • Passes them to data_tensor_context_parallel_test_helper for validation
  • Uses rank-based seeding when use_rank_in_seed=True for DP/CP scenarios
tests/_test_utils/torch_dist/plugins/megatron_common.py (2)

87-140: LGTM: MegatronModel extensions support CP and seeded inputs.

The additions properly support context parallelism and enable per-rank calibration data divergence:

  • cp_size and tp_group parameters thread through to TransformerConfig and parallel layers
  • get_dummy_input(seed) allows deterministic, per-rank data generation when seed is provided
  • Backward compatibility maintained by defaulting new parameters

403-418: LGTM: Parameter order maintains backward compatibility.

The parameter order (tensor_model_parallel_size, pipeline_model_parallel_size, seed, context_parallel_size) ensures existing positional calls passing seed continue to work, addressing the concern raised in previous reviews.

Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
tests/_test_utils/torch_dist/plugins/megatron_common.py (1)

401-416: LGTM! Backward compatibility correctly restored.

The parameter ordering now correctly places seed in the third position, maintaining backward compatibility with existing code that passes it positionally. The new context_parallel_size parameter is appropriately placed at the end and correctly passed to initialize_model_parallel.

Consider updating the docstring to document the new context_parallel_size parameter:

 def initialize_for_megatron(
     tensor_model_parallel_size=1,
     pipeline_model_parallel_size=1,
     seed=1234,
     context_parallel_size=1,
 ):
     """Initialize Megatron model parallelism.
+
+    Args:
+        tensor_model_parallel_size: Size of tensor model parallel group.
+        pipeline_model_parallel_size: Size of pipeline model parallel group.
+        seed: Random seed for model parallel initialization.
+        context_parallel_size: Size of context parallel group.
 
     NOTE: If used in a non-spawned process, make sure to call `megatron.core.parallel_state.destroy_model_parallel()`.
     """
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 291cfa3 and a106dd9.

📒 Files selected for processing (3)
  • examples/nemo_run/qat/README.md (2 hunks)
  • tests/_test_utils/torch_dist/plugins/megatron_common.py (5 hunks)
  • tests/gpu/torch/quantization/plugins/test_apex.py (3 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • examples/nemo_run/qat/README.md
🧰 Additional context used
🧬 Code graph analysis (2)
tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
tests/gpu/torch/quantization/plugins/test_apex.py (1)
  • get_dummy_input (61-66)
tests/gpu/torch/quantization/plugins/test_apex.py (2)
tests/_test_utils/torch_quantization/quantize_common.py (1)
  • data_tensor_context_parallel_test_helper (139-206)
tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
  • get_dummy_input (135-140)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (5)
tests/_test_utils/torch_dist/plugins/megatron_common.py (3)

87-92: LGTM! Context parallel and tensor parallel group parameters added correctly.

The additions of cp_size and tp_group parameters enable fine-grained control over parallel configurations for testing. Default values maintain backward compatibility for existing callers.


110-110: LGTM! Tensor parallel group correctly propagated to parallel layers.

Passing tp_group to both ColumnParallelLinear and RowParallelLinear enables explicit control over the tensor parallel group, which is essential for testing scenarios with multiple parallel groups.

Also applies to: 125-125


135-140: LGTM! Deterministic input generation implemented correctly.

The optional seed parameter enables reproducible test inputs when needed, while maintaining non-seeded behavior by default. The implementation correctly uses torch.Generator with manual_seed and follows the established pattern from test_apex.py.

tests/gpu/torch/quantization/plugins/test_apex.py (2)

26-26: LGTM!

The import name change from tensor_parallel_test_helper to data_tensor_context_parallel_test_helper correctly reflects the updated test utility that now supports combined DP/TP/CP testing.


61-66: LGTM!

The seeded input generation implementation is correct and matches the pattern used in the Megatron common test utilities (see tests/_test_utils/torch_dist/plugins/megatron_common.py lines 134-139). This enables deterministic data generation per DP rank for proper distributed testing.

@jenchen13 jenchen13 enabled auto-merge (squash) October 9, 2025 20:50
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 50000dd and 440ca48.

📒 Files selected for processing (1)
  • examples/nemo_run/qat/README.md (2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: wait-checks / wait
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs

Comment on lines 66 to 70
Run docker command:

```bash
docker run -v /home/user/:/home/user/ -v /home/user/NeMo:/opt/NeMo -v /home/user/TensorRT-Model-Optimizer/modelopt/:/usr/local/lib/python3.12/dist-packages/modelopt --gpus all -it --shm-size 20g --rm nvcr.io/nvidia/nemo:25.07 bash
docker run -v /home/user/:/home/user/ -v /home/user/TensorRT-Model-Optimizer/:/opt/TensorRT-Model-Optimizer/ --gpus all -it --shm-size 20g --rm nvcr.io/nvidia/nemo:25.09 bash
```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Pass HF_TOKEN into the container

You tell users to export HF_TOKEN, but the subsequent docker run does not propagate that variable, so the container still can’t access the token and downloads will fail. Add -e HF_TOKEN=$HF_TOKEN (or equivalent) to the command.

🤖 Prompt for AI Agents
In examples/nemo_run/qat/README.md around lines 66 to 70, the docker run command
does not propagate the HF_TOKEN environment variable despite instructing users
to export it; update the docker run invocation to pass the token into the
container (for example by adding -e HF_TOKEN=$HF_TOKEN or an equivalent --env
specification) so the container can access the token for downloads.

mtq.NVFP4_DEFAULT_CFG,
],
)
def test_data_parallel(need_2_gpus, config):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For data parallel and context parallel, do we really need to test all configs? Or testing one sufficient given that we have extensive tensor parallel tests?
Thoughts @realAsma @jenchen13

Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
module.awq_lite.act_scale = module.awq_lite.act_scale / module.awq_lite.num_cache_steps
if torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any(

has_nan_local = torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch on this.

Can we use

DistributedProcessGroup.get_dist_syncd_obj(has_nan_local, module.parallel_state.data_parallel_group, lambda   obis: any(obis))

Comment on lines 621 to 630
)
has_nan = torch.tensor(int(has_nan_local), device=module.awq_lite.act_scale.device)
if module.parallel_state.data_parallel_group.is_initialized():
dist.all_reduce(
has_nan,
op=dist.ReduceOp.MAX,
group=module.parallel_state.data_parallel_group.group,
)

if has_nan.item() > 0:
Copy link
Contributor

@realAsma realAsma Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use the sync_obj to sync has_nan

Suggested change
)
has_nan = torch.tensor(int(has_nan_local), device=module.awq_lite.act_scale.device)
if module.parallel_state.data_parallel_group.is_initialized():
dist.all_reduce(
has_nan,
op=dist.ReduceOp.MAX,
group=module.parallel_state.data_parallel_group.group,
)
if has_nan.item() > 0:
)
has_nan = DistributedProcessGroup.get_dist_syncd_obj(has_nan_local, module.parallel_state.data_parallel_group, lambda objs: any(objs))
if has_nan:

Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2e8ef58 and afe6f34.

📒 Files selected for processing (2)
  • examples/nemo_run/qat/README.md (2 hunks)
  • modelopt/torch/quantization/model_calib.py (4 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • examples/nemo_run/qat/README.md
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/quantization/model_calib.py (3)
modelopt/torch/utils/distributed.py (5)
  • DistributedProcessGroup (189-229)
  • ParallelState (232-253)
  • is_initialized (49-51)
  • is_initialized (196-198)
  • get_dist_syncd_obj (212-229)
modelopt/torch/trace/symbols.py (1)
  • named_modules (444-447)
modelopt/torch/quantization/utils.py (1)
  • is_quantized_linear (246-256)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality

Comment on lines +619 to +621
has_nan_local = torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any(
torch.isnan(module.awq_lite.weight_scale)
):
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix tensor boolean evaluation before distributed sync.

torch.any(...) returns a 0‑D tensor, so the Python or tries to convert that tensor to bool, triggering RuntimeError: Boolean value of Tensor with more than one value is ambiguous at runtime. AWQ-Lite calibration will crash the first time it hits this branch. Convert the NaN checks to Python booleans before combining them.

-            has_nan_local = torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any(
-                torch.isnan(module.awq_lite.weight_scale)
-            )
+            act_nan = torch.isnan(module.awq_lite.act_scale).any().item()
+            weight_nan = torch.isnan(module.awq_lite.weight_scale).any().item()
+            has_nan_local = bool(act_nan or weight_nan)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
has_nan_local = torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any(
torch.isnan(module.awq_lite.weight_scale)
):
)
@@ modelopt/torch/quantization/model_calib.py:619
- has_nan_local = torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any(
- torch.isnan(module.awq_lite.weight_scale)
act_nan = torch.isnan(module.awq_lite.act_scale).any().item()
weight_nan = torch.isnan(module.awq_lite.weight_scale).any().item()
has_nan_local = bool(act_nan or weight_nan)
🤖 Prompt for AI Agents
In modelopt/torch/quantization/model_calib.py around lines 619 to 621, the code
uses torch.any(...) twice and combines them with the Python "or", which attempts
to convert 0-D tensors to bools and raises a RuntimeError; change the checks to
produce Python booleans before combining (for example call .any().item() or wrap
each check with bool(...).cpu().item() as appropriate) so has_nan_local becomes
a plain Python bool, then use that to drive the subsequent logic.

@jenchen13 jenchen13 merged commit 99c44d3 into main Oct 10, 2025
25 of 27 checks passed
@jenchen13 jenchen13 deleted the jennifchen/cp_amax_sync branch October 10, 2025 23:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants