From 4f573bf438b9ca664bb763ab89c913dc35a27de6 Mon Sep 17 00:00:00 2001 From: Volodymyr Kysenko Date: Wed, 27 Oct 2021 19:05:29 -0700 Subject: [PATCH 1/4] Add missing widening_absd patterns (#6359) * Add missing widening_absd patterns * Add a comment --- src/CodeGen_ARM.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/CodeGen_ARM.cpp b/src/CodeGen_ARM.cpp index 0c6208cbb7dd..02b2ebac4bf2 100644 --- a/src/CodeGen_ARM.cpp +++ b/src/CodeGen_ARM.cpp @@ -580,15 +580,19 @@ const ArmIntrinsic intrinsic_defs[] = { {nullptr, "udot.v4i32.v16i8", UInt(32, 4), "dot_product", {UInt(32, 4), UInt(8, 16), UInt(8, 16)}, ArmIntrinsic::NoMangle}, // ABDL - Widening absolute difference - // Need to be able to handle both signed and unsigned outputs for signed inputs. + // The ARM backend folds both signed and unsigned widening casts of absd to a widening_absd, so we need to handle both signed and + // unsigned input and return types. {"vabdl_i8x8", "vabdl_i8x8", Int(16, 8), "widening_absd", {Int(8, 8), Int(8, 8)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix}, {"vabdl_i8x8", "vabdl_i8x8", UInt(16, 8), "widening_absd", {Int(8, 8), Int(8, 8)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix}, + {"vabdl_u8x8", "vabdl_u8x8", Int(16, 8), "widening_absd", {UInt(8, 8), UInt(8, 8)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix}, {"vabdl_u8x8", "vabdl_u8x8", UInt(16, 8), "widening_absd", {UInt(8, 8), UInt(8, 8)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix}, {"vabdl_i16x4", "vabdl_i16x4", Int(32, 4), "widening_absd", {Int(16, 4), Int(16, 4)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix}, {"vabdl_i16x4", "vabdl_i16x4", UInt(32, 4), "widening_absd", {Int(16, 4), Int(16, 4)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix}, + {"vabdl_u16x4", "vabdl_u16x4", Int(32, 4), "widening_absd", {UInt(16, 4), UInt(16, 4)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix}, {"vabdl_u16x4", "vabdl_u16x4", UInt(32, 4), "widening_absd", {UInt(16, 4), UInt(16, 4)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix}, {"vabdl_i32x2", "vabdl_i32x2", Int(64, 2), "widening_absd", {Int(32, 2), Int(32, 2)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix}, {"vabdl_i32x2", "vabdl_i32x2", UInt(64, 2), "widening_absd", {Int(32, 2), Int(32, 2)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix}, + {"vabdl_u32x2", "vabdl_u32x2", Int(64, 2), "widening_absd", {UInt(32, 2), UInt(32, 2)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix}, {"vabdl_u32x2", "vabdl_u32x2", UInt(64, 2), "widening_absd", {UInt(32, 2), UInt(32, 2)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix}, }; From 1c7388a5fd470339c390049589d587c90b95030c Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Thu, 28 Oct 2021 10:25:58 -0700 Subject: [PATCH 2/4] Allow users to use their own cuda contexts and streams in JIT mode (#6345) * Deprecate JIT runtime override methods that take void * * Make it possible to use custom cuda contexts and streams in JIT mode * Clean up comments * Tolerate null handlers in the JITUserContext These can come up if a JITUserContext is passed to something like copy_to_device before getting fully populated by passing it to a call to realize. * Remove reliance on dlsym in test and reuse the runtime's name resolution mechanism instead * Handle case where cuda and cuda-debug runtime modules both exist This change means we'll only ever create one built-in cuda context in this circumstance. * Slight simplification * Improve comments --- src/JITModule.cpp | 116 ++++++++++++--- src/JITModule.h | 16 ++ src/Pipeline.cpp | 2 +- src/runtime/HalideRuntimeCuda.h | 17 +++ src/runtime/cuda.cpp | 86 +++++++++-- src/runtime/cuda_functions.h | 5 + test/correctness/CMakeLists.txt | 1 + test/correctness/custom_cuda_context.cpp | 178 +++++++++++++++++++++++ 8 files changed, 387 insertions(+), 34 deletions(-) create mode 100644 test/correctness/custom_cuda_context.cpp diff --git a/src/JITModule.cpp b/src/JITModule.cpp index 0033cb212c16..acb8be5da8c7 100644 --- a/src/JITModule.cpp +++ b/src/JITModule.cpp @@ -506,63 +506,72 @@ void merge_handlers(JITHandlers &base, const JITHandlers &addins) { if (addins.custom_get_library_symbol) { base.custom_get_library_symbol = addins.custom_get_library_symbol; } + if (addins.custom_cuda_acquire_context) { + base.custom_cuda_acquire_context = addins.custom_cuda_acquire_context; + } + if (addins.custom_cuda_release_context) { + base.custom_cuda_release_context = addins.custom_cuda_release_context; + } + if (addins.custom_cuda_get_stream) { + base.custom_cuda_get_stream = addins.custom_cuda_get_stream; + } } void print_handler(JITUserContext *context, const char *msg) { - if (context) { - (*context->handlers.custom_print)(context, msg); + if (context && context->handlers.custom_print) { + context->handlers.custom_print(context, msg); } else { - return (*active_handlers.custom_print)(context, msg); + return active_handlers.custom_print(context, msg); } } void *malloc_handler(JITUserContext *context, size_t x) { - if (context) { - return (*context->handlers.custom_malloc)(context, x); + if (context && context->handlers.custom_malloc) { + return context->handlers.custom_malloc(context, x); } else { - return (*active_handlers.custom_malloc)(context, x); + return active_handlers.custom_malloc(context, x); } } void free_handler(JITUserContext *context, void *ptr) { - if (context) { - (*context->handlers.custom_free)(context, ptr); + if (context && context->handlers.custom_free) { + context->handlers.custom_free(context, ptr); } else { - (*active_handlers.custom_free)(context, ptr); + active_handlers.custom_free(context, ptr); } } int do_task_handler(JITUserContext *context, int (*f)(JITUserContext *, int, uint8_t *), int idx, uint8_t *closure) { - if (context) { - return (*context->handlers.custom_do_task)(context, f, idx, closure); + if (context && context->handlers.custom_do_task) { + return context->handlers.custom_do_task(context, f, idx, closure); } else { - return (*active_handlers.custom_do_task)(context, f, idx, closure); + return active_handlers.custom_do_task(context, f, idx, closure); } } int do_par_for_handler(JITUserContext *context, int (*f)(JITUserContext *, int, uint8_t *), int min, int size, uint8_t *closure) { - if (context) { - return (*context->handlers.custom_do_par_for)(context, f, min, size, closure); + if (context && context->handlers.custom_do_par_for) { + return context->handlers.custom_do_par_for(context, f, min, size, closure); } else { - return (*active_handlers.custom_do_par_for)(context, f, min, size, closure); + return active_handlers.custom_do_par_for(context, f, min, size, closure); } } void error_handler_handler(JITUserContext *context, const char *msg) { - if (context) { - (*context->handlers.custom_error)(context, msg); + if (context && context->handlers.custom_error) { + context->handlers.custom_error(context, msg); } else { - (*active_handlers.custom_error)(context, msg); + active_handlers.custom_error(context, msg); } } int32_t trace_handler(JITUserContext *context, const halide_trace_event_t *e) { - if (context) { - return (*context->handlers.custom_trace)(context, e); + if (context && context->handlers.custom_trace) { + return context->handlers.custom_trace(context, e); } else { - return (*active_handlers.custom_trace)(context, e); + return active_handlers.custom_trace(context, e); } } @@ -578,6 +587,30 @@ void *get_library_symbol_handler(void *lib, const char *name) { return (*active_handlers.custom_get_library_symbol)(lib, name); } +int cuda_acquire_context_handler(JITUserContext *context, void **cuda_context_ptr, bool create) { + if (context && context->handlers.custom_cuda_acquire_context) { + return context->handlers.custom_cuda_acquire_context(context, cuda_context_ptr, create); + } else { + return active_handlers.custom_cuda_acquire_context(context, cuda_context_ptr, create); + } +} + +int cuda_release_context_handler(JITUserContext *context) { + if (context && context->handlers.custom_cuda_release_context) { + return context->handlers.custom_cuda_release_context(context); + } else { + return active_handlers.custom_cuda_release_context(context); + } +} + +int cuda_get_stream_handler(JITUserContext *context, void *cuda_context, void **cuda_stream_ptr) { + if (context && context->handlers.custom_cuda_get_stream) { + return context->handlers.custom_cuda_get_stream(context, cuda_context, cuda_stream_ptr); + } else { + return active_handlers.custom_cuda_get_stream(context, cuda_context, cuda_stream_ptr); + } +} + template function_t hook_function(const std::map &exports, const char *hook_name, function_t hook) { auto iter = exports.find(hook_name); @@ -776,13 +809,13 @@ JITModule &make_module(llvm::Module *for_module, Target target, hook_function(runtime.exports(), "halide_set_custom_trace", trace_handler); runtime_internal_handlers.custom_get_symbol = - hook_function(shared_runtimes(MainShared).exports(), "halide_set_custom_get_symbol", get_symbol_handler); + hook_function(runtime.exports(), "halide_set_custom_get_symbol", get_symbol_handler); runtime_internal_handlers.custom_load_library = - hook_function(shared_runtimes(MainShared).exports(), "halide_set_custom_load_library", load_library_handler); + hook_function(runtime.exports(), "halide_set_custom_load_library", load_library_handler); runtime_internal_handlers.custom_get_library_symbol = - hook_function(shared_runtimes(MainShared).exports(), "halide_set_custom_get_library_symbol", get_library_symbol_handler); + hook_function(runtime.exports(), "halide_set_custom_get_library_symbol", get_library_symbol_handler); active_handlers = runtime_internal_handlers; merge_handlers(active_handlers, default_handlers); @@ -794,6 +827,41 @@ JITModule &make_module(llvm::Module *for_module, Target target, runtime.jit_module->name = "MainShared"; } else { runtime.jit_module->name = "GPU"; + + // There are two versions of these cuda context + // management handlers we could use - one in the cuda + // module, and one in the cuda-debug module. If both + // modules are in use, we'll just want to use one of + // them, so that we don't needlessly create two cuda + // contexts. We'll use whichever was first + // created. The second one will then declare a + // dependency on the first one, to make sure things + // are destroyed in the correct order. + + if (runtime_kind == CUDA || runtime_kind == CUDADebug) { + if (!runtime_internal_handlers.custom_cuda_acquire_context) { + // Neither module has been created. + runtime_internal_handlers.custom_cuda_acquire_context = + hook_function(runtime.exports(), "halide_set_cuda_acquire_context", cuda_acquire_context_handler); + + runtime_internal_handlers.custom_cuda_release_context = + hook_function(runtime.exports(), "halide_set_cuda_release_context", cuda_release_context_handler); + + runtime_internal_handlers.custom_cuda_get_stream = + hook_function(runtime.exports(), "halide_set_cuda_get_stream", cuda_get_stream_handler); + + active_handlers = runtime_internal_handlers; + merge_handlers(active_handlers, default_handlers); + } else if (runtime_kind == CUDA) { + // The CUDADebug module has already been created. + // Use the context in the CUDADebug module and add + // a dependence edge from the CUDA module to it. + runtime.add_dependency(shared_runtimes(CUDADebug)); + } else { + // The CUDA module has already been created. + runtime.add_dependency(shared_runtimes(CUDA)); + } + } } uint64_t arg_addr = diff --git a/src/JITModule.h b/src/JITModule.h index 5208d31698ac..ee27fe99216b 100644 --- a/src/JITModule.h +++ b/src/JITModule.h @@ -106,6 +106,22 @@ struct JITHandlers { * an opened library. Equivalent to dlsym. Takes a handle * returned by custom_load_library as the first argument. */ void *(*custom_get_library_symbol)(void *lib, const char *name){nullptr}; + + /** A custom method for the Halide runtime acquire a cuda + * context. The cuda context is treated as a void * to avoid a + * dependence on the cuda headers. If the create argument is set + * to true, a context should be created if one does not already + * exist. */ + int32_t (*custom_cuda_acquire_context)(JITUserContext *user_context, void **cuda_context_ptr, bool create){nullptr}; + + /** The Halide runtime calls this when it is done with a cuda + * context. The default implementation does nothing. */ + int32_t (*custom_cuda_release_context)(JITUserContext *user_context){nullptr}; + + /** A custom method for the Halide runtime to acquire a cuda + * stream to use. The cuda context and stream are both modelled + * as a void *, to avoid a dependence on the cuda headers. */ + int32_t (*custom_cuda_get_stream)(JITUserContext *user_context, void *cuda_context, void **stream_ptr){nullptr}; }; namespace Internal { diff --git a/src/Pipeline.cpp b/src/Pipeline.cpp index 06fa9eb34003..eebeb86dea80 100644 --- a/src/Pipeline.cpp +++ b/src/Pipeline.cpp @@ -757,7 +757,7 @@ Realization Pipeline::realize(JITUserContext *context, if (needs_crop) { r[i].crop(crop); } - r[i].copy_to_host(); + r[i].copy_to_host(context); } return r; } diff --git a/src/runtime/HalideRuntimeCuda.h b/src/runtime/HalideRuntimeCuda.h index 3eff7d834543..239f6e698561 100644 --- a/src/runtime/HalideRuntimeCuda.h +++ b/src/runtime/HalideRuntimeCuda.h @@ -65,6 +65,23 @@ extern uintptr_t halide_cuda_get_device_ptr(void *user_context, struct halide_bu * driver. See halide_reuse_device_allocations. */ extern int halide_cuda_release_unused_device_allocations(void *user_context); +// These typedefs treat both a CUcontext and a CUstream as a void *, +// to avoid dependencies on cuda headers. +typedef int (*halide_cuda_acquire_context_t)(void *, // user_context + void **, // cuda context out parameter + bool); // should create a context if none exist +typedef int (*halide_cuda_release_context_t)(void * /* user_context */); +typedef int (*halide_cuda_get_stream_t)(void *, // user_context + void *, // context + void **); // stream out parameter + +/** Set custom methods to acquire and release cuda contexts and streams */ +// @{ +extern halide_cuda_acquire_context_t halide_set_cuda_acquire_context(halide_cuda_acquire_context_t handler); +extern halide_cuda_release_context_t halide_set_cuda_release_context(halide_cuda_release_context_t handler); +extern halide_cuda_get_stream_t halide_set_cuda_get_stream(halide_cuda_get_stream_t handler); +// @} + #ifdef __cplusplus } // End extern "C" #endif diff --git a/src/runtime/cuda.cpp b/src/runtime/cuda.cpp index 11cc906bf316..eee52c1e4804 100644 --- a/src/runtime/cuda.cpp +++ b/src/runtime/cuda.cpp @@ -136,7 +136,7 @@ extern "C" { // - A call to halide_cuda_acquire_context is followed by a matching call to // halide_cuda_release_context. halide_cuda_acquire_context should block while a // previous call (if any) has not yet been released via halide_cuda_release_context. -WEAK int halide_cuda_acquire_context(void *user_context, CUcontext *ctx, bool create = true) { +WEAK int halide_default_cuda_acquire_context(void *user_context, CUcontext *ctx, bool create = true) { // TODO: Should we use a more "assertive" assert? these asserts do // not block execution on failure. halide_assert(user_context, ctx != nullptr); @@ -179,7 +179,7 @@ WEAK int halide_cuda_acquire_context(void *user_context, CUcontext *ctx, bool cr return 0; } -WEAK int halide_cuda_release_context(void *user_context) { +WEAK int halide_default_cuda_release_context(void *user_context) { return 0; } @@ -188,7 +188,7 @@ WEAK int halide_cuda_release_context(void *user_context) { // for the context (nullptr stream). The context is passed in for convenience, but // any sort of scoping must be handled by that of the // halide_cuda_acquire_context/halide_cuda_release_context pair, not this call. -WEAK int halide_cuda_get_stream(void *user_context, CUcontext ctx, CUstream *stream) { +WEAK int halide_default_cuda_get_stream(void *user_context, CUcontext ctx, CUstream *stream) { // There are two default streams we could use. stream 0 is fully // synchronous. stream 2 gives a separate non-blocking stream per // thread. @@ -198,6 +198,53 @@ WEAK int halide_cuda_get_stream(void *user_context, CUcontext ctx, CUstream *str } // extern "C" +namespace Halide { +namespace Runtime { +namespace Internal { +namespace CUDA { + +WEAK halide_cuda_acquire_context_t acquire_context = (halide_cuda_acquire_context_t)halide_default_cuda_acquire_context; +WEAK halide_cuda_release_context_t release_context = (halide_cuda_release_context_t)halide_default_cuda_release_context; +WEAK halide_cuda_get_stream_t get_stream = (halide_cuda_get_stream_t)halide_default_cuda_get_stream; + +} // namespace CUDA +} // namespace Internal +} // namespace Runtime +} // namespace Halide + +extern "C" { + +WEAK int halide_cuda_acquire_context(void *user_context, CUcontext *ctx, bool create = true) { + return CUDA::acquire_context(user_context, (void **)ctx, create); +} + +WEAK halide_cuda_acquire_context_t halide_set_cuda_acquire_context(halide_cuda_acquire_context_t handler) { + halide_cuda_acquire_context_t result = CUDA::acquire_context; + CUDA::acquire_context = handler; + return result; +} + +WEAK int halide_cuda_release_context(void *user_context) { + return CUDA::release_context(user_context); +} + +WEAK halide_cuda_release_context_t halide_set_cuda_release_context(halide_cuda_release_context_t handler) { + halide_cuda_release_context_t result = CUDA::release_context; + CUDA::release_context = handler; + return result; +} + +WEAK int halide_cuda_get_stream(void *user_context, CUcontext ctx, CUstream *stream) { + return CUDA::get_stream(user_context, (void *)ctx, (void **)stream); +} + +WEAK halide_cuda_get_stream_t halide_set_cuda_get_stream(halide_cuda_get_stream_t handler) { + halide_cuda_get_stream_t result = CUDA::get_stream; + CUDA::get_stream = handler; + return result; +} +} + namespace Halide { namespace Runtime { namespace Internal { @@ -845,7 +892,8 @@ WEAK int halide_cuda_device_malloc(void *user_context, halide_buffer_t *buf) { namespace { WEAK int cuda_do_multidimensional_copy(void *user_context, const device_copy &c, - uint64_t src, uint64_t dst, int d, bool from_host, bool to_host) { + uint64_t src, uint64_t dst, int d, bool from_host, bool to_host, + CUstream stream) { if (d > MAX_COPY_DIMS) { error(user_context) << "Buffer has too many dimensions to copy to/from GPU\n"; return -1; @@ -858,15 +906,27 @@ WEAK int cuda_do_multidimensional_copy(void *user_context, const device_copy &c, if (!from_host && to_host) { debug(user_context) << "cuMemcpyDtoH(" << (void *)dst << ", " << (void *)src << ", " << c.chunk_size << ")\n"; copy_name = "cuMemcpyDtoH"; - err = cuMemcpyDtoH((void *)dst, (CUdeviceptr)src, c.chunk_size); + if (stream) { + err = cuMemcpyDtoHAsync((void *)dst, (CUdeviceptr)src, c.chunk_size, stream); + } else { + err = cuMemcpyDtoH((void *)dst, (CUdeviceptr)src, c.chunk_size); + } } else if (from_host && !to_host) { debug(user_context) << "cuMemcpyHtoD(" << (void *)dst << ", " << (void *)src << ", " << c.chunk_size << ")\n"; copy_name = "cuMemcpyHtoD"; - err = cuMemcpyHtoD((CUdeviceptr)dst, (void *)src, c.chunk_size); + if (stream) { + err = cuMemcpyHtoDAsync((CUdeviceptr)dst, (void *)src, c.chunk_size, stream); + } else { + err = cuMemcpyHtoD((CUdeviceptr)dst, (void *)src, c.chunk_size); + } } else if (!from_host && !to_host) { debug(user_context) << "cuMemcpyDtoD(" << (void *)dst << ", " << (void *)src << ", " << c.chunk_size << ")\n"; copy_name = "cuMemcpyDtoD"; - err = cuMemcpyDtoD((CUdeviceptr)dst, (CUdeviceptr)src, c.chunk_size); + if (stream) { + err = cuMemcpyDtoDAsync((CUdeviceptr)dst, (CUdeviceptr)src, c.chunk_size, stream); + } else { + err = cuMemcpyDtoD((CUdeviceptr)dst, (CUdeviceptr)src, c.chunk_size); + } } else if (dst != src) { debug(user_context) << "memcpy(" << (void *)dst << ", " << (void *)src << ", " << c.chunk_size << ")\n"; // Could reach here if a user called directly into the @@ -881,7 +941,7 @@ WEAK int cuda_do_multidimensional_copy(void *user_context, const device_copy &c, } else { ssize_t src_off = 0, dst_off = 0; for (int i = 0; i < (int)c.extent[d - 1]; i++) { - int err = cuda_do_multidimensional_copy(user_context, c, src + src_off, dst + dst_off, d - 1, from_host, to_host); + int err = cuda_do_multidimensional_copy(user_context, c, src + src_off, dst + dst_off, d - 1, from_host, to_host, stream); dst_off += c.dst_stride_bytes[d - 1]; src_off += c.src_stride_bytes[d - 1]; if (err) { @@ -938,7 +998,15 @@ WEAK int halide_cuda_buffer_copy(void *user_context, struct halide_buffer_t *src } #endif - err = cuda_do_multidimensional_copy(user_context, c, c.src + c.src_begin, c.dst, dst->dimensions, from_host, to_host); + CUstream stream = nullptr; + if (cuStreamSynchronize != nullptr) { + int result = halide_cuda_get_stream(user_context, ctx.context, &stream); + if (result != 0) { + error(user_context) << "CUDA: In cuda_do_multidimensional_copy, halide_cuda_get_stream returned " << result << "\n"; + } + } + + err = cuda_do_multidimensional_copy(user_context, c, c.src + c.src_begin, c.dst, dst->dimensions, from_host, to_host, stream); #ifdef DEBUG_RUNTIME uint64_t t_after = halide_current_time_ns(user_context); diff --git a/src/runtime/cuda_functions.h b/src/runtime/cuda_functions.h index 61614ea89805..5242e931686a 100644 --- a/src/runtime/cuda_functions.h +++ b/src/runtime/cuda_functions.h @@ -36,6 +36,11 @@ CUDA_FN_3020(CUresult, cuMemFree, cuMemFree_v2, (CUdeviceptr dptr)); CUDA_FN_3020(CUresult, cuMemcpyHtoD, cuMemcpyHtoD_v2, (CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount)); CUDA_FN_3020(CUresult, cuMemcpyDtoH, cuMemcpyDtoH_v2, (void *dstHost, CUdeviceptr srcDevice, size_t ByteCount)); CUDA_FN_3020(CUresult, cuMemcpyDtoD, cuMemcpyDtoD_v2, (CUdeviceptr dstHost, CUdeviceptr srcDevice, size_t ByteCount)); + +CUDA_FN_3020(CUresult, cuMemcpyHtoDAsync, cuMemcpyHtoDAsync_v2, (CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount, CUstream stream)); +CUDA_FN_3020(CUresult, cuMemcpyDtoHAsync, cuMemcpyDtoHAsync_v2, (void *dstHost, CUdeviceptr srcDevice, size_t ByteCount, CUstream stream)); +CUDA_FN_3020(CUresult, cuMemcpyDtoDAsync, cuMemcpyDtoDAsync_v2, (CUdeviceptr dstHost, CUdeviceptr srcDevice, size_t ByteCount, CUstream stream)); + CUDA_FN_3020(CUresult, cuMemcpy3D, cuMemcpy3D_v2, (const CUDA_MEMCPY3D *pCopy)); CUDA_FN(CUresult, cuLaunchKernel, (CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void **kernelParams, void **extra)); CUDA_FN(CUresult, cuCtxSynchronize, ()); diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index bb4bb8908f8e..0cf9332376ad 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -66,6 +66,7 @@ tests(GROUPS correctness cuda_8_bit_dot_product.cpp custom_allocator.cpp custom_auto_scheduler.cpp + custom_cuda_context.cpp custom_error_reporter.cpp custom_jit_context.cpp custom_lowering_pass.cpp diff --git a/test/correctness/custom_cuda_context.cpp b/test/correctness/custom_cuda_context.cpp new file mode 100644 index 000000000000..d9aaa76c96cd --- /dev/null +++ b/test/correctness/custom_cuda_context.cpp @@ -0,0 +1,178 @@ +#include "Halide.h" + +using namespace Halide; + +int (*cuStreamCreate)(void **, uint32_t) = nullptr; +int (*cuCtxCreate)(void **, uint32_t, int) = nullptr; +int (*cuCtxDestroy)(void *) = nullptr; +int (*cuMemAlloc)(void **, size_t) = nullptr; +int (*cuMemFree)(void *) = nullptr; +int (*cuCtxSetCurrent)(void *) = nullptr; + +struct CudaState : public Halide::JITUserContext { + void *cuda_context = nullptr, *cuda_stream = nullptr; + std::atomic acquires = 0, releases = 0; + + static int my_cuda_acquire_context(JITUserContext *ctx, void **cuda_ctx, bool create) { + CudaState *state = (CudaState *)ctx; + *cuda_ctx = state->cuda_context; + state->acquires++; + return 0; + } + + static int my_cuda_release_context(JITUserContext *ctx) { + CudaState *state = (CudaState *)ctx; + state->releases++; + return 0; + } + + static int my_cuda_get_stream(JITUserContext *ctx, void *cuda_ctx, void **stream) { + CudaState *state = (CudaState *)ctx; + *stream = state->cuda_stream; + return 0; + } + + CudaState() { + handlers.custom_cuda_acquire_context = my_cuda_acquire_context; + handlers.custom_cuda_release_context = my_cuda_release_context; + handlers.custom_cuda_get_stream = my_cuda_get_stream; + } +}; + +int main(int argc, char **argv) { + Target target = get_jit_target_from_environment(); + if (!target.has_feature(Target::CUDA)) { + printf("[SKIP] CUDA not enabled.\n"); + return 0; + } + + { + // Do some nonsense to get symbols out of libcuda without + // needing the CUDA sdk. This would not be a concern in a real + // cuda-using application but is helpful for our + // build-and-test infrastructure. + + // We'll find cuda module in the Halide runtime so + // that we can use it resolve symbols into libcuda in a + // portable way. + + // Force-initialize the cuda runtime module by running something trivial. + evaluate_may_gpu(Expr(0.f)); + + // Go get it, and dig out the method used to resolve symbols in libcuda. + auto runtime_modules = Internal::JITSharedRuntime::get(nullptr, target, false); + void *(*halide_cuda_get_symbol)(void *, const char *) = nullptr; + for (Internal::JITModule &m : runtime_modules) { + // Just rifle through all the runtime modules for this + // target until we find the method we want. + auto sym = m.find_symbol_by_name("halide_cuda_get_symbol"); + if (sym.address != nullptr) { + halide_cuda_get_symbol = (decltype(halide_cuda_get_symbol))sym.address; + break; + } + } + + if (halide_cuda_get_symbol == nullptr) { + printf("Failed to extract halide_cuda_get_symbol from Halide cuda runtime\n"); + return -1; + } + + // Go get the CUDA API functions we actually intend to use. + cuStreamCreate = (decltype(cuStreamCreate))halide_cuda_get_symbol(nullptr, "cuStreamCreate"); + cuCtxCreate = (decltype(cuCtxCreate))halide_cuda_get_symbol(nullptr, "cuCtxCreate_v2"); + cuCtxDestroy = (decltype(cuCtxDestroy))halide_cuda_get_symbol(nullptr, "cuCtxDestroy_v2"); + cuCtxSetCurrent = (decltype(cuCtxSetCurrent))halide_cuda_get_symbol(nullptr, "cuCtxSetCurrent"); + cuMemAlloc = (decltype(cuMemAlloc))halide_cuda_get_symbol(nullptr, "cuMemAlloc_v2"); + cuMemFree = (decltype(cuMemFree))halide_cuda_get_symbol(nullptr, "cuMemFree_v2"); + + if (cuStreamCreate == nullptr || + cuCtxCreate == nullptr || + cuCtxDestroy == nullptr || + cuCtxSetCurrent == nullptr || + cuMemAlloc == nullptr || + cuMemFree == nullptr) { + printf("Failed to find cuda API\n"); + return -1; + } + } + + // Make a cuda context and stream. + CudaState state; + int err = cuCtxCreate(&state.cuda_context, 0, 0); + if (state.cuda_context == nullptr) { + printf("Failed to initialize context: %d\n", err); + return -1; + } + + err = cuCtxSetCurrent(state.cuda_context); + if (err) { + printf("Failed to set context: %d\n", err); + return -1; + } + + err = cuStreamCreate(&state.cuda_stream, 1 /* non-blocking */); + if (state.cuda_stream == nullptr) { + printf("Failed to initialize stream: %d\n", err); + return -1; + } + + // Allocate some GPU memory on this context + const int width = 32, height = 1024; + + void *ptr = nullptr; + err = cuMemAlloc(&ptr, width * height * sizeof(float)); + + if (ptr == nullptr) { + printf("cuMemAlloc failed: %d\n", err); + return -1; + } + + // Wrap a Halide buffer around it, with some host memory too. + Buffer in(width, height); + in.fill(4.0f); + auto device_interface = get_device_interface_for_device_api(DeviceAPI::CUDA); + in.device_wrap_native(device_interface, + (uintptr_t)ptr, &state); + in.copy_to_device(device_interface, &state); + + // Run a kernel on multiple threads that copies slices of it into + // a Halide-allocated temporary buffer. This would likely crash + // if we don't allocate the outputs on the right context. If the + // copies don't happen on the same stream as the compute, we'll + // get incorrect outputs due to race conditions. + Func f, g; + Var x, xi, y; + f(x, y) = sqrt(in(x, y)); + g(x, y) = f(x, y); + f.gpu_tile(x, x, xi, 32).compute_at(g, y); + g.parallel(y); + + for (int i = 0; i < 10; i++) { + Buffer out = g.realize(&state, {width, height}); + out.copy_to_host(&state); + for (int y = 0; y < height; y++) { + for (int x = 0; x < width; x++) { + float correct = 2.0f; + if (out(x, y) != 2.0f) { + printf("out(%d, %d) = %f instead of %f\n", x, y, out(x, y), correct); + return -1; + } + } + } + } + + // Clean up + in.device_detach_native(&state); + cuMemFree(ptr); + cuCtxDestroy(state.cuda_stream); + + if (state.acquires.load() != state.releases.load() || + state.acquires.load() < height) { + printf("Context acquires: %d releases: %d\n", state.acquires.load(), state.releases.load()); + printf("Expected these to match and be at least %d (the number of parallel tasks)\n", height); + return -1; + } + + printf("Success!\n"); + return 0; +} From e10f104c37d5cf2186f2285b14b7aeb9358f50fb Mon Sep 17 00:00:00 2001 From: Steven Johnson Date: Thu, 28 Oct 2021 10:34:27 -0700 Subject: [PATCH 3/4] Update Emscripten settings (#6362) The settings we use to build C++ in wasm were slightly out of date now that we've updated our runtime to Node instead of d8. Also drive-by gitignore fix. --- apps/hannk/.gitignore | 2 +- dependencies/wasm/CMakeLists.txt | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/apps/hannk/.gitignore b/apps/hannk/.gitignore index d5697ecf6f18..46164faef484 100644 --- a/apps/hannk/.gitignore +++ b/apps/hannk/.gitignore @@ -1,2 +1,2 @@ -build/ +build*/ bin/ diff --git a/dependencies/wasm/CMakeLists.txt b/dependencies/wasm/CMakeLists.txt index d6800f83c9b0..838ba0274a76 100644 --- a/dependencies/wasm/CMakeLists.txt +++ b/dependencies/wasm/CMakeLists.txt @@ -91,9 +91,8 @@ function(add_wasm_executable TARGET) -Wsuggest-override -s ASSERTIONS=1 -s ALLOW_MEMORY_GROWTH=1 - -s WASM_BIGINT=1 - -s STANDALONE_WASM=1 - -s ENVIRONMENT=node) + -s ENVIRONMENT=node + ) set(SRCS) foreach (S IN LISTS args_SRCS) From 541bc37e7f420ff616affe38dac6e51337cab333 Mon Sep 17 00:00:00 2001 From: Steven Johnson Date: Thu, 28 Oct 2021 14:14:42 -0700 Subject: [PATCH 4/4] [hannk] Allow disabling TFLite+Delegate build in CMake (#6360) * [hannk] Allow disabling TFLite+Delegate build in CMake Preparatory work for allowing building of hannk with Emscripten; TFLite (and its dependees) problematic to build in that environment, but this will allow us to build a tflite-parser-only environment. (Note that more work is needed to get this working for wasm, as crosscompiling in CMake is still pretty painful; this work was split out to make subsequent reviews simpler) * Update hannk_delegate.h * HANNK_BUILD_TFLITE_DELEGATE -> HANNK_BUILD_TFLITE --- apps/hannk/CMakeLists.txt | 14 ++++++++- apps/hannk/Makefile | 3 ++ apps/hannk/configure_cmake.sh | 8 ++++++ apps/hannk/delegate/hannk_delegate.h | 4 +++ apps/hannk/tflite/CMakeLists.txt | 9 ++++-- apps/hannk/util/CMakeLists.txt | 12 +++++--- apps/hannk/util/model_runner.cpp | 43 ++++++++++++++++++++++++++-- apps/hannk/util/model_runner.h | 6 ++++ 8 files changed, 89 insertions(+), 10 deletions(-) diff --git a/apps/hannk/CMakeLists.txt b/apps/hannk/CMakeLists.txt index b01c1a6875a4..3e663100a6a2 100644 --- a/apps/hannk/CMakeLists.txt +++ b/apps/hannk/CMakeLists.txt @@ -6,6 +6,9 @@ set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) enable_testing() +option(HANNK_BUILD_TFLITE "Build TFLite+Delegate for HANNK" ON) +message(STATUS "HANNK_BUILD_TFLITE is ${HANNK_BUILD_TFLITE}") + # -fPIC is necessary for .so builds (at least on Linux); not necessary for the non-delegate # builds but easier to enable it for everything. set(CMAKE_POSITION_INDEPENDENT_CODE ON) @@ -19,6 +22,8 @@ set(CMAKE_CXX_EXTENSIONS NO) find_package(Halide REQUIRED) # Set up the version of TFLite we expect +# (We need to do this even if HANNK_BUILD_TFLITE is off, +# so that the .tflite file parser can get the right schema) set(TFLITE_VERSION_MAJOR "2" CACHE STRING "Major version of TFLite to assume") set(TFLITE_VERSION_MINOR "6" CACHE STRING "Minor version of TFLite to assume") set(TFLITE_VERSION_PATCH "0" CACHE STRING "Patch version of TFLite to assume") @@ -26,14 +31,21 @@ set(TFLITE_VERSION_PATCH "0" CACHE STRING "Patch version of TFLite to assume") add_compile_definitions(TFLITE_VERSION_MAJOR=${TFLITE_VERSION_MAJOR}) add_compile_definitions(TFLITE_VERSION_MINOR=${TFLITE_VERSION_MINOR}) add_compile_definitions(TFLITE_VERSION_PATCH=${TFLITE_VERSION_PATCH}) +if (HANNK_BUILD_TFLITE) + add_compile_definitions(HANNK_BUILD_TFLITE=1) +else () + add_compile_definitions(HANNK_BUILD_TFLITE=0) +endif () set(TFLITE_VERSION "${TFLITE_VERSION_MAJOR}.${TFLITE_VERSION_MINOR}.${TFLITE_VERSION_PATCH}") -add_subdirectory(delegate) add_subdirectory(halide) add_subdirectory(interpreter) add_subdirectory(tflite) add_subdirectory(util) +if (HANNK_BUILD_TFLITE) + add_subdirectory(delegate) +endif () # Benchmarking executable add_executable(benchmark benchmark.cpp) diff --git a/apps/hannk/Makefile b/apps/hannk/Makefile index b6407d87beb8..77c4a41488f3 100644 --- a/apps/hannk/Makefile +++ b/apps/hannk/Makefile @@ -6,6 +6,9 @@ include ../support/Makefile.inc # builds but easier to enable it for everything. CXXFLAGS += -Wno-unused-private-field -fno-exceptions -fPIC -fvisibility=hidden -fvisibility-inlines-hidden -Wunused-variable -Wsuggest-override -Woverloaded-virtual -I$(MAKEFILE_DIR) +# No option to build without TFLite/Delegate in Make (use CMake for that) +CXXFLAGS += -DHANNK_BUILD_TFLITE=1 + BENCHMARK_OUT = benchmark ifeq (hexagon-32-qurt,$(findstring hexagon-32-qurt,$(HL_TARGET))) # Building benchmark application as shared object instead of elf for diff --git a/apps/hannk/configure_cmake.sh b/apps/hannk/configure_cmake.sh index 55448cbe1908..99440afaff1c 100755 --- a/apps/hannk/configure_cmake.sh +++ b/apps/hannk/configure_cmake.sh @@ -28,6 +28,13 @@ CMAKE_BUILD_TYPE=Release fi echo Using CMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} +if [ -z "${HANNK_BUILD_TFLITE}" ]; then +HANNK_BUILD_TFLITE=ON +else +HANNK_BUILD_TFLITE=OFF +fi +echo Using HANNK_BUILD_TFLITE=${HANNK_BUILD_TFLITE} + EXTRAS= # TODO: this doesn't work (yet); crosscompiling in CMake is painful. if [[ "${HL_TARGET}" =~ ^arm-64-android.* ]]; then @@ -45,6 +52,7 @@ cmake \ -DHalide_DIR="${HALIDE_INSTALL_PATH}" \ -DCMAKE_PREFIX_PATH="${HALIDE_INSTALL_PATH}" \ -DHalide_TARGET=${HL_TARGET} \ + -DHANNK_BUILD_TFLITE=${HANNK_BUILD_TFLITE} \ -S "${HANNK_DIR}" \ -B "${BUILD_DIR}" diff --git a/apps/hannk/delegate/hannk_delegate.h b/apps/hannk/delegate/hannk_delegate.h index 6c19d12a1975..2971846720c1 100644 --- a/apps/hannk/delegate/hannk_delegate.h +++ b/apps/hannk/delegate/hannk_delegate.h @@ -1,6 +1,10 @@ #ifndef HANNK_DELEGATE_H #define HANNK_DELEGATE_H +#if !HANNK_BUILD_TFLITE +#error "This file should not be included when HANNK_BUILD_TFLITE=0" +#endif + #include "tensorflow/lite/c/c_api.h" #ifdef __cplusplus diff --git a/apps/hannk/tflite/CMakeLists.txt b/apps/hannk/tflite/CMakeLists.txt index d56d5db37f34..6d7627506d56 100644 --- a/apps/hannk/tflite/CMakeLists.txt +++ b/apps/hannk/tflite/CMakeLists.txt @@ -17,6 +17,9 @@ set(FLATBUFFERS_BUILD_TESTS OFF) set(FLATBUFFERS_INSTALL OFF) set(FLATBUFFERS_BUILD_FLATC OFF) +# Enable this to see details about downloading -- useful for debugging +# set(FETCHCONTENT_QUIET NO) + FetchContent_Declare(tflite GIT_REPOSITORY https://github.com/tensorflow/tensorflow GIT_TAG ${TFLITE_TAG} @@ -33,8 +36,10 @@ if (NOT tflite_POPULATED) set(CMAKE_MESSAGE_LOG_LEVEL ${OLD_CMAKE_MESSAGE_LOG_LEVEL}) endif () - -# tensorflowlite_c is implicitly declared by this FetchContent +# tensorflowlite_c is implicitly declared by this FetchContent. +# Mark it as EXCLUDE_FROM_ALL so that it won't be built unless we actually +# depend on it (which we might not depending on HANNK_BUILD_TFLITE) +set_property(TARGET tensorflowlite_c PROPERTY EXCLUDE_FROM_ALL TRUE) # Disable some noisy warnings in abseil foreach (LIB IN ITEMS diff --git a/apps/hannk/util/CMakeLists.txt b/apps/hannk/util/CMakeLists.txt index b3030f30e9a1..5cfa4c184b37 100644 --- a/apps/hannk/util/CMakeLists.txt +++ b/apps/hannk/util/CMakeLists.txt @@ -4,12 +4,12 @@ target_include_directories(error_util PUBLIC $) target_link_libraries(error_util PRIVATE Halide::Runtime) -add_library(hannk_log_stderr STATIC +add_library(hannk_log_stderr STATIC EXCLUDE_FROM_ALL hannk_log_stderr.cpp) target_include_directories(hannk_log_stderr PUBLIC $) -add_library(hannk_log_tflite STATIC +add_library(hannk_log_tflite STATIC EXCLUDE_FROM_ALL hannk_log_tflite.cpp) target_link_libraries(hannk_log_tflite PRIVATE tensorflowlite_headers) target_include_directories(hannk_log_tflite PUBLIC @@ -41,12 +41,16 @@ target_compile_definitions(model_runner PRIVATE -DTFLITE_VERSION_MINOR=${TFLITE_VERSION_MINOR} -DTFLITE_VERSION_PATCH=${TFLITE_VERSION_PATCH}) target_link_libraries(model_runner PRIVATE - tensorflowlite_c interpreter error_util file_util - hannk_delegate tflite_parser Halide::Tools # for halide_benchmark.h Halide::Runtime tensorflowlite_headers) + +if (HANNK_BUILD_TFLITE) + target_link_libraries(model_runner PRIVATE + tensorflowlite_c + hannk_delegate) +endif () diff --git a/apps/hannk/util/model_runner.cpp b/apps/hannk/util/model_runner.cpp index c40607224d1e..209a67625a74 100644 --- a/apps/hannk/util/model_runner.cpp +++ b/apps/hannk/util/model_runner.cpp @@ -9,7 +9,9 @@ #include "util/model_runner.h" +#if HANNK_BUILD_TFLITE #include "delegate/hannk_delegate.h" +#endif #include "halide_benchmark.h" #include "interpreter/interpreter.h" #include "tflite/tflite_parser.h" @@ -17,9 +19,11 @@ #include "util/error_util.h" #include "util/file_util.h" +#if HANNK_BUILD_TFLITE // IMPORTANT: use only the TFLite C API here. #include "tensorflow/lite/c/c_api.h" #include "tensorflow/lite/c/common.h" +#endif namespace hannk { namespace { @@ -29,6 +33,7 @@ std::chrono::duration bench(std::function f) { return std::chrono::duration(result.wall_time); } +#if HANNK_BUILD_TFLITE halide_type_t tf_lite_type_to_halide_type(TfLiteType t) { switch (t) { case kTfLiteBool: @@ -141,6 +146,7 @@ class DelegatePtr { } } }; +#endif static const char *const RunNames[ModelRunner::kNumRuns] = { "TfLite", @@ -234,6 +240,7 @@ int SeedTracker::seed_for_name(const std::string &name) { return seed_here; } +#if HANNK_BUILD_TFLITE /*static*/ void TfLiteModelRunner::ErrorReporter(void *user_data, const char *format, va_list args) { TfLiteModelRunner *self = (TfLiteModelRunner *)user_data; if (self->verbose_output_) { @@ -340,10 +347,15 @@ TfLiteModelRunner::~TfLiteModelRunner() { TfLiteModelDelete(tf_model_); } } +#endif ModelRunner::ModelRunner() { for (int i = 0; i < kNumRuns; i++) { +#if HANNK_BUILD_TFLITE do_run[i] = true; +#else + do_run[i] = (i == kHannk); +#endif } #if defined(__arm__) || defined(__aarch64__) // TFLite on Arm devices generally uses the rounding-shift instructions, @@ -369,6 +381,7 @@ void ModelRunner::status() { std::cout << "Using random seed: " << seed_tracker_.next_seed() << "\n"; std::cout << "Using threads: " << threads << "\n"; +#if HANNK_BUILD_TFLITE std::string tf_ver = TfLiteVersion(); std::cout << "Using TFLite version: " << tf_ver << "\n"; std::string expected = std::to_string(TFLITE_VERSION_MAJOR) + "." + std::to_string(TFLITE_VERSION_MINOR) + "."; @@ -376,6 +389,9 @@ void ModelRunner::status() { std::cerr << "*** WARNING: compare_vs_tflite has been tested against TFLite v" << expected << "x, " << "but is using " << tf_ver << "; results may be inaccurate or wrong.\n"; } +#else + std::cout << "Built without TFLite support.\n"; +#endif } } @@ -437,6 +453,7 @@ ModelRunner::RunResult ModelRunner::run_in_hannk(const std::vector &buffer return result; } +#if HANNK_BUILD_TFLITE ModelRunner::RunResult ModelRunner::run_in_tflite(const std::vector &buffer, TfLiteDelegate *delegate) { RunResult result; @@ -457,6 +474,7 @@ ModelRunner::RunResult ModelRunner::run_in_tflite(const std::vector &buffe return result; } +#endif bool ModelRunner::compare_results(const std::string &name_a, const std::string &name_b, const RunResult &a, const RunResult &b) { bool all_matched = true; @@ -518,18 +536,27 @@ int ModelRunner::parse_flags(int argc, char **argv, std::vector &fi } for (char c : value) { switch (c) { - case 't': - this->do_run[ModelRunner::kTfLite] = true; - break; case 'h': this->do_run[ModelRunner::kHannk] = true; break; +#if HANNK_BUILD_TFLITE + case 't': + this->do_run[ModelRunner::kTfLite] = true; + break; case 'x': this->do_run[ModelRunner::kExternalDelegate] = true; break; case 'i': this->do_run[ModelRunner::kInternalDelegate] = true; break; +#else + case 't': + case 'x': + case 'i': + std::cerr << "Unsupported option to --enable (TFLite is not enabled in this build): " << c << "\n"; + return -1; + break; +#endif default: std::cerr << "Unknown option to --enable: " << c << "\n"; return -1; @@ -623,6 +650,7 @@ void ModelRunner::run(const std::string &filename) { const std::vector buffer = read_entire_file(filename); +#if HANNK_BUILD_TFLITE const auto exec_tflite = [this, &buffer]() { return run_in_tflite(buffer); }; @@ -648,9 +676,18 @@ void ModelRunner::run(const std::string &filename) { {kExternalDelegate, exec_hannk_external_delegate}, {kInternalDelegate, exec_hannk_internal_delegate}, }; +#endif for (WhichRun i : active_runs) { +#if HANNK_BUILD_TFLITE results[i] = execs.at(i)(); +#else + if (i != kHannk) { + std::cerr << "Only kHannk is available in this build.\n"; + exit(1); + } + results[i] = run_in_hannk(buffer); +#endif } // ----- Log benchmark times diff --git a/apps/hannk/util/model_runner.h b/apps/hannk/util/model_runner.h index e3ef8e77367b..b98fd40bd890 100644 --- a/apps/hannk/util/model_runner.h +++ b/apps/hannk/util/model_runner.h @@ -10,10 +10,12 @@ #include "util/buffer_util.h" +#if HANNK_BUILD_TFLITE struct TfLiteDelegate; struct TfLiteInterpreter; struct TfLiteInterpreterOptions; struct TfLiteModel; +#endif namespace hannk { @@ -60,6 +62,7 @@ struct SeedTracker { SeedTracker &operator=(SeedTracker &&) = delete; }; +#if HANNK_BUILD_TFLITE class TfLiteModelRunner { TfLiteModel *tf_model_ = nullptr; TfLiteInterpreterOptions *tf_options_ = nullptr; @@ -84,6 +87,7 @@ class TfLiteModelRunner { TfLiteModelRunner(TfLiteModelRunner &&) = delete; TfLiteModelRunner &operator=(TfLiteModelRunner &&) = delete; }; +#endif // TODO: add a way to bottleneck stdout/stdout, or just errors/warnings in general struct ModelRunner { @@ -130,7 +134,9 @@ struct ModelRunner { std::chrono::duration time{0}; }; RunResult run_in_hannk(const std::vector &buffer); +#if HANNK_BUILD_TFLITE RunResult run_in_tflite(const std::vector &buffer, TfLiteDelegate *delegate = nullptr); +#endif bool compare_results(const std::string &name_a, const std::string &name_b, const RunResult &a, const RunResult &b); };