Skip to content

Commit 91837c3

Browse files
committed
minor fix
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
1 parent 5481d10 commit 91837c3

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

tests/_test_utils/torch_dist/plugins/megatron_common.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -636,10 +636,9 @@ def compare_amax_sync_across_expert_parallel(model, compare_across_experts=True)
636636
# fc2: RowParallel: [X_1, X_2] @ [A_1
637637
# A_2] (weights split along Cin)
638638
# amax should be the same across all ranks
639-
640639
rank_groups = (
641640
list(etp_groups.values())
642-
if "linear_fc1" in quantizer_type and rank_values[0].ndim > 0
641+
if "linear_fc1" in quantizer_type and (next(iter(rank_values.values()))).ndim > 0
643642
else [list(range(world_size))]
644643
)
645644

0 commit comments

Comments
 (0)