Skip to content

Commit

Permalink
minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Sara Adkins committed Apr 16, 2024
1 parent edc35a1 commit 7bbeb65
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/sparsetensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def apply_quantization_config(model: Module, config: QuantizationConfig):
target_to_scheme[target] = scheme

# mark appropriate layers for quantization by setting their quantization schemes
for name, submodule in _iter_named_leaf_modules(model):
for name, submodule in iter_named_leaf_modules(model):
if _find_first_name_or_class_match(name, submodule, config.ignore):
continue # layer matches ignore list, continue
target = _find_first_name_or_class_match(name, submodule, target_to_scheme)
Expand Down
4 changes: 2 additions & 2 deletions src/sparsetensors/quantization/quant_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
__all__ = ["QuantizationType", "QuantizationStrategy", "QuantizationArgs"]


class QuantizationType(Enum):
class QuantizationType(str, Enum):
"""
Enum storing quantization type options
"""
Expand All @@ -30,7 +30,7 @@ class QuantizationType(Enum):
FLOAT = "float"


class QuantizationStrategy(Enum):
class QuantizationStrategy(str, Enum):
"""
Enum storing quantization strategy options
"""
Expand Down
2 changes: 1 addition & 1 deletion src/sparsetensors/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
]


class QuantizationStatus(Enum):
class QuantizationStatus(str, Enum):
"""
Enum storing the different states a quantized layer can be in
Expand Down
2 changes: 0 additions & 2 deletions src/sparsetensors/quantization/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,6 @@ def calculate_compression_ratio(model: Module) -> float:
compressed_bits = uncompressed_bits
if is_module_quantized(submodule):
compressed_bits = submodule.quantization_scheme.weights.num_bits
else:
print(name)
num_weights = parameter.numel()
total_compressed += compressed_bits * num_weights
total_uncompressed += uncompressed_bits * num_weights
Expand Down

0 comments on commit 7bbeb65

Please sign in to comment.