From 6fc12152fde4bb7cd5042f3323b6e110388dadec Mon Sep 17 00:00:00 2001 From: Steve Suzuki Date: Tue, 4 Oct 2022 22:25:16 +0100 Subject: [PATCH] Add support for float16 buffer in python extension (#7060) --- .../test/correctness/addconstant_test.py | 18 +++++++++++------- .../generators/addconstantcpp_generator.cpp | 3 +++ .../test/generators/addconstantpy_generator.py | 3 +++ src/PythonExtensionGen.cpp | 10 ++++++---- 4 files changed, 23 insertions(+), 11 deletions(-) diff --git a/python_bindings/test/correctness/addconstant_test.py b/python_bindings/test/correctness/addconstant_test.py index 0f1eb3ede059..fb9ab90f06cc 100644 --- a/python_bindings/test/correctness/addconstant_test.py +++ b/python_bindings/test/correctness/addconstant_test.py @@ -38,6 +38,7 @@ def test(addconstant_impl_func, offset): input_i64 = numpy.array([0, -4294967296, 8589934592], dtype=numpy.int64) input_float = numpy.array([3.14, 2.718, 1.618], dtype=numpy.float32) input_double = numpy.array([3.14, 2.718, 1.618], dtype=numpy.float64) + input_half = numpy.array([3.14, 2.718, 1.618], dtype=numpy.float16) input_2d = numpy.array([[1, 2, 3], [4, 5, 6]], dtype=numpy.int8, order='F') input_3d = numpy.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=numpy.int8) @@ -51,6 +52,7 @@ def test(addconstant_impl_func, offset): output_i64 = numpy.zeros((3,), dtype=numpy.int64) output_float = numpy.zeros((3,), dtype=numpy.float32) output_double = numpy.zeros((3,), dtype=numpy.float64) + output_half = numpy.zeros((3,), dtype=numpy.float16) output_2d = numpy.zeros((2, 3), dtype=numpy.int8, order='F') output_3d = numpy.zeros((2, 2, 2), dtype=numpy.int8) @@ -61,10 +63,10 @@ def test(addconstant_impl_func, offset): scalar_float, scalar_double, input_u8, input_u16, input_u32, input_u64, input_i8, input_i16, input_i32, input_i64, - input_float, input_double, input_2d, input_3d, + input_float, input_double, input_half, input_2d, input_3d, output_u8, output_u16, output_u32, output_u64, output_i8, output_i16, output_i32, output_i64, - output_float, output_double, output_2d, output_3d, + output_float, output_double, output_half, output_2d, output_3d, ) combinations = [ @@ -78,11 +80,13 @@ def test(addconstant_impl_func, offset): ("i64", input_i64, output_i64, scalar_i64), ("float", input_float, output_float, scalar_float), ("double", input_double, output_double, scalar_double), + ("half", input_half, output_half, scalar_float), ] for _, input, output, scalar in combinations: for i, o in zip(input, output): - assert abs(o - (i + scalar)) < ERROR_THRESHOLD + scalar_as_numpy = numpy.array(scalar).astype(input.dtype) + assert abs(o - (i + scalar_as_numpy)) < ERROR_THRESHOLD for x in range(input_2d.shape[0]): for y in range(input_2d.shape[1]): @@ -103,10 +107,10 @@ def test(addconstant_impl_func, offset): scalar_float, scalar_double, input_u8, input_u16, input_u32, input_u64, input_i8, input_i16, input_i32, input_i64, - input_float, input_double, input_2d, input_3d, + input_float, input_double, input_half, input_2d, input_3d, output_u8, output_u16, output_u32, output_u64, output_i8, output_i16, output_i32, output_i64, - output_float, output_double, output_2d, output_3d, + output_float, output_double, output_half, output_2d, output_3d, ) except RuntimeError as e: assert str(e) == "Halide Runtime Error: -27", e @@ -125,10 +129,10 @@ def test(addconstant_impl_func, offset): scalar_float, scalar_double, input_u8, input_u16, input_u32, input_u64, input_i8, input_i16, input_i32, input_i64, - input_float, input_double, input_2d, input_3d, + input_float, input_double, input_half, input_2d, input_3d, output_u8, output_u16, output_u32, output_u64, output_i8, output_i16, output_i32, output_i64, - output_float, output_double, output_2d, output_3d, + output_float, output_double, output_half, output_2d, output_3d, ) except RuntimeError as e: assert str(e) == "Halide Runtime Error: -27", e diff --git a/python_bindings/test/generators/addconstantcpp_generator.cpp b/python_bindings/test/generators/addconstantcpp_generator.cpp index 577de4dab890..e6766917ec0f 100644 --- a/python_bindings/test/generators/addconstantcpp_generator.cpp +++ b/python_bindings/test/generators/addconstantcpp_generator.cpp @@ -29,6 +29,7 @@ class AddConstantGenerator : public Halide::Generator { Input> input_int64{"input_int64"}; Input> input_float{"input_float"}; Input> input_double{"input_double"}; + Input> input_half{"input_half"}; Input> input_2d{"input_2d"}; Input> input_3d{"input_3d"}; @@ -42,6 +43,7 @@ class AddConstantGenerator : public Halide::Generator { Output> output_int64{"output_int64"}; Output> output_float{"output_float"}; Output> output_double{"output_double"}; + Output> output_half{"output_half"}; Output> output_2d{"buffer_2d"}; Output> output_3d{"buffer_3d"}; @@ -61,6 +63,7 @@ class AddConstantGenerator : public Halide::Generator { output_int64(x) = input_int64(x) + scalar_int64; output_float(x) = input_float(x) + scalar_float; output_double(x) = input_double(x) + scalar_double; + output_half(x) = input_half(x) + cast(Float(16), scalar_float); output_2d(x, y) = input_2d(x, y) + scalar_int8; output_3d(x, y, z) = input_3d(x, y, z) + scalar_int8 + extra_int; } diff --git a/python_bindings/test/generators/addconstantpy_generator.py b/python_bindings/test/generators/addconstantpy_generator.py index c48476e766a2..77d2b4a21708 100644 --- a/python_bindings/test/generators/addconstantpy_generator.py +++ b/python_bindings/test/generators/addconstantpy_generator.py @@ -34,6 +34,7 @@ class AddConstantGenerator: input_int64 = hl.InputBuffer(hl.Int(64), 1) input_float = hl.InputBuffer(hl.Float(32), 1) input_double = hl.InputBuffer(hl.Float(64), 1) + input_half = hl.InputBuffer(hl.Float(16), 1) input_2d = hl.InputBuffer(hl.Int(8), 2) input_3d = hl.InputBuffer(hl.Int(8), 3) @@ -47,6 +48,7 @@ class AddConstantGenerator: output_int64 = hl.OutputBuffer(hl.Int(64), 1) output_float = hl.OutputBuffer(hl.Float(32), 1) output_double = hl.OutputBuffer(hl.Float(64), 1) + output_half = hl.OutputBuffer(hl.Float(16), 1) output_2d = hl.OutputBuffer(hl.Int(8), 2) output_3d = hl.OutputBuffer(hl.Int(8), 3) @@ -65,6 +67,7 @@ def generate(self): g.output_int64[x] = g.input_int64[x] + g.scalar_int64 g.output_float[x] = g.input_float[x] + g.scalar_float g.output_double[x] = g.input_double[x] + g.scalar_double + g.output_half[x] = g.input_half[x] + hl.cast(hl.Float(16), g.scalar_float) g.output_2d[x, y] = g.input_2d[x, y] + g.scalar_int8 g.output_3d[x, y, z] = g.input_3d[x, y, z] + g.scalar_int8 + g.extra_int diff --git a/src/PythonExtensionGen.cpp b/src/PythonExtensionGen.cpp index d454743df9fa..266ca27c2083 100644 --- a/src/PythonExtensionGen.cpp +++ b/src/PythonExtensionGen.cpp @@ -51,7 +51,7 @@ bool can_convert(const LoweredArgument *arg) { if (arg->type.is_vector()) { return false; } - if (arg->type.is_float() && arg->type.bits() != 32 && arg->type.bits() != 64) { + if (arg->type.is_float() && arg->type.bits() != 32 && arg->type.bits() != 64 && arg->type.bits() != 16) { return false; } if ((arg->type.is_int() || arg->type.is_uint()) && @@ -77,6 +77,8 @@ std::pair print_type(const LoweredArgument *arg) { return std::make_pair("f", "float"); } else if (arg->type.is_float() && arg->type.bits() == 64) { return std::make_pair("d", "double"); + // } else if (arg->type.is_float() && arg->type.bits() == 16) { + // TODO: can't pass scalar float16 type } else if (arg->type.bits() == 1) { // "b" expects an unsigned char, so we assume that bool == uint8. return std::make_pair("b", "bool"); @@ -221,8 +223,8 @@ bool unpack_buffer(PyObject *py_obj, while (strchr("@<>!=", *p)) { p++; // ignore little/bit endian (and alignment) } - if (*p == 'f' || *p == 'd') { - // 'f' and 'd' are float and double, respectively. + if (*p == 'f' || *p == 'd' || *p == 'e') { + // 'f', 'd', and 'e' are float, double, and half, respectively. halide_buf.type.code = halide_type_float; } else if (*p >= 'a' && *p <= 'z') { // lowercase is signed int. @@ -231,7 +233,7 @@ bool unpack_buffer(PyObject *py_obj, // uppercase is unsigned int. halide_buf.type.code = halide_type_uint; } - const char *type_codes = "bBhHiIlLqQfd"; // integers and floats + const char *type_codes = "bBhHiIlLqQfde"; // integers and floats if (*p == '?') { // Special-case bool, so that it is a distinct type vs uint8_t // (even though the memory layout is identical)