Skip to content

Commit

Permalink
Move is_float method
Browse files Browse the repository at this point in the history
  • Loading branch information
l-bat committed Sep 27, 2024
1 parent 6b83591 commit c23fe73
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 19 deletions.
20 changes: 2 additions & 18 deletions nncf/quantization/algorithms/weight_compression/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from nncf.scopes import IgnoredScope
from nncf.scopes import get_ignored_node_names_from_ignored_scope
from nncf.tensor import Tensor
from nncf.tensor.definitions import TensorDataType

TModel = TypeVar("TModel")
TTensor = TypeVar("TTensor")
Expand Down Expand Up @@ -322,21 +321,6 @@ def _get_bitwidth_distribution_str(
pretty_string = f"Statistics of the bitwidth distribution:\n{table}"
return pretty_string

@staticmethod
def _is_float(dtype):
"""
Check if the given data type is a floating-point type.
:param dtype: The data type to check.
:return: True if the data type is one of the floating-point types, False otherwise.
"""
return dtype in [
TensorDataType.float16,
TensorDataType.bfloat16,
TensorDataType.float32,
TensorDataType.float64,
]

def _get_ignored_scope_weight_statistics(self, model: TModel, graph: NNCFGraph) -> List[int]:
"""
Collect the weight statistics for nodes in the ignored scope.
Expand All @@ -359,7 +343,7 @@ def _get_ignored_scope_weight_statistics(self, model: TModel, graph: NNCFGraph)
continue
for _, weight_port_id in self._backend_entity.get_weight_names_and_port_ids(node, graph):
weight_dtype = self._backend_entity.get_weight_dtype(node, weight_port_id, model, graph)
if not WeightCompression._is_float(weight_dtype):
if weight_dtype.is_float():
continue
weight_shape = self._backend_entity.get_weight_shape(node, weight_port_id, graph)
weight_size = reduce(operator.mul, weight_shape, 1)
Expand Down Expand Up @@ -391,7 +375,7 @@ def apply(
continue

weight_dtype = self._backend_entity.get_weight_dtype(node, weight_port_id, model, graph)
if not WeightCompression._is_float(weight_dtype):
if not weight_dtype.is_float():
continue
weight_shape = self._backend_entity.get_weight_shape(node, weight_port_id, graph)
weight_size = reduce(operator.mul, weight_shape, 1)
Expand Down
8 changes: 7 additions & 1 deletion nncf/tensor/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class TensorBackendType(Enum):

class TensorDataType(Enum):
"""
Enum representing the different tensor data types.
The class represents the different tensor data types.
"""

float16 = auto()
Expand All @@ -37,6 +37,12 @@ class TensorDataType(Enum):
int64 = auto()
uint8 = auto()

def is_float(self):
"""
:return: True if the tensor data type is a floating-point type, else False.
"""
return self in [TensorDataType.float16, TensorDataType.bfloat16, TensorDataType.float32, TensorDataType.float64]


class TensorDeviceType(Enum):
"""
Expand Down

0 comments on commit c23fe73

Please sign in to comment.