Skip to content

Commit

Permalink
Add placeholder code for bfloat16 in Python (halide#6849) (halide#6850)
Browse files Browse the repository at this point in the history
* Add placeholder code for bfloat16 in Python (halide#6849)

This is a no-op change; I just want to mark the place(s) in the Python bindings that need attention if/when it becomes possible to support bfloat16 in Python buffers.

* Update PyBinaryOperators.h
  • Loading branch information
steven-johnson authored and ardier committed Mar 3, 2024
1 parent 1278276 commit 3363c52
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 0 deletions.
1 change: 1 addition & 0 deletions python_bindings/src/PyBinaryOperators.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ HANDLE_SCALAR_TYPE(int8_t)
HANDLE_SCALAR_TYPE(int16_t)
HANDLE_SCALAR_TYPE(int32_t)
HANDLE_SCALAR_TYPE(int64_t)
// HANDLE_SCALAR_TYPE(bfloat16_t) TODO: https://github.com/halide/Halide/issues/6849
HANDLE_SCALAR_TYPE(float16_t)
HANDLE_SCALAR_TYPE(float)
HANDLE_SCALAR_TYPE(double)
Expand Down
24 changes: 24 additions & 0 deletions python_bindings/src/PyBuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ inline float16_t value_cast<float16_t>(const py::object &value) {
return float16_t(value.cast<double>());
}

// TODO: https://github.com/halide/Halide/issues/6849
// template<>
// inline bfloat16_t value_cast<bfloat16_t>(const py::object &value) {
// return bfloat16_t(value.cast<double>());
// }

template<typename T>
inline std::string format_descriptor() {
return py::format_descriptor<T>::format();
Expand All @@ -67,6 +73,12 @@ inline std::string format_descriptor<float16_t>() {
return "e";
}

// TODO: https://github.com/halide/Halide/issues/6849
// template<>
// inline std::string format_descriptor<bfloat16_t>() {
// return there-is-no-python-buffer-format-descriptor-for-bfloat16;
// }

void call_fill(Buffer<> &b, const py::object &value) {

#define HANDLE_BUFFER_TYPE(TYPE) \
Expand All @@ -84,6 +96,8 @@ void call_fill(Buffer<> &b, const py::object &value) {
HANDLE_BUFFER_TYPE(int16_t)
HANDLE_BUFFER_TYPE(int32_t)
HANDLE_BUFFER_TYPE(int64_t)
// TODO: https://github.com/halide/Halide/issues/6849
// HANDLE_BUFFER_TYPE(bfloat16_t)
HANDLE_BUFFER_TYPE(float16_t)
HANDLE_BUFFER_TYPE(float)
HANDLE_BUFFER_TYPE(double)
Expand All @@ -109,6 +123,8 @@ bool call_all_equal(Buffer<> &b, const py::object &value) {
HANDLE_BUFFER_TYPE(int16_t)
HANDLE_BUFFER_TYPE(int32_t)
HANDLE_BUFFER_TYPE(int64_t)
// TODO: https://github.com/halide/Halide/issues/6849
// HANDLE_BUFFER_TYPE(bfloat16_t)
HANDLE_BUFFER_TYPE(float16_t)
HANDLE_BUFFER_TYPE(float)
HANDLE_BUFFER_TYPE(double)
Expand All @@ -133,6 +149,8 @@ std::string type_to_format_descriptor(const Type &type) {
HANDLE_BUFFER_TYPE(int16_t)
HANDLE_BUFFER_TYPE(int32_t)
HANDLE_BUFFER_TYPE(int64_t)
// TODO: https://github.com/halide/Halide/issues/6849
// HANDLE_BUFFER_TYPE(bfloat16_t)
HANDLE_BUFFER_TYPE(float16_t)
HANDLE_BUFFER_TYPE(float)
HANDLE_BUFFER_TYPE(double)
Expand All @@ -159,6 +177,8 @@ Type format_descriptor_to_type(const std::string &fd) {
HANDLE_BUFFER_TYPE(int16_t)
HANDLE_BUFFER_TYPE(int32_t)
HANDLE_BUFFER_TYPE(int64_t)
// TODO: https://github.com/halide/Halide/issues/6849
// HANDLE_BUFFER_TYPE(bfloat16_t)
HANDLE_BUFFER_TYPE(float16_t)
HANDLE_BUFFER_TYPE(float)
HANDLE_BUFFER_TYPE(double)
Expand Down Expand Up @@ -196,6 +216,8 @@ py::object buffer_getitem_operator(Buffer<> &buf, const std::vector<int> &pos) {
HANDLE_BUFFER_TYPE(int16_t)
HANDLE_BUFFER_TYPE(int32_t)
HANDLE_BUFFER_TYPE(int64_t)
// TODO: https://github.com/halide/Halide/issues/6849
// HANDLE_BUFFER_TYPE(bfloat16_t)
HANDLE_BUFFER_TYPE(float16_t)
HANDLE_BUFFER_TYPE(float)
HANDLE_BUFFER_TYPE(double)
Expand Down Expand Up @@ -224,6 +246,8 @@ py::object buffer_setitem_operator(Buffer<> &buf, const std::vector<int> &pos, c
HANDLE_BUFFER_TYPE(int16_t)
HANDLE_BUFFER_TYPE(int32_t)
HANDLE_BUFFER_TYPE(int64_t)
// TODO: https://github.com/halide/Halide/issues/6849
// HANDLE_BUFFER_TYPE(bfloat16_t)
HANDLE_BUFFER_TYPE(float16_t)
HANDLE_BUFFER_TYPE(float)
HANDLE_BUFFER_TYPE(double)
Expand Down
16 changes: 16 additions & 0 deletions python_bindings/test/correctness/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,20 @@ def test_float16():
hl_img = hl.Buffer(array_in)
array_out = np.array(hl_img, copy = False)

# TODO: https://github.com/halide/Halide/issues/6849
# def test_bfloat16():
# try:
# from tensorflow.python.lib.core import _pywrap_bfloat16
# bfloat16 = _pywrap_bfloat16.TF_bfloat16_type()
# array_in = np.zeros((256, 256, 3), dtype=bfloat16, order='F')
# hl_img = hl.Buffer(array_in)
# array_out = np.array(hl_img, copy = False)
# except ModuleNotFoundError as e:
# print("skipping test_bfloat16() because tensorflow was not found: %s" % str(e))
# return
# else:
# assert False, "This should not happen"

def test_int64():
array_in = np.zeros((256, 256, 3), dtype=np.int64, order='F')
hl_img = hl.Buffer(array_in)
Expand Down Expand Up @@ -279,6 +293,8 @@ def test_scalar_buffers():
test_for_each_element()
test_fill_all_equal()
test_bufferinfo_sharing()
# TODO: https://github.com/halide/Halide/issues/6849
# test_bfloat16()
test_float16()
test_int64()
test_reorder()
Expand Down

0 comments on commit 3363c52

Please sign in to comment.