Skip to content

Commit

Permalink
[IR] Attach uint attributes (#283)
Browse files Browse the repository at this point in the history
  • Loading branch information
chhzh123 authored Jan 1, 2025
1 parent 63e83a7 commit 90a0efe
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 4 deletions.
17 changes: 14 additions & 3 deletions allo/ir/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,10 @@ def build_shaped_type(ctx, dtype, shape, layout=None):
def build_array(ctx, dtype, shape):
if not ctx.enable_tensor:
memref_type = MemRefType.get(shape, dtype.build())
return memref_d.AllocOp(memref_type, [], [], ip=ctx.get_ip())
alloc_op = memref_d.AllocOp(memref_type, [], [], ip=ctx.get_ip())
if isinstance(dtype, UInt):
alloc_op.attributes["unsigned"] = UnitAttr.get()
return alloc_op
return tensor_d.EmptyOp(shape, dtype.build(), ip=ctx.get_ip())

@staticmethod
Expand Down Expand Up @@ -690,10 +693,14 @@ def build_general_binop(ctx, node, lhs, rhs):
if isinstance(node.op, (ast.LShift, ast.RShift)) and isinstance(
node.dtype, (Fixed, UFixed)
):
return opcls[ty_cls](
op = opcls[ty_cls](
node.dtype.build(), lhs.result, rhs.result, ip=ctx.get_ip()
)
return opcls[ty_cls](lhs.result, rhs.result, ip=ctx.get_ip())
else:
op = opcls[ty_cls](lhs.result, rhs.result, ip=ctx.get_ip())
if isinstance(node.dtype, UInt):
op.attributes["unsigned"] = UnitAttr.get()
return op

@staticmethod
def build_UnaryOp(ctx, node):
Expand Down Expand Up @@ -1122,6 +1129,8 @@ def build_memory_access(ctx, node, val=None, idx=0):
affine_attr,
ip=ctx.get_ip(),
)
if isinstance(node.value.dtype, UInt):
op.attributes["unsigned"] = UnitAttr.get()
else: # ast.Store
op = affine_d.AffineStoreOp(
val.results[idx], value.result, ivs, affine_attr, ip=ctx.get_ip()
Expand All @@ -1131,6 +1140,8 @@ def build_memory_access(ctx, node, val=None, idx=0):
if isinstance(node.ctx, ast.Load):
# pylint: disable=redefined-variable-type
op = memref_d.LoadOp(value.result, new_indices, ip=ctx.get_ip())
if isinstance(node.value.dtype, UInt):
op.attributes["unsigned"] = UnitAttr.get()
else: # ast.Store
op = memref_d.StoreOp(
val.result,
Expand Down
4 changes: 3 additions & 1 deletion mlir/lib/Translation/EmitVivadoHLS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1768,7 +1768,9 @@ void ModuleEmitter::emitBitcast(arith::BitcastOp op) {

template <typename CastOpType> void ModuleEmitter::emitCast(CastOpType op) {
indent();
emitValue(op.getResult());
Value result = op.getResult();
fixUnsignedType(result, op->hasAttr("unsigned"));
emitValue(result);
os << " = ";
emitValue(op.getOperand());
os << ";";
Expand Down
15 changes: 15 additions & 0 deletions tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,21 @@ def kernel(a: int32) -> float32:
assert mod(1) == kernel(1)


def test_uint():
def casting():
buf1: UInt(17)[16, 16] = 0
buf2: float32[16, 16]

for i, j in allo.grid(16, 16):
buf2[i, j] = float(buf1[i, j] + buf1[j, i])

s = allo.customize(casting)
mod = s.build(target="vhls")
code = mod.hls_code
assert "ap_uint<17>" in code and "ap_uint<18>" in code
assert "ap_int<" not in code


def test_index_fixed_casting():
def test_one_cast(fixed):
def kernel(a: index) -> float32:
Expand Down

0 comments on commit 90a0efe

Please sign in to comment.