Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: Print FloatAttr special values #3377

Merged
merged 11 commits into from
Nov 4, 2024
24 changes: 24 additions & 0 deletions tests/test_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from xdsl.dialects import test
from xdsl.dialects.arith import Addi, Arith, Constant
from xdsl.dialects.builtin import (
AnyFloatAttr,
Builtin,
FloatAttr,
FunctionType,
IntAttr,
IntegerType,
Expand Down Expand Up @@ -761,6 +763,28 @@ def test_densearray_attr():
assert_print_op(parsed, prog, None)


def test_float_attr_specials():
printer = Printer()

def _test_attr_print(expected: str, attr: AnyFloatAttr):
io = StringIO()
printer.stream = io
printer.print_attribute(attr)
assert io.getvalue() == expected

_test_attr_print("0x7e00 : f16", FloatAttr(float("nan"), 16))
_test_attr_print("0x7c00 : f16", FloatAttr(float("inf"), 16))
_test_attr_print("0xfc00 : f16", FloatAttr(float("-inf"), 16))

_test_attr_print("0x7fc00000 : f32", FloatAttr(float("nan"), 32))
_test_attr_print("0x7f800000 : f32", FloatAttr(float("inf"), 32))
_test_attr_print("0xff800000 : f32", FloatAttr(float("-inf"), 32))

_test_attr_print("0x7ff8000000000000 : f64", FloatAttr(float("nan"), 64))
_test_attr_print("0x7ff0000000000000 : f64", FloatAttr(float("inf"), 64))
_test_attr_print("0xfff0000000000000 : f64", FloatAttr(float("-inf"), 64))


def test_print_function_type():
io = StringIO()
printer = Printer(stream=io)
Expand Down
48 changes: 32 additions & 16 deletions xdsl/printer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import json
import math
from collections.abc import Callable, Iterable, Sequence
from contextlib import contextmanager
from dataclasses import dataclass, field
Expand Down Expand Up @@ -69,6 +70,11 @@
TypeAttribute,
)
from xdsl.traits import IsolatedFromAbove, IsTerminator
from xdsl.utils.bitwise_casts import (
convert_f16_to_u16,
convert_f32_to_u32,
convert_f64_to_u64,
)
from xdsl.utils.diagnostic import Diagnostic
from xdsl.utils.lexer import Lexer

Expand Down Expand Up @@ -460,6 +466,27 @@ def print_bytes_literal(self, bytestring: bytes):
self.print_string(chr(byte))
self.print_string('"')

def print_float(self, attribute: AnyFloatAttr):
value = attribute.value
if math.isnan(value.data) or math.isinf(value.data):
if isinstance(attribute.type, Float16Type):
self.print_string(f"{hex(convert_f16_to_u16(value.data))}")
elif isinstance(attribute.type, Float32Type):
self.print_string(f"{hex(convert_f32_to_u32(value.data))}")
elif isinstance(attribute.type, Float64Type):
self.print_string(f"{hex(convert_f64_to_u64(value.data))}")
else:
raise NotImplementedError(
f"Cannot print '{value.data}' value for float type {str(attribute.type)}"
)
else:
# to mirror mlir-opt, attempt to print scientific notation iff the value parses losslessly
float_str = f"{value.data:.6e}"
if float(float_str) == value.data:
self.print_string(float_str)
else:
self.print_string(f"{repr(value.data)}")

def print_attribute(self, attribute: Attribute) -> None:
if isinstance(attribute, UnitAttr):
self.print_string("unit")
Expand Down Expand Up @@ -533,17 +560,10 @@ def print_attribute(self, attribute: Attribute) -> None:
return

if isinstance(attribute, FloatAttr):
value = attribute.value
attr_type = cast(
FloatAttr[Float16Type | Float32Type | Float64Type], attribute
).type
# to mirror mlir-opt, attempt to print scientific notation iff the value parses losslessly
float_str = f"{value.data:.6e}"
if float(float_str) == value.data:
self.print_string(f"{float_str} : ")
else:
self.print_string(f"{repr(value.data)} : ")
self.print_attribute(attr_type)
attr = cast(AnyFloatAttr, attribute)
self.print_float(attr)
self.print_string(" : ")
self.print_attribute(attr.type)
return

# Complex types have MLIR shorthands but XDSL does not.
Expand Down Expand Up @@ -603,11 +623,7 @@ def print_one_elem(val: Attribute):
if isinstance(val, IntegerAttr):
self.print_string(f"{val.value.data}")
elif isinstance(val, FloatAttr):
float_str = f"{val.value.data:.6e}"
if float(float_str) == val.value.data:
self.print_string(float_str)
else:
self.print_string(f"{repr(val.value.data)}")
self.print_float(cast(AnyFloatAttr, val))
else:
raise Exception(
"unexpected attribute type "
Expand Down
31 changes: 29 additions & 2 deletions xdsl/utils/bitwise_casts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,23 @@
"""

import ctypes
import struct


def convert_f16_to_u16(value: float) -> int:
"""
Convert an IEEE 754 float to a raw unsigned integer representation.
"""
# using struct library as ctypes does not support half-precision floats
return struct.unpack("<H", struct.pack("<e", value))[0]


def convert_f32_to_u32(value: float) -> int:
"""
Convert an IEEE 754 float to a raw unsigned integer representation.
"""
raw_float = ctypes.c_float(value)
raw_int = ctypes.c_uint32.from_address(ctypes.addressof(raw_float)).value
raw_int = ctypes.c_uint32.from_buffer(raw_float).value
return raw_int


Expand All @@ -20,7 +29,25 @@ def convert_u32_to_f32(value: int) -> float:
Convert a raw 32-bit unsigned integer to IEEE 754 float representation.
"""
raw_int = ctypes.c_uint32(value)
raw_float = ctypes.c_float.from_address(ctypes.addressof(raw_int)).value
raw_float = ctypes.c_float.from_buffer(raw_int).value
return raw_float


def convert_f64_to_u64(value: float) -> int:
"""
Convert an IEEE 754 float to a raw unsigned integer representation.
"""
raw_float = ctypes.c_double(value)
raw_int = ctypes.c_uint64.from_buffer(raw_float).value
return raw_int


def convert_u64_to_f64(value: int) -> float:
"""
Convert a raw 32-bit unsigned integer to IEEE 754 float representation.
"""
raw_int = ctypes.c_uint64(value)
raw_float = ctypes.c_double.from_buffer(raw_int).value
return raw_float


Expand Down
Loading