diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 06475d42..b7b5e55d 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -195,7 +195,14 @@ def apply_quantization_status(model: Module, status: QuantizationStatus): model.apply(initialize_module_for_quantization) if current_status < status >= QuantizationStatus.CALIBRATION > current_status: - model.apply(set_module_for_calibration) + # only quantize weights up front when our end goal state is calibration, + # weight quantization parameters are already loaded for frozen/compressed + quantize_weights_upfront = status == QuantizationStatus.CALIBRATION + model.apply( + lambda module: set_module_for_calibration( + module, quantize_weights_upfront=quantize_weights_upfront + ) + ) if current_status < status >= QuantizationStatus.FROZEN > current_status: model.apply(freeze_module_quantization) diff --git a/src/compressed_tensors/quantization/lifecycle/calibration.py b/src/compressed_tensors/quantization/lifecycle/calibration.py index 7bdfddd0..b1fe2126 100644 --- a/src/compressed_tensors/quantization/lifecycle/calibration.py +++ b/src/compressed_tensors/quantization/lifecycle/calibration.py @@ -28,7 +28,7 @@ _LOGGER = logging.getLogger(__name__) -def set_module_for_calibration(module: Module): +def set_module_for_calibration(module: Module, quantize_weights_upfront: bool = True): """ marks a layer as ready for calibration which activates observers to update scales and zero points on each forward pass @@ -36,6 +36,8 @@ def set_module_for_calibration(module: Module): apply to full model with `model.apply(set_module_for_calibration)` :param module: module to set for calibration + :param quantize_weights_upfront: whether to automatically run weight quantization at the + start of calibration """ if not getattr(module, "quantization_scheme", None): # no quantization scheme nothing to do @@ -49,7 +51,7 @@ def set_module_for_calibration(module: Module): "to re-calibrate a frozen module" ) - if module.quantization_scheme.weights is not None: + if quantize_weights_upfront and module.quantization_scheme.weights is not None: # set weight scale and zero_point up front, calibration data doesn't affect it observer = module.weight_observer