diff --git a/src/onnx_ir/_enums.py b/src/onnx_ir/_enums.py index 26ecaa47..19fa5920 100644 --- a/src/onnx_ir/_enums.py +++ b/src/onnx_ir/_enums.py @@ -142,6 +142,20 @@ def short_name(self) -> str: raise TypeError(f"Short name not available for ONNX data type: {self}") return _DATA_TYPE_TO_SHORT_NAME[self] + def is_floating_point(self) -> bool: + """Returns True if the data type is a floating point type.""" + return self in { + DataType.FLOAT, + DataType.FLOAT16, + DataType.DOUBLE, + DataType.BFLOAT16, + DataType.FLOAT8E4M3FN, + DataType.FLOAT8E4M3FNUZ, + DataType.FLOAT8E5M2, + DataType.FLOAT8E5M2FNUZ, + DataType.FLOAT4E2M1, + } + def __repr__(self) -> str: return self.name