Skip to content

Commit

Permalink
[Type][Backend] Support fp16 (#285)
Browse files Browse the repository at this point in the history
  • Loading branch information
chhzh123 authored Jan 4, 2025
1 parent 71667ce commit e8b597e
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 6 deletions.
13 changes: 11 additions & 2 deletions allo/backend/llvm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright Allo authors. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
# pylint: disable=no-name-in-module, inconsistent-return-statements
# pylint: disable=no-name-in-module, inconsistent-return-statements, too-many-function-args

import os
import ctypes
Expand Down Expand Up @@ -155,7 +155,10 @@ def __call__(self, *args):
f"Input type mismatch: {target_in_type} vs f32. Please use NumPy array"
" to wrap the data to avoid possible result mismatch"
).warn()
if target_in_type == "f32":
if target_in_type == "f16":
c_float_p = ctypes.c_int16 * 1
arg = np.float16(arg).view(np.int16)
elif target_in_type == "f32":
c_float_p = ctypes.c_float * 1
else: # f64
c_float_p = ctypes.c_double * 1
Expand Down Expand Up @@ -317,6 +320,8 @@ def __call__(self, *args):
ret = struct_array_to_int_array(
ret, bitwidth, result_type[0] == "i"
)
elif result_type == "f16":
ret = np.array(ret, dtype=np.int16).view(np.float16)
elif result_type.startswith("fixed") or result_type.startswith(
"ufixed"
):
Expand All @@ -333,6 +338,8 @@ def __call__(self, *args):
# INVOKE
self.execution_engine.invoke(self.top_func_name, *arg_ptrs, return_ptr)
ret = return_ptr[0]
if result_type == "f16":
ret = np.int16(ret).view(np.float16)
else: # multiple returns, assume all memref
# INVOKE
self.execution_engine.invoke(self.top_func_name, return_ptr, *arg_ptrs)
Expand All @@ -349,6 +356,8 @@ def __call__(self, *args):
ret_i = struct_array_to_int_array(
np_arr, bitwidth, res_type[0] == "i"
)
elif res_type == "f16":
ret_i = np.array(np_arr, dtype=np.int16).view(np.float16)
elif res_type.startswith("fixed") or res_type.startswith("ufixed"):
bitwidth, frac = get_bitwidth_and_frac_from_fixed(res_type)
ret_i = struct_array_to_int_array(
Expand Down
7 changes: 7 additions & 0 deletions allo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
RankedTensorType,
IntegerType,
IndexType,
F16Type,
F32Type,
F64Type,
)
Expand All @@ -18,6 +19,7 @@


np_supported_types = {
"f16": np.float16,
"f32": np.float32,
"f64": np.float64,
"i8": np.int8,
Expand All @@ -33,6 +35,9 @@


ctype_map = {
# ctypes.c_float16 does not exist
# similar implementation in _mlir/runtime/np_to_memref.py/F16
"f16": ctypes.c_int16,
"f32": ctypes.c_float,
"f64": ctypes.c_double,
"i8": ctypes.c_int8,
Expand Down Expand Up @@ -152,6 +157,8 @@ def get_dtype_and_shape_from_type(dtype):
return "index", tuple()
if IntegerType.isinstance(dtype):
return str(IntegerType(dtype)), tuple()
if F16Type.isinstance(dtype):
return str(F16Type(dtype)), tuple()
if F32Type.isinstance(dtype):
return str(F32Type(dtype)), tuple()
if F64Type.isinstance(dtype):
Expand Down
4 changes: 3 additions & 1 deletion mlir/lib/Translation/EmitTapaHLS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ static SmallString<16> getTypeName(Type valType) {
valType = arrayType.getElementType();

// Handle float types.
if (valType.isa<Float32Type>())
if (valType.isa<Float16Type>())
return SmallString<16>("half");
else if (valType.isa<Float32Type>())
return SmallString<16>("float");
else if (valType.isa<Float64Type>())
return SmallString<16>("double");
Expand Down
6 changes: 5 additions & 1 deletion mlir/lib/Translation/EmitVivadoHLS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ static SmallString<16> getTypeName(Type valType) {
valType = arrayType.getElementType();

// Handle float types.
if (valType.isa<Float32Type>())
if (valType.isa<Float16Type>())
// Page 222:
// https://www.amd.com/content/dam/xilinx/support/documents/sw_manuals/xilinx2020_2/ug902-vivado-high-level-synthesis.pdf
return SmallString<16>("half");
else if (valType.isa<Float32Type>())
return SmallString<16>("float");
else if (valType.isa<Float64Type>())
return SmallString<16>("double");
Expand Down
28 changes: 26 additions & 2 deletions tests/test_bitop.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
import numpy as np
import allo
from allo.ir.types import uint1, uint2, int32, uint8, uint32, UInt, float32
from allo.ir.types import uint1, uint2, int32, uint8, uint32, UInt, float16, float32


def test_scalar():
Expand Down Expand Up @@ -125,7 +125,7 @@ def kernel(A: int32, B: int32[11]):
assert bin(1234) == "0b" + "".join([str(np_B[i]) for i in range(10, -1, -1)])


def test_bitcast_uint2float():
def test_bitcast_uint2float32():
def kernel(A: uint32[10, 10]) -> float32[10, 10]:
B: float32[10, 10]
for i, j in allo.grid(10, 10):
Expand All @@ -146,6 +146,30 @@ def kernel(A: uint32[10, 10]) -> float32[10, 10]:
print("Passed!")


def test_bitcast_uint2float16():
def kernel(A: int32[10, 10]) -> float16[10, 10]:
B: float16[10, 10]
for i, j in allo.grid(10, 10):
B[i, j] = A[i, j][0:16].bitcast()
return B

s = allo.customize(kernel)
print(s.module)
mod = s.build()

A_np = np.random.randint(100, size=(10, 10)).astype(np.int32)
B_np = mod(A_np)
answer = np.frombuffer(A_np.astype(np.int16).tobytes(), np.float16).reshape(
(10, 10)
)
assert np.array_equal(B_np, answer)

code = str(s.build(target="vhls"))
print(code)
assert "union" in code and "half" in code
print("Passed!")


def test_bitcast_float2uint():
def kernel(A: float32[10, 10]) -> uint32[10, 10]:
B: uint32[10, 10]
Expand Down
40 changes: 40 additions & 0 deletions tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
bool,
uint1,
int32,
float16,
float32,
index,
)
Expand Down Expand Up @@ -322,6 +323,45 @@ def kernel[Ty]() -> int32:
print(s.module)


def test_fp16():
def kernel(a: float16) -> float16:
return a + 1

s = allo.customize(kernel)
assert "f16" in str(s.module)
mod = s.build()
assert mod(1.0) == kernel(1.0)


def test_fp16_array():
def kernel(A: float16[10]) -> float16[10]:
B: float16[10]
for i in range(10):
B[i] = A[i] + 1
return B

s = allo.customize(kernel)
assert "f16" in str(s.module)
mod = s.build()
A = np.random.rand(10).astype(np.float16)
B = mod(A)
np.testing.assert_allclose(B, A + 1, rtol=1e-5)


def test_fp16_array_inplace():
def kernel(A: float16[10]):
for i in range(10):
A[i] += 1

s = allo.customize(kernel)
assert "f16" in str(s.module)
mod = s.build()
A = np.random.rand(10).astype(np.float16)
res = A + 1
mod(A)
np.testing.assert_allclose(A, res, rtol=1e-5)


def test_select_typing():
def kernel(flt: float32, itg: int32) -> float32:
# if correctly typed, the select should have float32 result
Expand Down

0 comments on commit e8b597e

Please sign in to comment.