Skip to content

Commit

Permalink
dont set quantization data on reload (#123) (#125)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sara Adkins authored Aug 8, 2024
1 parent a4c86dc commit 2f22bef
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
9 changes: 8 additions & 1 deletion src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,14 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
model.apply(initialize_module_for_quantization)

if current_status < status >= QuantizationStatus.CALIBRATION > current_status:
model.apply(set_module_for_calibration)
# only quantize weights up front when our end goal state is calibration,
# weight quantization parameters are already loaded for frozen/compressed
quantize_weights_upfront = status == QuantizationStatus.CALIBRATION
model.apply(
lambda module: set_module_for_calibration(
module, quantize_weights_upfront=quantize_weights_upfront
)
)
if current_status < status >= QuantizationStatus.FROZEN > current_status:
model.apply(freeze_module_quantization)

Expand Down
6 changes: 4 additions & 2 deletions src/compressed_tensors/quantization/lifecycle/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,16 @@
_LOGGER = logging.getLogger(__name__)


def set_module_for_calibration(module: Module):
def set_module_for_calibration(module: Module, quantize_weights_upfront: bool = True):
"""
marks a layer as ready for calibration which activates observers
to update scales and zero points on each forward pass
apply to full model with `model.apply(set_module_for_calibration)`
:param module: module to set for calibration
:param quantize_weights_upfront: whether to automatically run weight quantization at the
start of calibration
"""
if not getattr(module, "quantization_scheme", None):
# no quantization scheme nothing to do
Expand All @@ -49,7 +51,7 @@ def set_module_for_calibration(module: Module):
"to re-calibrate a frozen module"
)

if module.quantization_scheme.weights is not None:
if quantize_weights_upfront and module.quantization_scheme.weights is not None:
# set weight scale and zero_point up front, calibration data doesn't affect it
observer = module.weight_observer

Expand Down

0 comments on commit 2f22bef

Please sign in to comment.