diff --git a/nncf/quantization/algorithms/weight_compression/algorithm.py b/nncf/quantization/algorithms/weight_compression/algorithm.py index 1d0da998605..0190de36471 100644 --- a/nncf/quantization/algorithms/weight_compression/algorithm.py +++ b/nncf/quantization/algorithms/weight_compression/algorithm.py @@ -43,7 +43,6 @@ from nncf.scopes import IgnoredScope from nncf.scopes import get_ignored_node_names_from_ignored_scope from nncf.tensor import Tensor -from nncf.tensor.definitions import TensorDataType TModel = TypeVar("TModel") TTensor = TypeVar("TTensor") @@ -322,21 +321,6 @@ def _get_bitwidth_distribution_str( pretty_string = f"Statistics of the bitwidth distribution:\n{table}" return pretty_string - @staticmethod - def _is_float(dtype): - """ - Check if the given data type is a floating-point type. - - :param dtype: The data type to check. - :return: True if the data type is one of the floating-point types, False otherwise. - """ - return dtype in [ - TensorDataType.float16, - TensorDataType.bfloat16, - TensorDataType.float32, - TensorDataType.float64, - ] - def _get_ignored_scope_weight_statistics(self, model: TModel, graph: NNCFGraph) -> List[int]: """ Collect the weight statistics for nodes in the ignored scope. @@ -359,7 +343,7 @@ def _get_ignored_scope_weight_statistics(self, model: TModel, graph: NNCFGraph) continue for _, weight_port_id in self._backend_entity.get_weight_names_and_port_ids(node, graph): weight_dtype = self._backend_entity.get_weight_dtype(node, weight_port_id, model, graph) - if not WeightCompression._is_float(weight_dtype): + if weight_dtype.is_float(): continue weight_shape = self._backend_entity.get_weight_shape(node, weight_port_id, graph) weight_size = reduce(operator.mul, weight_shape, 1) @@ -391,7 +375,7 @@ def apply( continue weight_dtype = self._backend_entity.get_weight_dtype(node, weight_port_id, model, graph) - if not WeightCompression._is_float(weight_dtype): + if not weight_dtype.is_float(): continue weight_shape = self._backend_entity.get_weight_shape(node, weight_port_id, graph) weight_size = reduce(operator.mul, weight_shape, 1) diff --git a/nncf/tensor/definitions.py b/nncf/tensor/definitions.py index 447a6dd8bb5..2c0ac24f8ae 100644 --- a/nncf/tensor/definitions.py +++ b/nncf/tensor/definitions.py @@ -25,7 +25,7 @@ class TensorBackendType(Enum): class TensorDataType(Enum): """ - Enum representing the different tensor data types. + The class represents the different tensor data types. """ float16 = auto() @@ -37,6 +37,12 @@ class TensorDataType(Enum): int64 = auto() uint8 = auto() + def is_float(self): + """ + :return: True if the tensor data type is a floating-point type, else False. + """ + return self in [TensorDataType.float16, TensorDataType.bfloat16, TensorDataType.float32, TensorDataType.float64] + class TensorDeviceType(Enum): """