diff --git a/allo/ir/builder.py b/allo/ir/builder.py index 7bc7f768..685fa213 100644 --- a/allo/ir/builder.py +++ b/allo/ir/builder.py @@ -1775,6 +1775,16 @@ def build_Call(ctx, node): [], ip=ctx.get_ip(), ) + if node.func.attr == "bitcast": + val = build_stmt(ctx, node.func.value) + op = arith_d.BitcastOp( + node.dtype.build(), + val.result, + ip=ctx.get_ip(), + ) + if isinstance(node.func.value.dtype, UInt) or (node.dtype, UInt): + op.attributes["unsigned"] = UnitAttr.get() + return op if node.func.id in {"float", "int"}: # Python-Builtin functions diff --git a/allo/ir/infer.py b/allo/ir/infer.py index 9f6b8d8f..4076d142 100644 --- a/allo/ir/infer.py +++ b/allo/ir/infer.py @@ -17,6 +17,7 @@ AlloType, Int, UInt, + Float, Fixed, UFixed, Index, @@ -777,6 +778,17 @@ def visit_Call(ctx, node): # stream type itself node.func.value.shape = tuple() node.func.value.dtype = ctx.buffers[vid].dtype + elif node.func.attr == "bitcast": + visit_stmt(ctx, node.func.value) + # single-element operation + node.shape = tuple() + if isinstance(node.func.value.dtype, (UInt, Int)): + node.dtype = Float(node.func.value.dtype.bits) + else: + # casting between signed and unsigned types in C/C++ + # does not modify the underlying bit representation, + # but only the interpretation. + node.dtype = UInt(node.func.value.dtype.bits) else: raise RuntimeError( f"Unsupported function call or attribute method `.{node.func.attr}`" @@ -826,6 +838,7 @@ def visit_Call(ctx, node): # No argument if fn_name == "get_pid": node.shape = (tuple(), tuple()) + # pylint: disable=redefined-variable-type node.dtype = (Index(), Index()) else: node.shape = None diff --git a/mlir/lib/Translation/EmitTapaHLS.cpp b/mlir/lib/Translation/EmitTapaHLS.cpp index 10b68c0a..42f97a9b 100644 --- a/mlir/lib/Translation/EmitTapaHLS.cpp +++ b/mlir/lib/Translation/EmitTapaHLS.cpp @@ -1793,6 +1793,10 @@ void ModuleEmitter::emitConstant(arith::ConstantOp op) { void ModuleEmitter::emitBitcast(arith::BitcastOp op) { indent(); + Value result = op.getResult(); + fixUnsignedType(result, op->hasAttr("unsigned")); + Value operand = op.getOperand(); + fixUnsignedType(operand, op->hasAttr("unsigned")); emitValue(op.getResult()); os << ";\n"; indent(); diff --git a/mlir/lib/Translation/EmitVivadoHLS.cpp b/mlir/lib/Translation/EmitVivadoHLS.cpp index fd56deb8..976cfb4a 100644 --- a/mlir/lib/Translation/EmitVivadoHLS.cpp +++ b/mlir/lib/Translation/EmitVivadoHLS.cpp @@ -1743,6 +1743,10 @@ void ModuleEmitter::emitConstant(arith::ConstantOp op) { void ModuleEmitter::emitBitcast(arith::BitcastOp op) { indent(); + Value result = op.getResult(); + fixUnsignedType(result, op->hasAttr("unsigned")); + Value operand = op.getOperand(); + fixUnsignedType(operand, op->hasAttr("unsigned")); emitValue(op.getResult()); os << ";\n"; indent(); diff --git a/tests/test_bitop.py b/tests/test_bitop.py index d964f79b..5fed1a9d 100644 --- a/tests/test_bitop.py +++ b/tests/test_bitop.py @@ -4,7 +4,7 @@ import pytest import numpy as np import allo -from allo.ir.types import uint1, uint2, int32, uint8, UInt +from allo.ir.types import uint1, uint2, int32, uint8, uint32, UInt, float32 def test_scalar(): @@ -125,6 +125,70 @@ def kernel(A: int32, B: int32[11]): assert bin(1234) == "0b" + "".join([str(np_B[i]) for i in range(10, -1, -1)]) +def test_bitcast_uint2float(): + def kernel(A: uint32[10, 10]) -> float32[10, 10]: + B: float32[10, 10] + for i, j in allo.grid(10, 10): + B[i, j] = A[i, j].bitcast() + return B + + s = allo.customize(kernel) + print(s.module) + mod = s.build() + + A_np = np.random.randint(100, size=(10, 10)).astype(np.uint32) + B_np = mod(A_np) + answer = np.frombuffer(A_np.tobytes(), np.float32).reshape((10, 10)) + assert np.array_equal(B_np, answer) + + code = str(s.build(target="vhls")) + assert "union" in code and "uint32" in code + print("Passed!") + + +def test_bitcast_float2uint(): + def kernel(A: float32[10, 10]) -> uint32[10, 10]: + B: uint32[10, 10] + for i, j in allo.grid(10, 10): + B[i, j] = A[i, j].bitcast() + return B + + s = allo.customize(kernel) + print(s.module) + mod = s.build() + + A_np = np.random.rand(10, 10).astype(np.float32) + B_np = mod(A_np) + answer = np.frombuffer(A_np.tobytes(), np.uint32).reshape((10, 10)) + assert np.array_equal(B_np, answer) + + code = str(s.build(target="vhls")) + assert "union" in code and "uint32" in code + print("Passed!") + + +def test_bitcast_float2int(): + def kernel(A: float32[10, 10]) -> int32[10, 10]: + B: int32[10, 10] + for i, j in allo.grid(10, 10): + B[i, j] = A[i, j].bitcast() + return B + + s = allo.customize(kernel) + print(s.module) + mod = s.build() + + A_np = np.random.rand(10, 10).astype(np.float32) + B_np = mod(A_np) + answer = np.frombuffer(A_np.tobytes(), np.int32).reshape((10, 10)) + assert np.array_equal(B_np, answer) + + code = str(s.build(target="vhls")) + assert "union" in code and "int32" in code + print(code) + print("Passed!") + + def test_packed_bconv2D_nchw(): bs = 4 ic, oc = 16, 32