From 37df2dd0f0a8018d74d505b0d73b7816a2c6fe2f Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 1 Nov 2024 13:17:15 -0400 Subject: [PATCH] Clean up observer defaulting logic, better error message (#200) --- .../quantization/quant_args.py | 24 ++++++++----------- src/compressed_tensors/registry/registry.py | 2 +- 2 files changed, 11 insertions(+), 15 deletions(-) diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index 3259976c..4619d581 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -114,12 +114,6 @@ def get_observer(self): """ :return: torch quantization FakeQuantize built based on these QuantizationArgs """ - - # No observer required for the dynamic case - if self.dynamic: - self.observer = None - return self.observer - return self.observer @field_validator("type", mode="before") @@ -203,6 +197,7 @@ def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]: "activation ordering" ) + # infer observer w.r.t. dynamic if dynamic: if strategy not in ( QuantizationStrategy.TOKEN, @@ -214,18 +209,19 @@ def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]: "quantization", ) if observer is not None: - warnings.warn( - "No observer is used for dynamic quantization, setting to None" - ) - model.observer = None + if observer != "memoryless": # avoid annoying users with old configs + warnings.warn( + "No observer is used for dynamic quantization, setting to None" + ) + observer = None - # if we have not set an observer and we - # are running static quantization, use minmax - if not observer and not dynamic: - model.observer = "minmax" + elif observer is None: + # default to minmax for non-dynamic cases + observer = "minmax" # write back modified values model.strategy = strategy + model.observer = observer return model def pytorch_dtype(self) -> torch.dtype: diff --git a/src/compressed_tensors/registry/registry.py b/src/compressed_tensors/registry/registry.py index d8d8bc6d..76026313 100644 --- a/src/compressed_tensors/registry/registry.py +++ b/src/compressed_tensors/registry/registry.py @@ -258,7 +258,7 @@ def get_from_registry( retrieved_value = _import_and_get_value_from_module(module_path, value_name) else: # look up name in alias registry - name = _ALIAS_REGISTRY[parent_class].get(name) + name = _ALIAS_REGISTRY[parent_class].get(name, name) # look up name in registry retrieved_value = _REGISTRY[parent_class].get(name) if retrieved_value is None: