Skip to content

Commit afe6f34

Browse files
committed
fix dist has_nan
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
1 parent 5cb380c commit afe6f34

File tree

2 files changed

+6
-11
lines changed

2 files changed

+6
-11
lines changed

examples/nemo_run/qat/README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,11 @@ To run the example locally, first clone the `TensorRT-Model-Optimizer` repositor
6161
Set up repo:
6262

6363
- `git clone https://github.com/NVIDIA/TensorRT-Model-Optimizer.git`
64-
- `git clone https://github.com/NVIDIA-NeMo/NeMo.git`
6564

6665
Run docker command (modify with your paths) and export the HuggingFace token:
6766

6867
```bash
69-
docker run -v /home/user/:/home/user/ -v /home/user/NeMo:/opt/NeMo -v /home/user/TensorRT-Model-Optimizer/:/opt/TensorRT-Model-Optimizer/ --gpus all -it --shm-size 20g --rm nvcr.io/nvidia/nemo:25.09 bash
68+
docker run -v /home/user/:/home/user/ -v /home/user/TensorRT-Model-Optimizer/:/opt/TensorRT-Model-Optimizer/ --gpus all -it --shm-size 20g --rm nvcr.io/nvidia/nemo:25.09 bash
7069

7170
export HF_TOKEN=<your-token>
7271
```

modelopt/torch/quantization/model_calib.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
from modelopt.torch.opt.searcher import ForwardLoop
2828
from modelopt.torch.utils import print_rank_0
29-
from modelopt.torch.utils.distributed import ParallelState
29+
from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState
3030
from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method
3131

3232
from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context
@@ -619,15 +619,11 @@ def sync_act_scale_across_dp(module, data_parallel_group):
619619
has_nan_local = torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any(
620620
torch.isnan(module.awq_lite.weight_scale)
621621
)
622-
has_nan = torch.tensor(int(has_nan_local), device=module.awq_lite.act_scale.device)
623-
if module.parallel_state.data_parallel_group.is_initialized():
624-
dist.all_reduce(
625-
has_nan,
626-
op=dist.ReduceOp.MAX,
627-
group=module.parallel_state.data_parallel_group.group,
628-
)
622+
has_nan = DistributedProcessGroup.get_dist_syncd_obj(
623+
has_nan_local, module.parallel_state.data_parallel_group, lambda objs: any(objs)
624+
)
629625

630-
if has_nan.item() > 0:
626+
if has_nan:
631627
module.awq_lite.is_enabled = False
632628
else:
633629
sync_act_scale_across_dp(

0 commit comments

Comments
 (0)