Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: Add FixedbitWidthType Interface #2904

Merged
merged 6 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions tests/dialects/test_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@


def test_FloatType_bitwidths():
assert BFloat16Type().get_bitwidth == 16
assert Float16Type().get_bitwidth == 16
assert Float32Type().get_bitwidth == 32
assert Float64Type().get_bitwidth == 64
assert Float80Type().get_bitwidth == 80
assert Float128Type().get_bitwidth == 128
assert BFloat16Type().bitwidth == 16
assert Float16Type().bitwidth == 16
assert Float32Type().bitwidth == 32
assert Float64Type().bitwidth == 64
assert Float80Type().bitwidth == 80
assert Float128Type().bitwidth == 128


def test_DenseIntOrFPElementsAttr_fp_type_conversion():
Expand Down
39 changes: 7 additions & 32 deletions xdsl/backend/riscv/lowering/convert_memref_to_riscv.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from xdsl.dialects.builtin import (
AnyFloat,
DenseIntOrFPElementsAttr,
FixedBitwidthType,
Float32Type,
Float64Type,
IntegerType,
Expand All @@ -37,42 +38,14 @@
from xdsl.utils.exceptions import DiagnosticException


def bitwidth_of_type(type_attribute: Attribute) -> int:
"""
Returns the width of an element type in bits, or raises DiagnosticException for unknown inputs.
"""
if isinstance(type_attribute, AnyFloat):
return type_attribute.get_bitwidth
elif isinstance(type_attribute, IntegerType):
return type_attribute.width.data
else:
raise NotImplementedError(
f"Unsupported memref element type for riscv lowering: {type_attribute}"
)


def element_size_for_type(type_attribute: Attribute) -> int:
"""
Returns the width of an element type in bytes, or raises DiagnosticException for
unknown inputs, or sizes not divisible by 8.
"""
bitwidth = bitwidth_of_type(type_attribute)
if bitwidth % 8:
raise DiagnosticException(
f"Cannot determine size for element type {type_attribute}"
f" with bitwidth {bitwidth}"
)
bytes_per_element = bitwidth // 8
return bytes_per_element


class ConvertMemrefAllocOp(RewritePattern):

@op_type_rewrite_pattern
def match_and_rewrite(self, op: memref.Alloc, rewriter: PatternRewriter) -> None:
assert isinstance(op_memref_type := op.memref.type, memref.MemRefType)
op_memref_type = cast(memref.MemRefType[Any], op_memref_type)
width_in_bytes = bitwidth_of_type(op_memref_type.element_type) // 8
assert isinstance(op_memref_type.element_type, FixedBitwidthType)
width_in_bytes = op_memref_type.element_type.size
size = prod(op_memref_type.get_shape()) * width_in_bytes
rewriter.replace_matched_op(
(
Expand Down Expand Up @@ -117,7 +90,8 @@ def get_strided_pointer(
a new pointer to the element being accessed by the 'indices'.
"""

bytes_per_element = element_size_for_type(memref_type.element_type)
assert isinstance(memref_type.element_type, FixedBitwidthType)
bytes_per_element = memref_type.element_type.size

match memref_type.layout:
case NoneAttr():
Expand Down Expand Up @@ -394,7 +368,8 @@ def match_and_rewrite(self, op: memref.Subview, rewriter: PatternRewriter):

offset = result_layout_attr.get_offset()

factor = element_size_for_type(result_type.element_type)
assert isinstance(result_type.element_type, FixedBitwidthType)
factor = result_type.element_type.size

if offset == 0:
rewriter.replace_matched_op(
Expand Down
55 changes: 41 additions & 14 deletions xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,31 @@ def print_parameter(self, printer: Printer) -> None:
raise ValueError(f"Invalid signedness {data}")


class FixedBitwidthType(TypeAttribute, ABC):
"""
A type attribute with a defined bitwidth
"""

name = "abstract.bitwidth_type"

@property
@abstractmethod
def bitwidth(self) -> int:
"""
Contiguous memory footprint in bits
"""
raise NotImplementedError()

@property
def size(self) -> int:
"""
Contiguous memory footprint in bytes, defaults to `ceil(bitwidth / 8)`
"""
return self.bitwidth >> 3 + bool(self.bitwidth % 8)


@irdl_attr_definition
class IntegerType(ParametrizedAttribute, TypeAttribute):
class IntegerType(ParametrizedAttribute, FixedBitwidthType):
name = "integer_type"
width: ParameterDef[IntAttr]
signedness: ParameterDef[SignednessAttr]
Expand All @@ -373,6 +396,10 @@ def __init__(
def value_range(self) -> tuple[int, int]:
return self.signedness.data.value_range(self.width.data)

@property
def bitwidth(self) -> int:
return self.width.data


i64 = IntegerType(64)
i32 = IntegerType(32)
Expand Down Expand Up @@ -494,61 +521,61 @@ def print_without_type(self, printer: Printer):
class _FloatType(ABC):
@property
@abstractmethod
def get_bitwidth(self) -> int:
def bitwidth(self) -> int:
raise NotImplementedError()


@irdl_attr_definition
class BFloat16Type(ParametrizedAttribute, TypeAttribute, _FloatType):
class BFloat16Type(ParametrizedAttribute, FixedBitwidthType, _FloatType):
name = "bf16"

@property
def get_bitwidth(self) -> int:
def bitwidth(self) -> int:
return 16


@irdl_attr_definition
class Float16Type(ParametrizedAttribute, TypeAttribute, _FloatType):
class Float16Type(ParametrizedAttribute, FixedBitwidthType, _FloatType):
name = "f16"

@property
def get_bitwidth(self) -> int:
def bitwidth(self) -> int:
return 16


@irdl_attr_definition
class Float32Type(ParametrizedAttribute, TypeAttribute, _FloatType):
class Float32Type(ParametrizedAttribute, FixedBitwidthType, _FloatType):
name = "f32"

@property
def get_bitwidth(self) -> int:
def bitwidth(self) -> int:
return 32


@irdl_attr_definition
class Float64Type(ParametrizedAttribute, TypeAttribute, _FloatType):
class Float64Type(ParametrizedAttribute, FixedBitwidthType, _FloatType):
name = "f64"

@property
def get_bitwidth(self) -> int:
def bitwidth(self) -> int:
return 64


@irdl_attr_definition
class Float80Type(ParametrizedAttribute, TypeAttribute, _FloatType):
class Float80Type(ParametrizedAttribute, FixedBitwidthType, _FloatType):
name = "f80"

@property
def get_bitwidth(self) -> int:
def bitwidth(self) -> int:
return 80


@irdl_attr_definition
class Float128Type(ParametrizedAttribute, TypeAttribute, _FloatType):
class Float128Type(ParametrizedAttribute, FixedBitwidthType, _FloatType):
name = "f128"

@property
def get_bitwidth(self) -> int:
def bitwidth(self) -> int:
return 128


Expand Down
5 changes: 3 additions & 2 deletions xdsl/transforms/convert_memref_stream_to_snitch_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from functools import reduce
from typing import cast

from xdsl.backend.riscv.lowering.convert_memref_to_riscv import element_size_for_type
from xdsl.backend.riscv.lowering.utils import (
cast_operands_to_regs,
move_to_unallocated_regs,
Expand All @@ -20,6 +19,7 @@
)
from xdsl.dialects.builtin import (
ArrayAttr,
FixedBitwidthType,
IntAttr,
MemRefType,
ModuleOp,
Expand Down Expand Up @@ -203,7 +203,8 @@ def strides_map_from_memref_type(memref_type: MemRefType[AttributeCovT]) -> Affi
f"Unsupported empty shape in memref of type {memref_type}"
)

factor = element_size_for_type(memref_type.element_type)
assert isinstance(memref_type.element_type, FixedBitwidthType)
factor = memref_type.element_type.size

return AffineMap(
len(strides),
Expand Down
Loading