Skip to content

Commit

Permalink
Add support for float16 buffer in python extension (halide#7060)
Browse files Browse the repository at this point in the history
  • Loading branch information
stevesuzuki-arm authored and ardier committed Mar 3, 2024
1 parent 039432f commit 6fc1215
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 11 deletions.
18 changes: 11 additions & 7 deletions python_bindings/test/correctness/addconstant_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def test(addconstant_impl_func, offset):
input_i64 = numpy.array([0, -4294967296, 8589934592], dtype=numpy.int64)
input_float = numpy.array([3.14, 2.718, 1.618], dtype=numpy.float32)
input_double = numpy.array([3.14, 2.718, 1.618], dtype=numpy.float64)
input_half = numpy.array([3.14, 2.718, 1.618], dtype=numpy.float16)
input_2d = numpy.array([[1, 2, 3], [4, 5, 6]], dtype=numpy.int8, order='F')
input_3d = numpy.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=numpy.int8)

Expand All @@ -51,6 +52,7 @@ def test(addconstant_impl_func, offset):
output_i64 = numpy.zeros((3,), dtype=numpy.int64)
output_float = numpy.zeros((3,), dtype=numpy.float32)
output_double = numpy.zeros((3,), dtype=numpy.float64)
output_half = numpy.zeros((3,), dtype=numpy.float16)
output_2d = numpy.zeros((2, 3), dtype=numpy.int8, order='F')
output_3d = numpy.zeros((2, 2, 2), dtype=numpy.int8)

Expand All @@ -61,10 +63,10 @@ def test(addconstant_impl_func, offset):
scalar_float, scalar_double,
input_u8, input_u16, input_u32, input_u64,
input_i8, input_i16, input_i32, input_i64,
input_float, input_double, input_2d, input_3d,
input_float, input_double, input_half, input_2d, input_3d,
output_u8, output_u16, output_u32, output_u64,
output_i8, output_i16, output_i32, output_i64,
output_float, output_double, output_2d, output_3d,
output_float, output_double, output_half, output_2d, output_3d,
)

combinations = [
Expand All @@ -78,11 +80,13 @@ def test(addconstant_impl_func, offset):
("i64", input_i64, output_i64, scalar_i64),
("float", input_float, output_float, scalar_float),
("double", input_double, output_double, scalar_double),
("half", input_half, output_half, scalar_float),
]

for _, input, output, scalar in combinations:
for i, o in zip(input, output):
assert abs(o - (i + scalar)) < ERROR_THRESHOLD
scalar_as_numpy = numpy.array(scalar).astype(input.dtype)
assert abs(o - (i + scalar_as_numpy)) < ERROR_THRESHOLD

for x in range(input_2d.shape[0]):
for y in range(input_2d.shape[1]):
Expand All @@ -103,10 +107,10 @@ def test(addconstant_impl_func, offset):
scalar_float, scalar_double,
input_u8, input_u16, input_u32, input_u64,
input_i8, input_i16, input_i32, input_i64,
input_float, input_double, input_2d, input_3d,
input_float, input_double, input_half, input_2d, input_3d,
output_u8, output_u16, output_u32, output_u64,
output_i8, output_i16, output_i32, output_i64,
output_float, output_double, output_2d, output_3d,
output_float, output_double, output_half, output_2d, output_3d,
)
except RuntimeError as e:
assert str(e) == "Halide Runtime Error: -27", e
Expand All @@ -125,10 +129,10 @@ def test(addconstant_impl_func, offset):
scalar_float, scalar_double,
input_u8, input_u16, input_u32, input_u64,
input_i8, input_i16, input_i32, input_i64,
input_float, input_double, input_2d, input_3d,
input_float, input_double, input_half, input_2d, input_3d,
output_u8, output_u16, output_u32, output_u64,
output_i8, output_i16, output_i32, output_i64,
output_float, output_double, output_2d, output_3d,
output_float, output_double, output_half, output_2d, output_3d,
)
except RuntimeError as e:
assert str(e) == "Halide Runtime Error: -27", e
Expand Down
3 changes: 3 additions & 0 deletions python_bindings/test/generators/addconstantcpp_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class AddConstantGenerator : public Halide::Generator<AddConstantGenerator> {
Input<Buffer<int64_t, 1>> input_int64{"input_int64"};
Input<Buffer<float, 1>> input_float{"input_float"};
Input<Buffer<double, 1>> input_double{"input_double"};
Input<Buffer<float16_t, 1>> input_half{"input_half"};
Input<Buffer<int8_t, 2>> input_2d{"input_2d"};
Input<Buffer<int8_t, 3>> input_3d{"input_3d"};

Expand All @@ -42,6 +43,7 @@ class AddConstantGenerator : public Halide::Generator<AddConstantGenerator> {
Output<Buffer<int64_t, 1>> output_int64{"output_int64"};
Output<Buffer<float, 1>> output_float{"output_float"};
Output<Buffer<double, 1>> output_double{"output_double"};
Output<Buffer<float16_t, 1>> output_half{"output_half"};
Output<Buffer<int8_t, 2>> output_2d{"buffer_2d"};
Output<Buffer<int8_t, 3>> output_3d{"buffer_3d"};

Expand All @@ -61,6 +63,7 @@ class AddConstantGenerator : public Halide::Generator<AddConstantGenerator> {
output_int64(x) = input_int64(x) + scalar_int64;
output_float(x) = input_float(x) + scalar_float;
output_double(x) = input_double(x) + scalar_double;
output_half(x) = input_half(x) + cast(Float(16), scalar_float);
output_2d(x, y) = input_2d(x, y) + scalar_int8;
output_3d(x, y, z) = input_3d(x, y, z) + scalar_int8 + extra_int;
}
Expand Down
3 changes: 3 additions & 0 deletions python_bindings/test/generators/addconstantpy_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class AddConstantGenerator:
input_int64 = hl.InputBuffer(hl.Int(64), 1)
input_float = hl.InputBuffer(hl.Float(32), 1)
input_double = hl.InputBuffer(hl.Float(64), 1)
input_half = hl.InputBuffer(hl.Float(16), 1)
input_2d = hl.InputBuffer(hl.Int(8), 2)
input_3d = hl.InputBuffer(hl.Int(8), 3)

Expand All @@ -47,6 +48,7 @@ class AddConstantGenerator:
output_int64 = hl.OutputBuffer(hl.Int(64), 1)
output_float = hl.OutputBuffer(hl.Float(32), 1)
output_double = hl.OutputBuffer(hl.Float(64), 1)
output_half = hl.OutputBuffer(hl.Float(16), 1)
output_2d = hl.OutputBuffer(hl.Int(8), 2)
output_3d = hl.OutputBuffer(hl.Int(8), 3)

Expand All @@ -65,6 +67,7 @@ def generate(self):
g.output_int64[x] = g.input_int64[x] + g.scalar_int64
g.output_float[x] = g.input_float[x] + g.scalar_float
g.output_double[x] = g.input_double[x] + g.scalar_double
g.output_half[x] = g.input_half[x] + hl.cast(hl.Float(16), g.scalar_float)
g.output_2d[x, y] = g.input_2d[x, y] + g.scalar_int8
g.output_3d[x, y, z] = g.input_3d[x, y, z] + g.scalar_int8 + g.extra_int

Expand Down
10 changes: 6 additions & 4 deletions src/PythonExtensionGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ bool can_convert(const LoweredArgument *arg) {
if (arg->type.is_vector()) {
return false;
}
if (arg->type.is_float() && arg->type.bits() != 32 && arg->type.bits() != 64) {
if (arg->type.is_float() && arg->type.bits() != 32 && arg->type.bits() != 64 && arg->type.bits() != 16) {
return false;
}
if ((arg->type.is_int() || arg->type.is_uint()) &&
Expand All @@ -77,6 +77,8 @@ std::pair<string, string> print_type(const LoweredArgument *arg) {
return std::make_pair("f", "float");
} else if (arg->type.is_float() && arg->type.bits() == 64) {
return std::make_pair("d", "double");
// } else if (arg->type.is_float() && arg->type.bits() == 16) {
// TODO: can't pass scalar float16 type
} else if (arg->type.bits() == 1) {
// "b" expects an unsigned char, so we assume that bool == uint8.
return std::make_pair("b", "bool");
Expand Down Expand Up @@ -221,8 +223,8 @@ bool unpack_buffer(PyObject *py_obj,
while (strchr("@<>!=", *p)) {
p++; // ignore little/bit endian (and alignment)
}
if (*p == 'f' || *p == 'd') {
// 'f' and 'd' are float and double, respectively.
if (*p == 'f' || *p == 'd' || *p == 'e') {
// 'f', 'd', and 'e' are float, double, and half, respectively.
halide_buf.type.code = halide_type_float;
} else if (*p >= 'a' && *p <= 'z') {
// lowercase is signed int.
Expand All @@ -231,7 +233,7 @@ bool unpack_buffer(PyObject *py_obj,
// uppercase is unsigned int.
halide_buf.type.code = halide_type_uint;
}
const char *type_codes = "bBhHiIlLqQfd"; // integers and floats
const char *type_codes = "bBhHiIlLqQfde"; // integers and floats
if (*p == '?') {
// Special-case bool, so that it is a distinct type vs uint8_t
// (even though the memory layout is identical)
Expand Down

0 comments on commit 6fc1215

Please sign in to comment.