diff --git a/src/runtime/HalideBuffer.h b/src/runtime/HalideBuffer.h index 4ac2317278bc..7f914d0a4ff2 100644 --- a/src/runtime/HalideBuffer.h +++ b/src/runtime/HalideBuffer.h @@ -142,8 +142,8 @@ struct AllInts : std::false_type {}; template struct AllInts : std::false_type {}; -// A helper to detect if there are any zeros in a container namespace Internal { +// A helper to detect if there are any zeros in a container template bool any_zero(const Container &c) { for (int i : c) { @@ -153,6 +153,11 @@ bool any_zero(const Container &c) { } return false; } + +struct DefaultAllocatorFns { + static inline void *(*default_allocate_fn)(size_t) = nullptr; + static inline void (*default_deallocate_fn)(void *) = nullptr; +}; } // namespace Internal /** A struct acting as a header for allocations owned by the Buffer @@ -711,6 +716,13 @@ class Buffer { } public: + static void set_default_allocate_fn(void *(*allocate_fn)(size_t)) { + Internal::DefaultAllocatorFns::default_allocate_fn = allocate_fn; + } + static void set_default_deallocate_fn(void (*deallocate_fn)(void *)) { + Internal::DefaultAllocatorFns::default_deallocate_fn = deallocate_fn; + } + /** Determine if a Buffer can be constructed from some other Buffer type. * If this can be determined at compile time, fail with a static assert; otherwise * return a boolean based on runtime typing. */ @@ -893,7 +905,7 @@ class Buffer { #if HALIDE_RUNTIME_BUFFER_USE_ALIGNED_ALLOC // Only use aligned_alloc() if no custom allocators are specified. - if (!allocate_fn && !deallocate_fn) { + if (!allocate_fn && !deallocate_fn && !Internal::DefaultAllocatorFns::default_allocate_fn && !Internal::DefaultAllocatorFns::default_deallocate_fn) { // As a practical matter, sizeof(AllocationHeader) is going to be no more than 16 bytes // on any supported platform, so we will just overallocate by 'alignment' // so that the user storage also starts at an aligned point. This is a bit @@ -908,10 +920,16 @@ class Buffer { // else fall thru #endif if (!allocate_fn) { - allocate_fn = malloc; + allocate_fn = Internal::DefaultAllocatorFns::default_allocate_fn; + if (!allocate_fn) { + allocate_fn = malloc; + } } if (!deallocate_fn) { - deallocate_fn = free; + deallocate_fn = Internal::DefaultAllocatorFns::default_deallocate_fn; + if (!deallocate_fn) { + deallocate_fn = free; + } } static_assert(sizeof(AllocationHeader) <= alignment); diff --git a/test/correctness/halide_buffer.cpp b/test/correctness/halide_buffer.cpp index 6c35f4b7a409..accaf6f6bb3e 100644 --- a/test/correctness/halide_buffer.cpp +++ b/test/correctness/halide_buffer.cpp @@ -6,6 +6,22 @@ using namespace Halide::Runtime; +static void *my_malloced_addr = nullptr; +static int my_malloc_count = 0; +static void *my_freed_addr = nullptr; +static int my_free_count = 0; +void *my_malloc(size_t size) { + void *ptr = malloc(size); + my_malloced_addr = ptr; + my_malloc_count++; + return ptr; +} +void my_free(void *ptr) { + my_freed_addr = ptr; + my_free_count++; + free(ptr); +} + template void check_equal_shape(const Buffer &a, const Buffer &b) { if (a.dimensions() != b.dimensions()) abort(); @@ -515,6 +531,23 @@ int main(int argc, char **argv) { assert(b.dim(3).stride() == b2.dim(3).stride()); } + { + // Test setting default allocate and deallocate functions. + Buffer<>::set_default_allocate_fn(my_malloc); + Buffer<>::set_default_deallocate_fn(my_free); + + assert(my_malloc_count == 0); + assert(my_free_count == 0); + auto b = Buffer(5, 4).fill(1); + assert(my_malloced_addr != nullptr && my_malloced_addr < b.data()); + assert(my_malloc_count == 1); + assert(my_free_count == 0); + b.deallocate(); + assert(my_malloc_count == 1); + assert(my_free_count == 1); + assert(my_malloced_addr == my_freed_addr); + } + printf("Success!\n"); return 0; }