diff --git a/src/sparsetensors/quantization/lifecycle/apply.py b/src/sparsetensors/quantization/lifecycle/apply.py index 4c78568d..ac238564 100644 --- a/src/sparsetensors/quantization/lifecycle/apply.py +++ b/src/sparsetensors/quantization/lifecycle/apply.py @@ -50,7 +50,7 @@ def apply_quantization_config(model: Module, config: QuantizationConfig): target_to_scheme[target] = scheme # mark appropriate layers for quantization by setting their quantization schemes - for name, submodule in _iter_named_leaf_modules(model): + for name, submodule in iter_named_leaf_modules(model): if _find_first_name_or_class_match(name, submodule, config.ignore): continue # layer matches ignore list, continue target = _find_first_name_or_class_match(name, submodule, target_to_scheme) diff --git a/src/sparsetensors/quantization/quant_args.py b/src/sparsetensors/quantization/quant_args.py index d90fe9bc..76bd61f0 100644 --- a/src/sparsetensors/quantization/quant_args.py +++ b/src/sparsetensors/quantization/quant_args.py @@ -21,7 +21,7 @@ __all__ = ["QuantizationType", "QuantizationStrategy", "QuantizationArgs"] -class QuantizationType(Enum): +class QuantizationType(str, Enum): """ Enum storing quantization type options """ @@ -30,7 +30,7 @@ class QuantizationType(Enum): FLOAT = "float" -class QuantizationStrategy(Enum): +class QuantizationStrategy(str, Enum): """ Enum storing quantization strategy options """ diff --git a/src/sparsetensors/quantization/quant_config.py b/src/sparsetensors/quantization/quant_config.py index c70e7c45..2a2b345f 100644 --- a/src/sparsetensors/quantization/quant_config.py +++ b/src/sparsetensors/quantization/quant_config.py @@ -33,7 +33,7 @@ ] -class QuantizationStatus(Enum): +class QuantizationStatus(str, Enum): """ Enum storing the different states a quantized layer can be in diff --git a/src/sparsetensors/quantization/utils/helpers.py b/src/sparsetensors/quantization/utils/helpers.py index 52bebf58..3c00cdbe 100644 --- a/src/sparsetensors/quantization/utils/helpers.py +++ b/src/sparsetensors/quantization/utils/helpers.py @@ -108,8 +108,6 @@ def calculate_compression_ratio(model: Module) -> float: compressed_bits = uncompressed_bits if is_module_quantized(submodule): compressed_bits = submodule.quantization_scheme.weights.num_bits - else: - print(name) num_weights = parameter.numel() total_compressed += compressed_bits * num_weights total_uncompressed += uncompressed_bits * num_weights