Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for float16 buffer in python extension #7060

Merged
merged 1 commit into from
Oct 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions python_bindings/test/correctness/addconstant_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def test(addconstant_impl_func, offset):
input_i64 = numpy.array([0, -4294967296, 8589934592], dtype=numpy.int64)
input_float = numpy.array([3.14, 2.718, 1.618], dtype=numpy.float32)
input_double = numpy.array([3.14, 2.718, 1.618], dtype=numpy.float64)
input_half = numpy.array([3.14, 2.718, 1.618], dtype=numpy.float16)
input_2d = numpy.array([[1, 2, 3], [4, 5, 6]], dtype=numpy.int8, order='F')
input_3d = numpy.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=numpy.int8)

Expand All @@ -51,6 +52,7 @@ def test(addconstant_impl_func, offset):
output_i64 = numpy.zeros((3,), dtype=numpy.int64)
output_float = numpy.zeros((3,), dtype=numpy.float32)
output_double = numpy.zeros((3,), dtype=numpy.float64)
output_half = numpy.zeros((3,), dtype=numpy.float16)
output_2d = numpy.zeros((2, 3), dtype=numpy.int8, order='F')
output_3d = numpy.zeros((2, 2, 2), dtype=numpy.int8)

Expand All @@ -61,10 +63,10 @@ def test(addconstant_impl_func, offset):
scalar_float, scalar_double,
input_u8, input_u16, input_u32, input_u64,
input_i8, input_i16, input_i32, input_i64,
input_float, input_double, input_2d, input_3d,
input_float, input_double, input_half, input_2d, input_3d,
output_u8, output_u16, output_u32, output_u64,
output_i8, output_i16, output_i32, output_i64,
output_float, output_double, output_2d, output_3d,
output_float, output_double, output_half, output_2d, output_3d,
)

combinations = [
Expand All @@ -78,11 +80,13 @@ def test(addconstant_impl_func, offset):
("i64", input_i64, output_i64, scalar_i64),
("float", input_float, output_float, scalar_float),
("double", input_double, output_double, scalar_double),
("half", input_half, output_half, scalar_float),
]

for _, input, output, scalar in combinations:
for i, o in zip(input, output):
assert abs(o - (i + scalar)) < ERROR_THRESHOLD
scalar_as_numpy = numpy.array(scalar).astype(input.dtype)
assert abs(o - (i + scalar_as_numpy)) < ERROR_THRESHOLD

for x in range(input_2d.shape[0]):
for y in range(input_2d.shape[1]):
Expand All @@ -103,10 +107,10 @@ def test(addconstant_impl_func, offset):
scalar_float, scalar_double,
input_u8, input_u16, input_u32, input_u64,
input_i8, input_i16, input_i32, input_i64,
input_float, input_double, input_2d, input_3d,
input_float, input_double, input_half, input_2d, input_3d,
output_u8, output_u16, output_u32, output_u64,
output_i8, output_i16, output_i32, output_i64,
output_float, output_double, output_2d, output_3d,
output_float, output_double, output_half, output_2d, output_3d,
)
except RuntimeError as e:
assert str(e) == "Halide Runtime Error: -27", e
Expand All @@ -125,10 +129,10 @@ def test(addconstant_impl_func, offset):
scalar_float, scalar_double,
input_u8, input_u16, input_u32, input_u64,
input_i8, input_i16, input_i32, input_i64,
input_float, input_double, input_2d, input_3d,
input_float, input_double, input_half, input_2d, input_3d,
output_u8, output_u16, output_u32, output_u64,
output_i8, output_i16, output_i32, output_i64,
output_float, output_double, output_2d, output_3d,
output_float, output_double, output_half, output_2d, output_3d,
)
except RuntimeError as e:
assert str(e) == "Halide Runtime Error: -27", e
Expand Down
3 changes: 3 additions & 0 deletions python_bindings/test/generators/addconstantcpp_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class AddConstantGenerator : public Halide::Generator<AddConstantGenerator> {
Input<Buffer<int64_t, 1>> input_int64{"input_int64"};
Input<Buffer<float, 1>> input_float{"input_float"};
Input<Buffer<double, 1>> input_double{"input_double"};
Input<Buffer<float16_t, 1>> input_half{"input_half"};
Input<Buffer<int8_t, 2>> input_2d{"input_2d"};
Input<Buffer<int8_t, 3>> input_3d{"input_3d"};

Expand All @@ -42,6 +43,7 @@ class AddConstantGenerator : public Halide::Generator<AddConstantGenerator> {
Output<Buffer<int64_t, 1>> output_int64{"output_int64"};
Output<Buffer<float, 1>> output_float{"output_float"};
Output<Buffer<double, 1>> output_double{"output_double"};
Output<Buffer<float16_t, 1>> output_half{"output_half"};
Output<Buffer<int8_t, 2>> output_2d{"buffer_2d"};
Output<Buffer<int8_t, 3>> output_3d{"buffer_3d"};

Expand All @@ -61,6 +63,7 @@ class AddConstantGenerator : public Halide::Generator<AddConstantGenerator> {
output_int64(x) = input_int64(x) + scalar_int64;
output_float(x) = input_float(x) + scalar_float;
output_double(x) = input_double(x) + scalar_double;
output_half(x) = input_half(x) + cast(Float(16), scalar_float);
output_2d(x, y) = input_2d(x, y) + scalar_int8;
output_3d(x, y, z) = input_3d(x, y, z) + scalar_int8 + extra_int;
}
Expand Down
3 changes: 3 additions & 0 deletions python_bindings/test/generators/addconstantpy_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class AddConstantGenerator:
input_int64 = hl.InputBuffer(hl.Int(64), 1)
input_float = hl.InputBuffer(hl.Float(32), 1)
input_double = hl.InputBuffer(hl.Float(64), 1)
input_half = hl.InputBuffer(hl.Float(16), 1)
input_2d = hl.InputBuffer(hl.Int(8), 2)
input_3d = hl.InputBuffer(hl.Int(8), 3)

Expand All @@ -47,6 +48,7 @@ class AddConstantGenerator:
output_int64 = hl.OutputBuffer(hl.Int(64), 1)
output_float = hl.OutputBuffer(hl.Float(32), 1)
output_double = hl.OutputBuffer(hl.Float(64), 1)
output_half = hl.OutputBuffer(hl.Float(16), 1)
output_2d = hl.OutputBuffer(hl.Int(8), 2)
output_3d = hl.OutputBuffer(hl.Int(8), 3)

Expand All @@ -65,6 +67,7 @@ def generate(self):
g.output_int64[x] = g.input_int64[x] + g.scalar_int64
g.output_float[x] = g.input_float[x] + g.scalar_float
g.output_double[x] = g.input_double[x] + g.scalar_double
g.output_half[x] = g.input_half[x] + hl.cast(hl.Float(16), g.scalar_float)
g.output_2d[x, y] = g.input_2d[x, y] + g.scalar_int8
g.output_3d[x, y, z] = g.input_3d[x, y, z] + g.scalar_int8 + g.extra_int

Expand Down
10 changes: 6 additions & 4 deletions src/PythonExtensionGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ bool can_convert(const LoweredArgument *arg) {
if (arg->type.is_vector()) {
return false;
}
if (arg->type.is_float() && arg->type.bits() != 32 && arg->type.bits() != 64) {
if (arg->type.is_float() && arg->type.bits() != 32 && arg->type.bits() != 64 && arg->type.bits() != 16) {
return false;
}
if ((arg->type.is_int() || arg->type.is_uint()) &&
Expand All @@ -77,6 +77,8 @@ std::pair<string, string> print_type(const LoweredArgument *arg) {
return std::make_pair("f", "float");
} else if (arg->type.is_float() && arg->type.bits() == 64) {
return std::make_pair("d", "double");
// } else if (arg->type.is_float() && arg->type.bits() == 16) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Formatting nit: de-indent by 4 spaces here so indentation matches

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, clang-format changed it in this way. Should I add // clang-format off here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ugh. In that case, no, leave it as-is.

// TODO: can't pass scalar float16 type
} else if (arg->type.bits() == 1) {
// "b" expects an unsigned char, so we assume that bool == uint8.
return std::make_pair("b", "bool");
Expand Down Expand Up @@ -221,8 +223,8 @@ bool unpack_buffer(PyObject *py_obj,
while (strchr("@<>!=", *p)) {
p++; // ignore little/bit endian (and alignment)
}
if (*p == 'f' || *p == 'd') {
// 'f' and 'd' are float and double, respectively.
if (*p == 'f' || *p == 'd' || *p == 'e') {
// 'f', 'd', and 'e' are float, double, and half, respectively.
halide_buf.type.code = halide_type_float;
} else if (*p >= 'a' && *p <= 'z') {
// lowercase is signed int.
Expand All @@ -231,7 +233,7 @@ bool unpack_buffer(PyObject *py_obj,
// uppercase is unsigned int.
halide_buf.type.code = halide_type_uint;
}
const char *type_codes = "bBhHiIlLqQfd"; // integers and floats
const char *type_codes = "bBhHiIlLqQfde"; // integers and floats
if (*p == '?') {
// Special-case bool, so that it is a distinct type vs uint8_t
// (even though the memory layout is identical)
Expand Down