Skip to content

Commit 5481d10

Browse files
committed
fixed tests for per-channel support
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
1 parent 28c8bbf commit 5481d10

File tree

3 files changed

+94
-42
lines changed

3 files changed

+94
-42
lines changed

modelopt/torch/quantization/plugins/megatron.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -506,9 +506,15 @@ def _setup(self):
506506
expert.linear_fc2.parallel_state = self.parallel_state
507507

508508
def sync_moe_local_experts_amax(self):
509-
"""Sync amax across experts in a SequentialMLP."""
509+
"""Sync amax across local experts in a SequentialMLP.
510+
511+
amax across EP and ETP (for RowParallel) are synchronized as part of model_calib.max_calibrate().
512+
This function is called to synchronize the amax values across local experts s.t. all localexperts will
513+
share the same amax.
514+
"""
515+
torch.distributed.barrier()
516+
# Collect amax from all local experts
510517
amax_dict = {}
511-
# gather amax values from SequentialMLP experts
512518
for expert in self.local_experts:
513519
for name, module in expert.named_modules():
514520
if isinstance(module, TensorQuantizer) and module.amax is not None:
@@ -520,7 +526,7 @@ def sync_moe_local_experts_amax(self):
520526
else torch.maximum(stored_amax, amax_tensor)
521527
)
522528

523-
# sync amax values across experts in SequentialMLP
529+
# Apply synchronized amax values back to all local experts
524530
for expert in self.local_experts:
525531
for name, module in expert.named_modules():
526532
if isinstance(module, TensorQuantizer) and module.amax is not None:

tests/_test_utils/torch_dist/plugins/megatron_common.py

Lines changed: 64 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515
import copy
1616
import re
17+
from collections import defaultdict
1718
from warnings import warn
1819

1920
import torch
@@ -41,6 +42,7 @@
4142
from megatron.core.parallel_state import (
4243
get_expert_model_parallel_group,
4344
get_expert_tensor_parallel_group,
45+
get_expert_tensor_parallel_rank,
4446
initialize_model_parallel,
4547
is_pipeline_first_stage,
4648
is_pipeline_last_stage,
@@ -190,7 +192,7 @@ def squared_relu(x):
190192
pipeline_model_parallel_size=pipeline_model_parallel_size,
191193
expert_model_parallel_size=expert_model_parallel_size,
192194
expert_tensor_parallel_size=expert_tensor_parallel_size,
193-
sequence_parallel=expert_model_parallel_size > 1,
195+
sequence_parallel=False,
194196
moe_grouped_gemm=moe_grouped_gemm,
195197
num_layers=num_layers,
196198
num_layers_in_first_pipeline_stage=num_layers_in_first_pipeline_stage,
@@ -221,7 +223,12 @@ def squared_relu(x):
221223
else:
222224
assert HAS_TE, "Transformer Engine not installed"
223225
transformer_layer_spec = (
224-
get_gpt_modelopt_spec(config, remap_te_layernorm=True)
226+
get_gpt_modelopt_spec(
227+
config,
228+
remap_te_layernorm=True,
229+
# TODO: uncomment this when TEGroupedMLP is enabled in Megatron-LM
230+
# moe_grouped_gemm=moe_grouped_gemm
231+
)
225232
if transformer_impl == "modelopt"
226233
else get_gpt_layer_with_transformer_engine_spec()
227234
)
@@ -565,8 +572,7 @@ def compare_amax_sync_across_expert_parallel(model, compare_across_experts=True)
565572
# Check for both TEGrouped and sequential MoE patterns
566573
if "local_experts" in name or ("experts" in name and "linear_fc" in name):
567574
# Convert to scalar only if tensor has a single element
568-
amax_val = module.amax.detach().clone().cpu()
569-
expert_amax_values[name] = amax_val
575+
expert_amax_values[name] = module.amax.detach().clone().cpu()
570576

571577
# Early return if no expert quantizers found
572578
assert expert_amax_values, "No expert quantizers found"
@@ -577,19 +583,16 @@ def compare_amax_sync_across_expert_parallel(model, compare_across_experts=True)
577583
torch.distributed.all_gather_object(all_amax_values, expert_amax_values)
578584

579585
# Group quantizers by type (ignoring specific expert indices) and check sync
580-
expert_quantizers = {}
586+
expert_quantizers = defaultdict(dict)
581587
for rank_idx, rank_amax in enumerate(all_amax_values):
582588
for name, amax_val in rank_amax.items():
583589
# Create quantizer type key by normalizing the name
584-
if "local_experts" in name:
585-
# sequential MoE: replace expert index with wildcard
586-
quantizer_type = re.sub(r"local_experts\.\d+", "local_experts.*", name)
587-
else:
588-
# TEGrouped MoE: use the name as-is since experts are grouped
589-
quantizer_type = name
590-
591-
if quantizer_type not in expert_quantizers:
592-
expert_quantizers[quantizer_type] = {}
590+
quantizer_type = (
591+
re.sub(r"local_experts\.\d+", "local_experts.*", name)
592+
if "local_experts" in name
593+
else name
594+
)
595+
593596
if (
594597
quantizer_type in expert_quantizers
595598
and rank_idx in expert_quantizers[quantizer_type]
@@ -608,21 +611,53 @@ def compare_amax_sync_across_expert_parallel(model, compare_across_experts=True)
608611
)
609612
expert_quantizers[quantizer_type][rank_idx] = amax_val
610613

611-
# Check synchronization - fail fast on first inconsistency
614+
rank_info = {
615+
"global_rank": torch.distributed.get_rank(),
616+
"etp_rank": get_expert_tensor_parallel_rank(),
617+
}
618+
619+
all_rank_info = [None] * world_size
620+
torch.distributed.all_gather_object(all_rank_info, rank_info)
621+
622+
# Group ranks by ETP rank for fc1 (ColumnParallel: same output channels should match)
623+
etp_groups = defaultdict(list)
624+
for info in all_rank_info:
625+
etp_groups[info["etp_rank"] if info["etp_rank"] else 0].append(info["global_rank"])
626+
612627
for quantizer_type, rank_values in expert_quantizers.items():
613-
if len(rank_values) > 1: # Only check if we have multiple ranks
614-
values = list(rank_values.values())
615-
# Handle both scalar and tensor comparisons
616-
first_val = values[0]
617-
if isinstance(first_val, torch.Tensor):
618-
# For tensors, check if all values are close to the first one
619-
for val in values[1:]:
620-
if not torch.allclose(first_val, val, rtol=1e-6, atol=1e-6):
621-
return False, quantizer_type, rank_values
622-
else:
623-
# For scalars, use numeric comparison
624-
max_diff = max(values) - min(values)
625-
if max_diff > 1e-6: # Allow for small floating point differences
626-
return False, quantizer_type, rank_values
628+
# Determine which ranks should have same amax
629+
# Find which rank should have same amax
630+
#
631+
# fc1: ColumnParallel: X @ [A_1, A_2] (weights split along Cout)
632+
# so amax should be the same across same ETP rank
633+
# if EP is 2, ETP is 2, we have 4 ranks, EP1, ETP1: 0, EP1, ETP2: 1, EP2, ETP1: 2, EP2, ETP2: 3
634+
# so we need to compare amax across same ETP rank [0, 2] [1, 3] for per-channel quantization
635+
#
636+
# fc2: RowParallel: [X_1, X_2] @ [A_1
637+
# A_2] (weights split along Cin)
638+
# amax should be the same across all ranks
639+
640+
rank_groups = (
641+
list(etp_groups.values())
642+
if "linear_fc1" in quantizer_type and rank_values[0].ndim > 0
643+
else [list(range(world_size))]
644+
)
645+
646+
# Check each group independently
647+
for group in rank_groups:
648+
group_values = [rank_values[r] for r in group if r in rank_values]
649+
if len(group_values) > 1:
650+
# All values in this group should be identical
651+
first_val = group_values[0]
652+
for val in group_values[1:]:
653+
if isinstance(first_val, torch.Tensor):
654+
if not torch.allclose(first_val, val, rtol=1e-6, atol=1e-6):
655+
group_rank_values = {
656+
r: rank_values[r] for r in group if r in rank_values
657+
}
658+
return False, f"{quantizer_type} (group {group})", group_rank_values
659+
elif abs(first_val - val) > 1e-6:
660+
group_rank_values = {r: rank_values[r] for r in group if r in rank_values}
661+
return False, f"{quantizer_type} (group {group})", group_rank_values
627662

628663
return True, None, None

tests/gpu/torch/quantization/plugins/test_megatron.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
)
4646
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
4747
from megatron.core.transformer.moe.experts import SequentialMLP, TEGroupedMLP
48+
from megatron.core.transformer.moe.router import TopKRouter
4849

4950
import modelopt
5051
import modelopt.torch.opt as mto
@@ -240,6 +241,7 @@ def _gpt_model_provider(
240241
ep_size=1,
241242
etp_size=None,
242243
use_te=False,
244+
transformer_impl="local",
243245
):
244246
"""Build the model."""
245247

@@ -253,7 +255,7 @@ def _gpt_model_provider(
253255
ffn_hidden_size=None,
254256
num_attention_heads=8,
255257
activation_func="squared_relu",
256-
transformer_impl="local",
258+
transformer_impl=transformer_impl,
257259
hidden_size=hidden_size,
258260
vocab_size=vocab_size,
259261
use_cpu_initialization=meta_device,
@@ -270,7 +272,7 @@ def _gpt_model_provider(
270272
ffn_hidden_size=None,
271273
num_attention_heads=8,
272274
activation_func="squared_relu",
273-
transformer_impl="local",
275+
transformer_impl=transformer_impl,
274276
hidden_size=hidden_size,
275277
vocab_size=vocab_size,
276278
num_moe_experts=num_moe_experts,
@@ -297,6 +299,7 @@ def _test_sharded_state_dict(
297299
num_moe_experts = moe_config.get("num_moe_experts", None)
298300
moe_grouped_gemm = moe_config.get("moe_grouped_gemm", False)
299301
use_te = moe_config.get("use_te", False)
302+
transformer_impl = moe_config.get("transformer_impl", "local")
300303

301304
initialize_for_megatron(
302305
tensor_model_parallel_size=tp_size,
@@ -314,6 +317,7 @@ def _test_sharded_state_dict(
314317
use_te=use_te,
315318
ep_size=ep_size,
316319
etp_size=etp_size,
320+
transformer_impl=transformer_impl,
317321
)
318322
model_test = _gpt_model_provider(
319323
tp_size,
@@ -325,6 +329,7 @@ def _test_sharded_state_dict(
325329
meta_device=meta_device,
326330
ep_size=ep_size,
327331
etp_size=etp_size,
332+
transformer_impl=transformer_impl,
328333
)
329334

330335
prompt_tokens = torch.randint(
@@ -531,10 +536,7 @@ def test_fp8_real_quantize():
531536

532537
@pytest.mark.parametrize(
533538
"config",
534-
[
535-
mtq.FP8_DEFAULT_CFG,
536-
mtq.NVFP4_DEFAULT_CFG,
537-
],
539+
[mtq.FP8_DEFAULT_CFG, mtq.NVFP4_DEFAULT_CFG, mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG],
538540
)
539541
@pytest.mark.parametrize("moe_grouped_gemm", [True, False])
540542
def test_moe_sharded_state_dict(need_4_gpus, tmp_path, config, moe_grouped_gemm):
@@ -549,6 +551,7 @@ def test_moe_sharded_state_dict(need_4_gpus, tmp_path, config, moe_grouped_gemm)
549551
"num_moe_experts": 4,
550552
"moe_grouped_gemm": moe_grouped_gemm,
551553
"use_te": moe_grouped_gemm,
554+
"transformer_impl": "modelopt",
552555
}
553556
spawn_multiprocess_job(
554557
size=size,
@@ -606,6 +609,7 @@ def forward_fn(model):
606609
hidden_size=32,
607610
moe_grouped_gemm=False,
608611
num_moe_experts=4,
612+
transformer_impl="modelopt",
609613
)
610614
num_sequential_mlp = sum(
611615
isinstance(module, SequentialMLP) for module in sequential_moe_model.modules()
@@ -666,10 +670,16 @@ def _test_expert_model_parallel_amax_sync(
666670
hidden_size=256,
667671
moe_grouped_gemm=moe_grouped_gemm,
668672
use_te=moe_grouped_gemm,
669-
num_moe_experts=4,
673+
num_moe_experts=8,
674+
transformer_impl="modelopt",
670675
)
671676
prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda()
672677

678+
# force all expert routing
679+
for module in model.modules():
680+
if isinstance(module, TopKRouter):
681+
module.topk = module.num_experts
682+
673683
def forward_fn(model):
674684
return megatron_prefill(model, prompt_tokens)
675685

@@ -701,9 +711,10 @@ def forward_fn(model):
701711
assert final_sync, f"Inconsistent amax for expert {quantizer_type} across ranks: {rank_values}"
702712

703713

714+
@pytest.mark.parametrize("config", [mtq.FP8_DEFAULT_CFG, mtq.INT8_DEFAULT_CFG])
704715
@pytest.mark.parametrize(("ep_size", "etp_size"), [(1, 2), (2, 1), (2, 2)])
705716
@pytest.mark.parametrize("moe_grouped_gemm", [True, False])
706-
def test_expert_parallel_sync(ep_size, etp_size, moe_grouped_gemm):
717+
def test_expert_parallel_sync(config, ep_size, etp_size, moe_grouped_gemm):
707718
"""Test expert model parallel synchronization."""
708719
size = torch.cuda.device_count()
709720
if size < ep_size * etp_size:
@@ -716,11 +727,11 @@ def test_expert_parallel_sync(ep_size, etp_size, moe_grouped_gemm):
716727
size=size,
717728
job=partial(
718729
_test_expert_model_parallel_amax_sync,
719-
2,
730+
etp_size, # tp_size
720731
ep_size,
721732
etp_size,
722733
moe_grouped_gemm,
723-
mtq.FP8_DEFAULT_CFG,
734+
config,
724735
),
725736
backend="nccl",
726737
)

0 commit comments

Comments
 (0)