Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
f17131f
sync amax in context parallel and awq act scale
jenchen13 Sep 24, 2025
42519cc
lint
jenchen13 Sep 25, 2025
264adbb
test weight quantizer too
jenchen13 Sep 25, 2025
7cbe5b9
Merge branch 'main' into jennifchen/cp_amax_sync
jenchen13 Sep 25, 2025
1f7d17e
fix test
jenchen13 Sep 26, 2025
71a9f7a
Merge branch 'main' into jennifchen/cp_amax_sync
jenchen13 Sep 29, 2025
d02365c
awq test
jenchen13 Sep 29, 2025
5a572da
move awq test inside megatron tests
jenchen13 Sep 29, 2025
fc0bb88
fix amax tests
jenchen13 Sep 30, 2025
95da832
fix awq lite param
jenchen13 Sep 30, 2025
34c11ef
fix test
jenchen13 Sep 30, 2025
10e3e2b
Merge branch 'main' into jennifchen/cp_amax_sync
jenchen13 Sep 30, 2025
9f0691f
uncomment test
jenchen13 Sep 30, 2025
fa8f4c8
add print
jenchen13 Oct 1, 2025
d1fac44
docstring
jenchen13 Oct 1, 2025
22b8b73
fix tests
jenchen13 Oct 2, 2025
ca7c0e8
Merge branch 'main' into jennifchen/cp_amax_sync
jenchen13 Oct 2, 2025
3f857a3
fix multiprocess size
jenchen13 Oct 2, 2025
93bfd52
fix tests
jenchen13 Oct 8, 2025
6761109
consolidate tests
jenchen13 Oct 9, 2025
291cfa3
Merge branch 'main' into jennifchen/cp_amax_sync
jenchen13 Oct 9, 2025
a106dd9
fix test
jenchen13 Oct 9, 2025
50000dd
fix bug
jenchen13 Oct 10, 2025
2664563
Merge branch 'main' into jennifchen/cp_amax_sync
jenchen13 Oct 10, 2025
440ca48
update qat readme
jenchen13 Oct 10, 2025
2e8ef58
update readme
jenchen13 Oct 10, 2025
5cb380c
Merge branch 'main' into jennifchen/cp_amax_sync
jenchen13 Oct 10, 2025
afe6f34
fix dist has_nan
jenchen13 Oct 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions examples/nemo_run/qat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,18 +56,21 @@ The resulting exported checkpoint also is much smaller in memory at 6.4GB compar

You can run the example either locally or on a [Slurm cluster](ADVANCED.md).

To run the example locally, launch a [NeMo container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo) with version 25.07 or higher. Clone the `TensorRT-Model-Optimizer` repository and `NeMo` repository (checkout a specific commit for NeMo), then mount it onto your docker container.
To run the example locally, first clone the `TensorRT-Model-Optimizer` repository, then mount the repository to a [NeMo container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo) with version 25.09. After launching the Docker container, make sure to also set your HuggingFace token for dataset/model downloading.

Set up repo:

- `git clone https://github.com/NVIDIA/TensorRT-Model-Optimizer.git`
- `git clone https://github.com/NVIDIA-NeMo/NeMo.git && cd NeMo && git checkout 676ed1a`

Example docker command:
Run docker command (modify with your paths) and export the HuggingFace token:

```bash
docker run -v /home/user/:/home/user/ -v /home/user/NeMo:/opt/NeMo -v /home/user/TensorRT-Model-Optimizer/modelopt/:/usr/local/lib/python3.12/dist-packages/modelopt --gpus all -it --shm-size 20g --rm nvcr.io/nvidia/nemo:25.07 bash
docker run -v /home/user/:/home/user/ -v /home/user/TensorRT-Model-Optimizer/:/opt/TensorRT-Model-Optimizer/ --gpus all -it --shm-size 20g --rm nvcr.io/nvidia/nemo:25.09 bash

export HF_TOKEN=<your-token>
```

You will also need to set your Huggingface token with `export HF_TOKEN=<your-token>`. You may also need to enable write access to the docker container to the `examples/nemo_run` folder by doing `chmod 777 nemo_run` so that logs can be written.
You may also need to enable write access to the docker container to the `examples/nemo_run` folder by doing `chmod 777 nemo_run` so that logs can be written.

### Running the Flow Locally

Expand All @@ -92,7 +95,7 @@ In order to train using QAD, launch the example with `python qat/nemo_qat_flow.p
To perform QAD training, run:

```bash
python qat/nemo_qat_flow.py --distill --log-dir /my/log/dir --experiment qad_experiment
python qat/nemo_qat_flow.py --distill --log-dir /my/log/dir --experiment qad_experiment --tensor_parallelism 4
```

## Supported models
Expand Down
2 changes: 1 addition & 1 deletion examples/nemo_run/qat/nemo_qat_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def main(args):
global_batch_size=GBS,
micro_batch_size=MBS,
use_hf_tokenizer_chat_template=True,
num_workers=2,
num_workers=1,
persistent_workers=True,
)
if args.distill:
Expand Down
31 changes: 25 additions & 6 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from modelopt.torch.opt.searcher import ForwardLoop
from modelopt.torch.utils import print_rank_0
from modelopt.torch.utils.distributed import ParallelState
from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState
from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method

from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context
Expand Down Expand Up @@ -81,6 +81,7 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis
return

def sync_quantizer_amax_across_dp(quantizer, parallel_state):
"""Synchronize the amax across all ranks in the data parallel group."""
if isinstance(quantizer, SequentialQuantizer):
for _q in quantizer:
sync_quantizer_amax_across_dp(_q, parallel_state)
Expand All @@ -94,7 +95,6 @@ def sync_quantizer_amax_across_dp(quantizer, parallel_state):
for child in module.children():
if isinstance(child, (TensorQuantizer, SequentialQuantizer)):
sync_quantizer_amax_across_dp(child, module.parallel_state)

# TP sync:
# Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same

Expand All @@ -114,6 +114,7 @@ def sync_quantizer_amax_across_tp(
axes_for_sync: list,
parallel_state: ParallelState,
):
# Syncing amax across TP for sequential quantizer
if isinstance(quantizer, SequentialQuantizer):
for _q in quantizer:
sync_quantizer_amax_across_tp(
Expand Down Expand Up @@ -598,19 +599,37 @@ def forward(self, input, *args, **kwargs):
# This will also perform distributed amax sync for input_quantizers
max_calibrate(model, lambda model: None)

def sync_act_scale_across_dp(module, data_parallel_group):
"""Sync activation scale across Data Parallel (DP)."""
if data_parallel_group.is_initialized():
dist.all_reduce(
module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=data_parallel_group.group
)

for name, module in model.named_modules():
if (
is_quantized_linear(module)
and hasattr(module, "awq_lite")
and module.awq_lite.num_cache_steps > 0
):
# Hack: MoEs forward all tokens through all experts if _if_calib is True
module._if_calib = True
module.awq_lite.act_scale = module.awq_lite.act_scale / module.awq_lite.num_cache_steps
if torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any(

has_nan_local = torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch on this.

Can we use

DistributedProcessGroup.get_dist_syncd_obj(has_nan_local, module.parallel_state.data_parallel_group, lambda   obis: any(obis))

torch.isnan(module.awq_lite.weight_scale)
):
)
Comment on lines +619 to +621
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix tensor boolean evaluation before distributed sync.

torch.any(...) returns a 0‑D tensor, so the Python or tries to convert that tensor to bool, triggering RuntimeError: Boolean value of Tensor with more than one value is ambiguous at runtime. AWQ-Lite calibration will crash the first time it hits this branch. Convert the NaN checks to Python booleans before combining them.

-            has_nan_local = torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any(
-                torch.isnan(module.awq_lite.weight_scale)
-            )
+            act_nan = torch.isnan(module.awq_lite.act_scale).any().item()
+            weight_nan = torch.isnan(module.awq_lite.weight_scale).any().item()
+            has_nan_local = bool(act_nan or weight_nan)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
has_nan_local = torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any(
torch.isnan(module.awq_lite.weight_scale)
):
)
@@ modelopt/torch/quantization/model_calib.py:619
- has_nan_local = torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any(
- torch.isnan(module.awq_lite.weight_scale)
act_nan = torch.isnan(module.awq_lite.act_scale).any().item()
weight_nan = torch.isnan(module.awq_lite.weight_scale).any().item()
has_nan_local = bool(act_nan or weight_nan)
🤖 Prompt for AI Agents
In modelopt/torch/quantization/model_calib.py around lines 619 to 621, the code
uses torch.any(...) twice and combines them with the Python "or", which attempts
to convert 0-D tensors to bools and raises a RuntimeError; change the checks to
produce Python booleans before combining (for example call .any().item() or wrap
each check with bool(...).cpu().item() as appropriate) so has_nan_local becomes
a plain Python bool, then use that to drive the subsequent logic.

has_nan = DistributedProcessGroup.get_dist_syncd_obj(
has_nan_local, module.parallel_state.data_parallel_group, lambda objs: any(objs)
)

if has_nan:
module.awq_lite.is_enabled = False
# Hack: MoEs forward all tokens through all experts if _if_calib is True
module._if_calib = True
else:
sync_act_scale_across_dp(
module,
module.parallel_state.data_parallel_group,
)

AWQLiteHelper.cache_mode = False
print_rank_0("awq_lite: Searching parameters...")
Expand Down
12 changes: 11 additions & 1 deletion modelopt/torch/quantization/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@

"""Support quantization for megatron linear layers."""

import logging
import warnings
from typing import Any

import megatron.core.parallel_state as mcore_parallel
import megatron.core.tensor_parallel.layers as megatron_parallel
import megatron.core.transformer.mlp as megatron_mlp
import torch
from megatron.core.parallel_state import get_data_parallel_group
from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region
from megatron.core.transformer import MegatronModule
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
Expand All @@ -38,6 +40,8 @@
from ..qtensor import QTensorWrapper
from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear

logger = logging.getLogger(__name__)

__all__ = []


Expand Down Expand Up @@ -217,8 +221,14 @@ class _MegatronParallelLinear(_ParallelLinear):
]

def _setup(self):
data_parallel_group = None
try:
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
except AssertionError:
logger.warning("Context parallel group is not initialized, using data parallel group")
data_parallel_group = get_data_parallel_group()
self.parallel_state = ParallelState(
getattr(mcore_parallel, "get_expert_data_parallel_group", "get_data_parallel_group")(),
data_parallel_group,
mcore_parallel.get_tensor_model_parallel_group(),
)
super()._setup()
Expand Down
5 changes: 4 additions & 1 deletion modelopt/torch/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,10 @@ def __init__(
self.tensor_parallel_group = DistributedProcessGroup(tensor_parallel_group)

def __repr__(self) -> str:
return f"data_parallel_group: {self.data_parallel_group}, tensor_parallel_group: {self.tensor_parallel_group}"
return (
f"data_parallel_group: {self.data_parallel_group}, "
f"tensor_parallel_group: {self.tensor_parallel_group}, "
)


def get_group(ranks: list[int]):
Expand Down
24 changes: 20 additions & 4 deletions tests/_test_utils/torch_dist/plugins/megatron_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,12 @@


class MegatronModel(MegatronModule):
def __init__(self, tp_size: int = 1, use_te_norm: bool = False):
def __init__(
self, tp_size: int = 1, cp_size: int = 1, use_te_norm: bool = False, tp_group=None
):
config = TransformerConfig(
tensor_model_parallel_size=tp_size,
context_parallel_size=cp_size,
pipeline_model_parallel_size=1,
normalization="LayerNorm",
# Unused parameters below are set to avoid ZeroDivisionError in __post_init__
Expand All @@ -104,6 +107,7 @@ def __init__(self, tp_size: int = 1, use_te_norm: bool = False):
gather_output=False,
skip_bias_add=True,
is_expert=False,
tp_group=tp_group,
)
self.activation = nn.ReLU()
if use_te_norm:
Expand All @@ -118,6 +122,7 @@ def __init__(self, tp_size: int = 1, use_te_norm: bool = False):
skip_bias_add=True,
input_is_parallel=True,
is_expert=False,
tp_group=tp_group,
)

def forward(self, x):
Expand All @@ -127,7 +132,11 @@ def forward(self, x):
x = x[0]
return x

def get_dummy_input(self) -> torch.Tensor:
def get_dummy_input(self, seed: int | None = None) -> torch.Tensor:
if seed is not None:
gen = torch.Generator()
gen.manual_seed(seed)
return torch.randn(1, 4, 32, generator=gen)
return torch.randn(1, 4, 32)


Expand Down Expand Up @@ -390,13 +399,20 @@ def run_mcore_inference_with_dummy_input(


def initialize_for_megatron(
tensor_model_parallel_size=1, pipeline_model_parallel_size=1, seed=1234
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
seed=1234,
context_parallel_size=1,
):
"""Initialize Megatron model parallelism.

NOTE: If used in a non-spawned process, make sure to call `megatron.core.parallel_state.destroy_model_parallel()`.
"""
initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size)
initialize_model_parallel(
tensor_model_parallel_size,
pipeline_model_parallel_size,
context_parallel_size=context_parallel_size,
)
model_parallel_cuda_manual_seed(seed)


Expand Down
110 changes: 85 additions & 25 deletions tests/_test_utils/torch_quantization/quantize_common.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like we dont need separate methods tensor_parallel_test_helper, dp_cp_parallel_test_helper and data_tensor_context_parallel_test_helper for testing out all the combinations. Can we merge them into one and do data_tensor_context_parallel_test_helper(..., tp_group=None, dp_group=None, cp_group=None)?

Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from unittest.mock import patch

import pytest
import torch
Expand All @@ -22,7 +23,9 @@

import modelopt.torch.opt as mto
import modelopt.torch.quantization as mtq
import modelopt.torch.quantization.model_calib as model_calib_module # needed for patching awq_lite
from modelopt.torch.quantization.backends.gemm_registry import enable_real_quant_gemm
from modelopt.torch.quantization.nn.modules.tensor_quantizer import SequentialQuantizer
from modelopt.torch.quantization.utils import is_quantized_linear
from modelopt.torch.utils import torch_to

Expand Down Expand Up @@ -116,38 +119,95 @@ def save_restore_test(model_cls, device, quant_config, compress=False, version=N
mto.restore_from_modelopt_state(model_ref, state_dict)


def tensor_parallel_test_helper(model, config, tp_group, dp_group):
# The input to fist layer, the column parallel should be the same across all tp ranks
calib_data = model.get_dummy_input().cuda()
dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group)
def _distributed_attr_check(quantizer, attr: str, op=dist.ReduceOp.MAX, groups=[]):
quantizer_attr = getattr(quantizer, attr).clone()
for group in groups:
if group is not None:
dist.all_reduce(quantizer_attr, op=op, group=group)
assert torch.allclose(quantizer_attr, getattr(quantizer, attr))

def forward_loop(model):
model(calib_data)

model = mtq.quantize(model, config, forward_loop)
original_awq_lite = model_calib_module.awq_lite

# Sanity check
forward_loop(model)

if config in [mtq.INT8_DEFAULT_CFG, mtq.FP8_DEFAULT_CFG, mtq.INT8_SMOOTHQUANT_CFG]:
# Lets check the amax for row parallel input quantizer; it should be the same across all tp ranks
activation_amax = model.fc2.input_quantizer.amax.clone()
dist.all_reduce(activation_amax, op=dist.ReduceOp.MAX, group=tp_group)
assert torch.allclose(activation_amax, model.fc2.input_quantizer.amax)
def _debug_awq_lite(model, forward_loop, alpha_step=0.1, debug=True, **kwargs):
"""Function to mock awq_lite function to always use debug=True for testing"""
return original_awq_lite(model, forward_loop, alpha_step, debug=True, **kwargs)

# Lets check the row parallel weight amax; it should be the same across all tp ranks
weight_amax = model.fc2.weight_quantizer.amax.clone()
dist.all_reduce(weight_amax, op=dist.ReduceOp.MAX, group=tp_group)
assert torch.allclose(weight_amax, model.fc2.weight_quantizer.amax)

if config in [mtq.INT8_SMOOTHQUANT_CFG, mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]:
# Lets check the column parallel pre_quant_scale; it should be the same across all tp ranks
input_quantizer = model.fc1.input_quantizer
pre_quant_scale = input_quantizer.pre_quant_scale.clone()
dist.all_reduce(pre_quant_scale, op=dist.ReduceOp.MAX, group=tp_group)
assert torch.allclose(pre_quant_scale, input_quantizer.pre_quant_scale)
@patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite)
def data_tensor_context_parallel_test_helper(
model, config, mock_awq_lite, dp_group=None, tp_group=None, test_pre_quant_scale=True
):
# Calib data should be different across each DP rank
dp_rank = dist.get_rank(group=dp_group)
calib_data = model.get_dummy_input(seed=dp_rank).cuda()

if tp_group is not None:
# The input to first layer, the column parallel should be the same across all tp ranks
dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group)

dist.destroy_process_group()
def forward_loop(model):
model(calib_data)

model = mtq.quantize(model, config, forward_loop)

# Input quantizer amax
if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]:
_distributed_attr_check(
model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX, groups=[dp_group, tp_group]
)
_distributed_attr_check(
model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, groups=[dp_group, tp_group]
)

# Per-tensor quantization (FP8/NVFP4) expects same amax across row and column parallel ranks
# Channel-wise (INT8) only expects same amax across row parallel ranks
# Block-wise quantization does not expect same amax across row and column parallel ranks
if config in [mtq.FP8_DEFAULT_CFG, mtq.NVFP4_DEFAULT_CFG]:
if isinstance(model.fc1.weight_quantizer, SequentialQuantizer):
for quantizer in model.fc1.weight_quantizer:
_distributed_attr_check(
quantizer, "amax", dist.ReduceOp.MAX, groups=[dp_group, tp_group]
)
else:
_distributed_attr_check(
model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX, groups=[dp_group, tp_group]
)

if config in [
mtq.FP8_DEFAULT_CFG,
mtq.NVFP4_DEFAULT_CFG,
mtq.INT8_DEFAULT_CFG,
mtq.INT8_SMOOTHQUANT_CFG,
]:
if isinstance(model.fc2.weight_quantizer, SequentialQuantizer):
for quantizer in model.fc2.weight_quantizer:
_distributed_attr_check(
quantizer, "amax", dist.ReduceOp.MAX, groups=[dp_group, tp_group]
)
else:
_distributed_attr_check(
model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX, groups=[dp_group, tp_group]
)

# Lets check the column parallel pre_quant_scale; it should be the same across all tp ranks
# It is different across DP/CP ranks since the input is different
if (
test_pre_quant_scale
and tp_group
and config in [mtq.INT8_SMOOTHQUANT_CFG, mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]
):
input_quantizer = model.fc1.input_quantizer
_distributed_attr_check(
input_quantizer, "pre_quant_scale", dist.ReduceOp.MAX, groups=[dp_group, tp_group]
)

# Check act scale
if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]:
_distributed_attr_check(
model.fc1.awq_lite, "act_scale", dist.ReduceOp.AVG, groups=[dp_group, tp_group]
)


def auto_quantize_helper(model):
Expand Down
6 changes: 6 additions & 0 deletions tests/gpu/torch/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ def need_2_gpus():
pytest.skip("Need at least 2 GPUs to run this test")


@pytest.fixture
def need_8_gpus():
if torch.cuda.device_count() < 8:
pytest.skip("Need at least 8 GPUs to run this test")


@pytest.fixture(scope="module")
def set_torch_dtype(request):
orig_dtype = torch.get_default_dtype()
Expand Down
Loading
Loading