|
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | 15 | import copy |
| 16 | +from unittest.mock import patch |
16 | 17 |
|
17 | 18 | import pytest |
18 | 19 | import torch |
|
22 | 23 |
|
23 | 24 | import modelopt.torch.opt as mto |
24 | 25 | import modelopt.torch.quantization as mtq |
| 26 | +import modelopt.torch.quantization.model_calib as model_calib_module # needed for patching awq_lite |
25 | 27 | from modelopt.torch.quantization.backends.gemm_registry import enable_real_quant_gemm |
| 28 | +from modelopt.torch.quantization.nn.modules.tensor_quantizer import SequentialQuantizer |
26 | 29 | from modelopt.torch.quantization.utils import is_quantized_linear |
27 | 30 | from modelopt.torch.utils import torch_to |
28 | 31 |
|
@@ -116,38 +119,95 @@ def save_restore_test(model_cls, device, quant_config, compress=False, version=N |
116 | 119 | mto.restore_from_modelopt_state(model_ref, state_dict) |
117 | 120 |
|
118 | 121 |
|
119 | | -def tensor_parallel_test_helper(model, config, tp_group, dp_group): |
120 | | - # The input to fist layer, the column parallel should be the same across all tp ranks |
121 | | - calib_data = model.get_dummy_input().cuda() |
122 | | - dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group) |
| 122 | +def _distributed_attr_check(quantizer, attr: str, op=dist.ReduceOp.MAX, groups=[]): |
| 123 | + quantizer_attr = getattr(quantizer, attr).clone() |
| 124 | + for group in groups: |
| 125 | + if group is not None: |
| 126 | + dist.all_reduce(quantizer_attr, op=op, group=group) |
| 127 | + assert torch.allclose(quantizer_attr, getattr(quantizer, attr)) |
123 | 128 |
|
124 | | - def forward_loop(model): |
125 | | - model(calib_data) |
126 | 129 |
|
127 | | - model = mtq.quantize(model, config, forward_loop) |
| 130 | +original_awq_lite = model_calib_module.awq_lite |
128 | 131 |
|
129 | | - # Sanity check |
130 | | - forward_loop(model) |
131 | 132 |
|
132 | | - if config in [mtq.INT8_DEFAULT_CFG, mtq.FP8_DEFAULT_CFG, mtq.INT8_SMOOTHQUANT_CFG]: |
133 | | - # Lets check the amax for row parallel input quantizer; it should be the same across all tp ranks |
134 | | - activation_amax = model.fc2.input_quantizer.amax.clone() |
135 | | - dist.all_reduce(activation_amax, op=dist.ReduceOp.MAX, group=tp_group) |
136 | | - assert torch.allclose(activation_amax, model.fc2.input_quantizer.amax) |
| 133 | +def _debug_awq_lite(model, forward_loop, alpha_step=0.1, debug=True, **kwargs): |
| 134 | + """Function to mock awq_lite function to always use debug=True for testing""" |
| 135 | + return original_awq_lite(model, forward_loop, alpha_step, debug=True, **kwargs) |
137 | 136 |
|
138 | | - # Lets check the row parallel weight amax; it should be the same across all tp ranks |
139 | | - weight_amax = model.fc2.weight_quantizer.amax.clone() |
140 | | - dist.all_reduce(weight_amax, op=dist.ReduceOp.MAX, group=tp_group) |
141 | | - assert torch.allclose(weight_amax, model.fc2.weight_quantizer.amax) |
142 | 137 |
|
143 | | - if config in [mtq.INT8_SMOOTHQUANT_CFG, mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: |
144 | | - # Lets check the column parallel pre_quant_scale; it should be the same across all tp ranks |
145 | | - input_quantizer = model.fc1.input_quantizer |
146 | | - pre_quant_scale = input_quantizer.pre_quant_scale.clone() |
147 | | - dist.all_reduce(pre_quant_scale, op=dist.ReduceOp.MAX, group=tp_group) |
148 | | - assert torch.allclose(pre_quant_scale, input_quantizer.pre_quant_scale) |
| 138 | +@patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite) |
| 139 | +def data_tensor_context_parallel_test_helper( |
| 140 | + model, config, mock_awq_lite, dp_group=None, tp_group=None, test_pre_quant_scale=True |
| 141 | +): |
| 142 | + # Calib data should be different across each DP rank |
| 143 | + dp_rank = dist.get_rank(group=dp_group) |
| 144 | + calib_data = model.get_dummy_input(seed=dp_rank).cuda() |
| 145 | + |
| 146 | + if tp_group is not None: |
| 147 | + # The input to first layer, the column parallel should be the same across all tp ranks |
| 148 | + dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group) |
149 | 149 |
|
150 | | - dist.destroy_process_group() |
| 150 | + def forward_loop(model): |
| 151 | + model(calib_data) |
| 152 | + |
| 153 | + model = mtq.quantize(model, config, forward_loop) |
| 154 | + |
| 155 | + # Input quantizer amax |
| 156 | + if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: |
| 157 | + _distributed_attr_check( |
| 158 | + model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX, groups=[dp_group, tp_group] |
| 159 | + ) |
| 160 | + _distributed_attr_check( |
| 161 | + model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, groups=[dp_group, tp_group] |
| 162 | + ) |
| 163 | + |
| 164 | + # Per-tensor quantization (FP8/NVFP4) expects same amax across row and column parallel ranks |
| 165 | + # Channel-wise (INT8) only expects same amax across row parallel ranks |
| 166 | + # Block-wise quantization does not expect same amax across row and column parallel ranks |
| 167 | + if config in [mtq.FP8_DEFAULT_CFG, mtq.NVFP4_DEFAULT_CFG]: |
| 168 | + if isinstance(model.fc1.weight_quantizer, SequentialQuantizer): |
| 169 | + for quantizer in model.fc1.weight_quantizer: |
| 170 | + _distributed_attr_check( |
| 171 | + quantizer, "amax", dist.ReduceOp.MAX, groups=[dp_group, tp_group] |
| 172 | + ) |
| 173 | + else: |
| 174 | + _distributed_attr_check( |
| 175 | + model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX, groups=[dp_group, tp_group] |
| 176 | + ) |
| 177 | + |
| 178 | + if config in [ |
| 179 | + mtq.FP8_DEFAULT_CFG, |
| 180 | + mtq.NVFP4_DEFAULT_CFG, |
| 181 | + mtq.INT8_DEFAULT_CFG, |
| 182 | + mtq.INT8_SMOOTHQUANT_CFG, |
| 183 | + ]: |
| 184 | + if isinstance(model.fc2.weight_quantizer, SequentialQuantizer): |
| 185 | + for quantizer in model.fc2.weight_quantizer: |
| 186 | + _distributed_attr_check( |
| 187 | + quantizer, "amax", dist.ReduceOp.MAX, groups=[dp_group, tp_group] |
| 188 | + ) |
| 189 | + else: |
| 190 | + _distributed_attr_check( |
| 191 | + model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX, groups=[dp_group, tp_group] |
| 192 | + ) |
| 193 | + |
| 194 | + # Lets check the column parallel pre_quant_scale; it should be the same across all tp ranks |
| 195 | + # It is different across DP/CP ranks since the input is different |
| 196 | + if ( |
| 197 | + test_pre_quant_scale |
| 198 | + and tp_group |
| 199 | + and config in [mtq.INT8_SMOOTHQUANT_CFG, mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG] |
| 200 | + ): |
| 201 | + input_quantizer = model.fc1.input_quantizer |
| 202 | + _distributed_attr_check( |
| 203 | + input_quantizer, "pre_quant_scale", dist.ReduceOp.MAX, groups=[dp_group, tp_group] |
| 204 | + ) |
| 205 | + |
| 206 | + # Check act scale |
| 207 | + if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: |
| 208 | + _distributed_attr_check( |
| 209 | + model.fc1.awq_lite, "act_scale", dist.ReduceOp.AVG, groups=[dp_group, tp_group] |
| 210 | + ) |
151 | 211 |
|
152 | 212 |
|
153 | 213 | def auto_quantize_helper(model): |
|
0 commit comments