Skip to content

Commit ca55348

Browse files
committed
Addressed MR comments
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
1 parent 91837c3 commit ca55348

File tree

3 files changed

+4
-1
lines changed

3 files changed

+4
-1
lines changed

modelopt/torch/quantization/model_calib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def sync_quantizer_amax_across_tp(
118118
# Syncing amax across TP for sequential quantizer
119119
if isinstance(quantizer, SequentialQuantizer):
120120
for _q in quantizer:
121-
"Syncing amax across TP for sequential quantizer"
121+
# Syncing amax across TP for sequential quantizer
122122
sync_quantizer_amax_across_tp(
123123
_q, linear_name, quantizer_type, axes_for_sync, parallel_state
124124
)

modelopt/torch/quantization/plugins/transformer_engine.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def _setup(self):
7373
# GroupedMLP stores the weights as weight0, weight1, etc. To run setup in order to
7474
# initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning
7575
# self.weight0 to self.weight to run the quantizer states initialization.
76+
assert not hasattr(self, "weight"), "self.weight should not exist for TEGroupedLinear"
7677
self.weight = self.weight0
7778
# Memorize the original weight.dtype for modelopt_post_restore given that
7879
# the dtype can change later.
@@ -84,6 +85,7 @@ def modelopt_post_restore(self, prefix: str = ""):
8485
# GroupedMLP stores the weights as weight0, weight1, etc. To run post_restore in order to
8586
# initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning
8687
# self.weight0 to self.weight to run the quantizer states initialization.
88+
assert not hasattr(self, "weight"), "self.weight should not exist for TEGroupedLinear"
8789
self.weight = self.weight0
8890
super().modelopt_post_restore(prefix=prefix)
8991
# Remove self.weight after post_restore.

modelopt/torch/quantization/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ def is_quantized_linear(module):
253253
and hasattr(module, "weight_quantizer")
254254
and (
255255
(getattr(module, "weight", None) is not None and module.weight.dim() == 2)
256+
# module.weight0 check is required to support TEGroupedLinear
256257
or (getattr(module, "weight0", None) is not None and module.weight0.dim() == 2)
257258
)
258259
)

0 commit comments

Comments
 (0)