diff --git a/apps/hannk/CMakeLists.txt b/apps/hannk/CMakeLists.txt index c93a1058e090..0bce2f9a81e7 100644 --- a/apps/hannk/CMakeLists.txt +++ b/apps/hannk/CMakeLists.txt @@ -6,7 +6,7 @@ set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) enable_testing() -option(HANNK_BUILD_TFLITE "Build TFLite Delegate for HANNK" ON) +option(HANNK_BUILD_TFLITE "Build TFLite+Delegate for HANNK" ON) if (HANNK_BUILD_TFLITE AND (Halide_TARGET MATCHES "wasm")) message(FATAL_ERROR "HANNK_BUILD_TFLITE must be OFF when targeting wasm") endif() diff --git a/apps/hannk/delegate/hannk_delegate.h b/apps/hannk/delegate/hannk_delegate.h index 877ec8053b33..2971846720c1 100644 --- a/apps/hannk/delegate/hannk_delegate.h +++ b/apps/hannk/delegate/hannk_delegate.h @@ -2,7 +2,7 @@ #define HANNK_DELEGATE_H #if !HANNK_BUILD_TFLITE - #error "This file should not be included when HANNK_BUILD_TFLITE=0" +#error "This file should not be included when HANNK_BUILD_TFLITE=0" #endif #include "tensorflow/lite/c/c_api.h" diff --git a/dependencies/wasm/CMakeLists.txt b/dependencies/wasm/CMakeLists.txt index d6800f83c9b0..838ba0274a76 100644 --- a/dependencies/wasm/CMakeLists.txt +++ b/dependencies/wasm/CMakeLists.txt @@ -91,9 +91,8 @@ function(add_wasm_executable TARGET) -Wsuggest-override -s ASSERTIONS=1 -s ALLOW_MEMORY_GROWTH=1 - -s WASM_BIGINT=1 - -s STANDALONE_WASM=1 - -s ENVIRONMENT=node) + -s ENVIRONMENT=node + ) set(SRCS) foreach (S IN LISTS args_SRCS) diff --git a/src/CodeGen_ARM.cpp b/src/CodeGen_ARM.cpp index 0c6208cbb7dd..02b2ebac4bf2 100644 --- a/src/CodeGen_ARM.cpp +++ b/src/CodeGen_ARM.cpp @@ -580,15 +580,19 @@ const ArmIntrinsic intrinsic_defs[] = { {nullptr, "udot.v4i32.v16i8", UInt(32, 4), "dot_product", {UInt(32, 4), UInt(8, 16), UInt(8, 16)}, ArmIntrinsic::NoMangle}, // ABDL - Widening absolute difference - // Need to be able to handle both signed and unsigned outputs for signed inputs. + // The ARM backend folds both signed and unsigned widening casts of absd to a widening_absd, so we need to handle both signed and + // unsigned input and return types. {"vabdl_i8x8", "vabdl_i8x8", Int(16, 8), "widening_absd", {Int(8, 8), Int(8, 8)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix}, {"vabdl_i8x8", "vabdl_i8x8", UInt(16, 8), "widening_absd", {Int(8, 8), Int(8, 8)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix}, + {"vabdl_u8x8", "vabdl_u8x8", Int(16, 8), "widening_absd", {UInt(8, 8), UInt(8, 8)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix}, {"vabdl_u8x8", "vabdl_u8x8", UInt(16, 8), "widening_absd", {UInt(8, 8), UInt(8, 8)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix}, {"vabdl_i16x4", "vabdl_i16x4", Int(32, 4), "widening_absd", {Int(16, 4), Int(16, 4)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix}, {"vabdl_i16x4", "vabdl_i16x4", UInt(32, 4), "widening_absd", {Int(16, 4), Int(16, 4)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix}, + {"vabdl_u16x4", "vabdl_u16x4", Int(32, 4), "widening_absd", {UInt(16, 4), UInt(16, 4)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix}, {"vabdl_u16x4", "vabdl_u16x4", UInt(32, 4), "widening_absd", {UInt(16, 4), UInt(16, 4)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix}, {"vabdl_i32x2", "vabdl_i32x2", Int(64, 2), "widening_absd", {Int(32, 2), Int(32, 2)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix}, {"vabdl_i32x2", "vabdl_i32x2", UInt(64, 2), "widening_absd", {Int(32, 2), Int(32, 2)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix}, + {"vabdl_u32x2", "vabdl_u32x2", Int(64, 2), "widening_absd", {UInt(32, 2), UInt(32, 2)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix}, {"vabdl_u32x2", "vabdl_u32x2", UInt(64, 2), "widening_absd", {UInt(32, 2), UInt(32, 2)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix}, }; diff --git a/src/JITModule.cpp b/src/JITModule.cpp index 0033cb212c16..acb8be5da8c7 100644 --- a/src/JITModule.cpp +++ b/src/JITModule.cpp @@ -506,63 +506,72 @@ void merge_handlers(JITHandlers &base, const JITHandlers &addins) { if (addins.custom_get_library_symbol) { base.custom_get_library_symbol = addins.custom_get_library_symbol; } + if (addins.custom_cuda_acquire_context) { + base.custom_cuda_acquire_context = addins.custom_cuda_acquire_context; + } + if (addins.custom_cuda_release_context) { + base.custom_cuda_release_context = addins.custom_cuda_release_context; + } + if (addins.custom_cuda_get_stream) { + base.custom_cuda_get_stream = addins.custom_cuda_get_stream; + } } void print_handler(JITUserContext *context, const char *msg) { - if (context) { - (*context->handlers.custom_print)(context, msg); + if (context && context->handlers.custom_print) { + context->handlers.custom_print(context, msg); } else { - return (*active_handlers.custom_print)(context, msg); + return active_handlers.custom_print(context, msg); } } void *malloc_handler(JITUserContext *context, size_t x) { - if (context) { - return (*context->handlers.custom_malloc)(context, x); + if (context && context->handlers.custom_malloc) { + return context->handlers.custom_malloc(context, x); } else { - return (*active_handlers.custom_malloc)(context, x); + return active_handlers.custom_malloc(context, x); } } void free_handler(JITUserContext *context, void *ptr) { - if (context) { - (*context->handlers.custom_free)(context, ptr); + if (context && context->handlers.custom_free) { + context->handlers.custom_free(context, ptr); } else { - (*active_handlers.custom_free)(context, ptr); + active_handlers.custom_free(context, ptr); } } int do_task_handler(JITUserContext *context, int (*f)(JITUserContext *, int, uint8_t *), int idx, uint8_t *closure) { - if (context) { - return (*context->handlers.custom_do_task)(context, f, idx, closure); + if (context && context->handlers.custom_do_task) { + return context->handlers.custom_do_task(context, f, idx, closure); } else { - return (*active_handlers.custom_do_task)(context, f, idx, closure); + return active_handlers.custom_do_task(context, f, idx, closure); } } int do_par_for_handler(JITUserContext *context, int (*f)(JITUserContext *, int, uint8_t *), int min, int size, uint8_t *closure) { - if (context) { - return (*context->handlers.custom_do_par_for)(context, f, min, size, closure); + if (context && context->handlers.custom_do_par_for) { + return context->handlers.custom_do_par_for(context, f, min, size, closure); } else { - return (*active_handlers.custom_do_par_for)(context, f, min, size, closure); + return active_handlers.custom_do_par_for(context, f, min, size, closure); } } void error_handler_handler(JITUserContext *context, const char *msg) { - if (context) { - (*context->handlers.custom_error)(context, msg); + if (context && context->handlers.custom_error) { + context->handlers.custom_error(context, msg); } else { - (*active_handlers.custom_error)(context, msg); + active_handlers.custom_error(context, msg); } } int32_t trace_handler(JITUserContext *context, const halide_trace_event_t *e) { - if (context) { - return (*context->handlers.custom_trace)(context, e); + if (context && context->handlers.custom_trace) { + return context->handlers.custom_trace(context, e); } else { - return (*active_handlers.custom_trace)(context, e); + return active_handlers.custom_trace(context, e); } } @@ -578,6 +587,30 @@ void *get_library_symbol_handler(void *lib, const char *name) { return (*active_handlers.custom_get_library_symbol)(lib, name); } +int cuda_acquire_context_handler(JITUserContext *context, void **cuda_context_ptr, bool create) { + if (context && context->handlers.custom_cuda_acquire_context) { + return context->handlers.custom_cuda_acquire_context(context, cuda_context_ptr, create); + } else { + return active_handlers.custom_cuda_acquire_context(context, cuda_context_ptr, create); + } +} + +int cuda_release_context_handler(JITUserContext *context) { + if (context && context->handlers.custom_cuda_release_context) { + return context->handlers.custom_cuda_release_context(context); + } else { + return active_handlers.custom_cuda_release_context(context); + } +} + +int cuda_get_stream_handler(JITUserContext *context, void *cuda_context, void **cuda_stream_ptr) { + if (context && context->handlers.custom_cuda_get_stream) { + return context->handlers.custom_cuda_get_stream(context, cuda_context, cuda_stream_ptr); + } else { + return active_handlers.custom_cuda_get_stream(context, cuda_context, cuda_stream_ptr); + } +} + template function_t hook_function(const std::map &exports, const char *hook_name, function_t hook) { auto iter = exports.find(hook_name); @@ -776,13 +809,13 @@ JITModule &make_module(llvm::Module *for_module, Target target, hook_function(runtime.exports(), "halide_set_custom_trace", trace_handler); runtime_internal_handlers.custom_get_symbol = - hook_function(shared_runtimes(MainShared).exports(), "halide_set_custom_get_symbol", get_symbol_handler); + hook_function(runtime.exports(), "halide_set_custom_get_symbol", get_symbol_handler); runtime_internal_handlers.custom_load_library = - hook_function(shared_runtimes(MainShared).exports(), "halide_set_custom_load_library", load_library_handler); + hook_function(runtime.exports(), "halide_set_custom_load_library", load_library_handler); runtime_internal_handlers.custom_get_library_symbol = - hook_function(shared_runtimes(MainShared).exports(), "halide_set_custom_get_library_symbol", get_library_symbol_handler); + hook_function(runtime.exports(), "halide_set_custom_get_library_symbol", get_library_symbol_handler); active_handlers = runtime_internal_handlers; merge_handlers(active_handlers, default_handlers); @@ -794,6 +827,41 @@ JITModule &make_module(llvm::Module *for_module, Target target, runtime.jit_module->name = "MainShared"; } else { runtime.jit_module->name = "GPU"; + + // There are two versions of these cuda context + // management handlers we could use - one in the cuda + // module, and one in the cuda-debug module. If both + // modules are in use, we'll just want to use one of + // them, so that we don't needlessly create two cuda + // contexts. We'll use whichever was first + // created. The second one will then declare a + // dependency on the first one, to make sure things + // are destroyed in the correct order. + + if (runtime_kind == CUDA || runtime_kind == CUDADebug) { + if (!runtime_internal_handlers.custom_cuda_acquire_context) { + // Neither module has been created. + runtime_internal_handlers.custom_cuda_acquire_context = + hook_function(runtime.exports(), "halide_set_cuda_acquire_context", cuda_acquire_context_handler); + + runtime_internal_handlers.custom_cuda_release_context = + hook_function(runtime.exports(), "halide_set_cuda_release_context", cuda_release_context_handler); + + runtime_internal_handlers.custom_cuda_get_stream = + hook_function(runtime.exports(), "halide_set_cuda_get_stream", cuda_get_stream_handler); + + active_handlers = runtime_internal_handlers; + merge_handlers(active_handlers, default_handlers); + } else if (runtime_kind == CUDA) { + // The CUDADebug module has already been created. + // Use the context in the CUDADebug module and add + // a dependence edge from the CUDA module to it. + runtime.add_dependency(shared_runtimes(CUDADebug)); + } else { + // The CUDA module has already been created. + runtime.add_dependency(shared_runtimes(CUDA)); + } + } } uint64_t arg_addr = diff --git a/src/JITModule.h b/src/JITModule.h index 5208d31698ac..ee27fe99216b 100644 --- a/src/JITModule.h +++ b/src/JITModule.h @@ -106,6 +106,22 @@ struct JITHandlers { * an opened library. Equivalent to dlsym. Takes a handle * returned by custom_load_library as the first argument. */ void *(*custom_get_library_symbol)(void *lib, const char *name){nullptr}; + + /** A custom method for the Halide runtime acquire a cuda + * context. The cuda context is treated as a void * to avoid a + * dependence on the cuda headers. If the create argument is set + * to true, a context should be created if one does not already + * exist. */ + int32_t (*custom_cuda_acquire_context)(JITUserContext *user_context, void **cuda_context_ptr, bool create){nullptr}; + + /** The Halide runtime calls this when it is done with a cuda + * context. The default implementation does nothing. */ + int32_t (*custom_cuda_release_context)(JITUserContext *user_context){nullptr}; + + /** A custom method for the Halide runtime to acquire a cuda + * stream to use. The cuda context and stream are both modelled + * as a void *, to avoid a dependence on the cuda headers. */ + int32_t (*custom_cuda_get_stream)(JITUserContext *user_context, void *cuda_context, void **stream_ptr){nullptr}; }; namespace Internal { diff --git a/src/Pipeline.cpp b/src/Pipeline.cpp index 06fa9eb34003..eebeb86dea80 100644 --- a/src/Pipeline.cpp +++ b/src/Pipeline.cpp @@ -757,7 +757,7 @@ Realization Pipeline::realize(JITUserContext *context, if (needs_crop) { r[i].crop(crop); } - r[i].copy_to_host(); + r[i].copy_to_host(context); } return r; } diff --git a/src/runtime/HalideRuntimeCuda.h b/src/runtime/HalideRuntimeCuda.h index 3eff7d834543..239f6e698561 100644 --- a/src/runtime/HalideRuntimeCuda.h +++ b/src/runtime/HalideRuntimeCuda.h @@ -65,6 +65,23 @@ extern uintptr_t halide_cuda_get_device_ptr(void *user_context, struct halide_bu * driver. See halide_reuse_device_allocations. */ extern int halide_cuda_release_unused_device_allocations(void *user_context); +// These typedefs treat both a CUcontext and a CUstream as a void *, +// to avoid dependencies on cuda headers. +typedef int (*halide_cuda_acquire_context_t)(void *, // user_context + void **, // cuda context out parameter + bool); // should create a context if none exist +typedef int (*halide_cuda_release_context_t)(void * /* user_context */); +typedef int (*halide_cuda_get_stream_t)(void *, // user_context + void *, // context + void **); // stream out parameter + +/** Set custom methods to acquire and release cuda contexts and streams */ +// @{ +extern halide_cuda_acquire_context_t halide_set_cuda_acquire_context(halide_cuda_acquire_context_t handler); +extern halide_cuda_release_context_t halide_set_cuda_release_context(halide_cuda_release_context_t handler); +extern halide_cuda_get_stream_t halide_set_cuda_get_stream(halide_cuda_get_stream_t handler); +// @} + #ifdef __cplusplus } // End extern "C" #endif diff --git a/src/runtime/cuda.cpp b/src/runtime/cuda.cpp index 11cc906bf316..eee52c1e4804 100644 --- a/src/runtime/cuda.cpp +++ b/src/runtime/cuda.cpp @@ -136,7 +136,7 @@ extern "C" { // - A call to halide_cuda_acquire_context is followed by a matching call to // halide_cuda_release_context. halide_cuda_acquire_context should block while a // previous call (if any) has not yet been released via halide_cuda_release_context. -WEAK int halide_cuda_acquire_context(void *user_context, CUcontext *ctx, bool create = true) { +WEAK int halide_default_cuda_acquire_context(void *user_context, CUcontext *ctx, bool create = true) { // TODO: Should we use a more "assertive" assert? these asserts do // not block execution on failure. halide_assert(user_context, ctx != nullptr); @@ -179,7 +179,7 @@ WEAK int halide_cuda_acquire_context(void *user_context, CUcontext *ctx, bool cr return 0; } -WEAK int halide_cuda_release_context(void *user_context) { +WEAK int halide_default_cuda_release_context(void *user_context) { return 0; } @@ -188,7 +188,7 @@ WEAK int halide_cuda_release_context(void *user_context) { // for the context (nullptr stream). The context is passed in for convenience, but // any sort of scoping must be handled by that of the // halide_cuda_acquire_context/halide_cuda_release_context pair, not this call. -WEAK int halide_cuda_get_stream(void *user_context, CUcontext ctx, CUstream *stream) { +WEAK int halide_default_cuda_get_stream(void *user_context, CUcontext ctx, CUstream *stream) { // There are two default streams we could use. stream 0 is fully // synchronous. stream 2 gives a separate non-blocking stream per // thread. @@ -198,6 +198,53 @@ WEAK int halide_cuda_get_stream(void *user_context, CUcontext ctx, CUstream *str } // extern "C" +namespace Halide { +namespace Runtime { +namespace Internal { +namespace CUDA { + +WEAK halide_cuda_acquire_context_t acquire_context = (halide_cuda_acquire_context_t)halide_default_cuda_acquire_context; +WEAK halide_cuda_release_context_t release_context = (halide_cuda_release_context_t)halide_default_cuda_release_context; +WEAK halide_cuda_get_stream_t get_stream = (halide_cuda_get_stream_t)halide_default_cuda_get_stream; + +} // namespace CUDA +} // namespace Internal +} // namespace Runtime +} // namespace Halide + +extern "C" { + +WEAK int halide_cuda_acquire_context(void *user_context, CUcontext *ctx, bool create = true) { + return CUDA::acquire_context(user_context, (void **)ctx, create); +} + +WEAK halide_cuda_acquire_context_t halide_set_cuda_acquire_context(halide_cuda_acquire_context_t handler) { + halide_cuda_acquire_context_t result = CUDA::acquire_context; + CUDA::acquire_context = handler; + return result; +} + +WEAK int halide_cuda_release_context(void *user_context) { + return CUDA::release_context(user_context); +} + +WEAK halide_cuda_release_context_t halide_set_cuda_release_context(halide_cuda_release_context_t handler) { + halide_cuda_release_context_t result = CUDA::release_context; + CUDA::release_context = handler; + return result; +} + +WEAK int halide_cuda_get_stream(void *user_context, CUcontext ctx, CUstream *stream) { + return CUDA::get_stream(user_context, (void *)ctx, (void **)stream); +} + +WEAK halide_cuda_get_stream_t halide_set_cuda_get_stream(halide_cuda_get_stream_t handler) { + halide_cuda_get_stream_t result = CUDA::get_stream; + CUDA::get_stream = handler; + return result; +} +} + namespace Halide { namespace Runtime { namespace Internal { @@ -845,7 +892,8 @@ WEAK int halide_cuda_device_malloc(void *user_context, halide_buffer_t *buf) { namespace { WEAK int cuda_do_multidimensional_copy(void *user_context, const device_copy &c, - uint64_t src, uint64_t dst, int d, bool from_host, bool to_host) { + uint64_t src, uint64_t dst, int d, bool from_host, bool to_host, + CUstream stream) { if (d > MAX_COPY_DIMS) { error(user_context) << "Buffer has too many dimensions to copy to/from GPU\n"; return -1; @@ -858,15 +906,27 @@ WEAK int cuda_do_multidimensional_copy(void *user_context, const device_copy &c, if (!from_host && to_host) { debug(user_context) << "cuMemcpyDtoH(" << (void *)dst << ", " << (void *)src << ", " << c.chunk_size << ")\n"; copy_name = "cuMemcpyDtoH"; - err = cuMemcpyDtoH((void *)dst, (CUdeviceptr)src, c.chunk_size); + if (stream) { + err = cuMemcpyDtoHAsync((void *)dst, (CUdeviceptr)src, c.chunk_size, stream); + } else { + err = cuMemcpyDtoH((void *)dst, (CUdeviceptr)src, c.chunk_size); + } } else if (from_host && !to_host) { debug(user_context) << "cuMemcpyHtoD(" << (void *)dst << ", " << (void *)src << ", " << c.chunk_size << ")\n"; copy_name = "cuMemcpyHtoD"; - err = cuMemcpyHtoD((CUdeviceptr)dst, (void *)src, c.chunk_size); + if (stream) { + err = cuMemcpyHtoDAsync((CUdeviceptr)dst, (void *)src, c.chunk_size, stream); + } else { + err = cuMemcpyHtoD((CUdeviceptr)dst, (void *)src, c.chunk_size); + } } else if (!from_host && !to_host) { debug(user_context) << "cuMemcpyDtoD(" << (void *)dst << ", " << (void *)src << ", " << c.chunk_size << ")\n"; copy_name = "cuMemcpyDtoD"; - err = cuMemcpyDtoD((CUdeviceptr)dst, (CUdeviceptr)src, c.chunk_size); + if (stream) { + err = cuMemcpyDtoDAsync((CUdeviceptr)dst, (CUdeviceptr)src, c.chunk_size, stream); + } else { + err = cuMemcpyDtoD((CUdeviceptr)dst, (CUdeviceptr)src, c.chunk_size); + } } else if (dst != src) { debug(user_context) << "memcpy(" << (void *)dst << ", " << (void *)src << ", " << c.chunk_size << ")\n"; // Could reach here if a user called directly into the @@ -881,7 +941,7 @@ WEAK int cuda_do_multidimensional_copy(void *user_context, const device_copy &c, } else { ssize_t src_off = 0, dst_off = 0; for (int i = 0; i < (int)c.extent[d - 1]; i++) { - int err = cuda_do_multidimensional_copy(user_context, c, src + src_off, dst + dst_off, d - 1, from_host, to_host); + int err = cuda_do_multidimensional_copy(user_context, c, src + src_off, dst + dst_off, d - 1, from_host, to_host, stream); dst_off += c.dst_stride_bytes[d - 1]; src_off += c.src_stride_bytes[d - 1]; if (err) { @@ -938,7 +998,15 @@ WEAK int halide_cuda_buffer_copy(void *user_context, struct halide_buffer_t *src } #endif - err = cuda_do_multidimensional_copy(user_context, c, c.src + c.src_begin, c.dst, dst->dimensions, from_host, to_host); + CUstream stream = nullptr; + if (cuStreamSynchronize != nullptr) { + int result = halide_cuda_get_stream(user_context, ctx.context, &stream); + if (result != 0) { + error(user_context) << "CUDA: In cuda_do_multidimensional_copy, halide_cuda_get_stream returned " << result << "\n"; + } + } + + err = cuda_do_multidimensional_copy(user_context, c, c.src + c.src_begin, c.dst, dst->dimensions, from_host, to_host, stream); #ifdef DEBUG_RUNTIME uint64_t t_after = halide_current_time_ns(user_context); diff --git a/src/runtime/cuda_functions.h b/src/runtime/cuda_functions.h index 61614ea89805..5242e931686a 100644 --- a/src/runtime/cuda_functions.h +++ b/src/runtime/cuda_functions.h @@ -36,6 +36,11 @@ CUDA_FN_3020(CUresult, cuMemFree, cuMemFree_v2, (CUdeviceptr dptr)); CUDA_FN_3020(CUresult, cuMemcpyHtoD, cuMemcpyHtoD_v2, (CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount)); CUDA_FN_3020(CUresult, cuMemcpyDtoH, cuMemcpyDtoH_v2, (void *dstHost, CUdeviceptr srcDevice, size_t ByteCount)); CUDA_FN_3020(CUresult, cuMemcpyDtoD, cuMemcpyDtoD_v2, (CUdeviceptr dstHost, CUdeviceptr srcDevice, size_t ByteCount)); + +CUDA_FN_3020(CUresult, cuMemcpyHtoDAsync, cuMemcpyHtoDAsync_v2, (CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount, CUstream stream)); +CUDA_FN_3020(CUresult, cuMemcpyDtoHAsync, cuMemcpyDtoHAsync_v2, (void *dstHost, CUdeviceptr srcDevice, size_t ByteCount, CUstream stream)); +CUDA_FN_3020(CUresult, cuMemcpyDtoDAsync, cuMemcpyDtoDAsync_v2, (CUdeviceptr dstHost, CUdeviceptr srcDevice, size_t ByteCount, CUstream stream)); + CUDA_FN_3020(CUresult, cuMemcpy3D, cuMemcpy3D_v2, (const CUDA_MEMCPY3D *pCopy)); CUDA_FN(CUresult, cuLaunchKernel, (CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void **kernelParams, void **extra)); CUDA_FN(CUresult, cuCtxSynchronize, ()); diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index bb4bb8908f8e..0cf9332376ad 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -66,6 +66,7 @@ tests(GROUPS correctness cuda_8_bit_dot_product.cpp custom_allocator.cpp custom_auto_scheduler.cpp + custom_cuda_context.cpp custom_error_reporter.cpp custom_jit_context.cpp custom_lowering_pass.cpp diff --git a/test/correctness/custom_cuda_context.cpp b/test/correctness/custom_cuda_context.cpp new file mode 100644 index 000000000000..d9aaa76c96cd --- /dev/null +++ b/test/correctness/custom_cuda_context.cpp @@ -0,0 +1,178 @@ +#include "Halide.h" + +using namespace Halide; + +int (*cuStreamCreate)(void **, uint32_t) = nullptr; +int (*cuCtxCreate)(void **, uint32_t, int) = nullptr; +int (*cuCtxDestroy)(void *) = nullptr; +int (*cuMemAlloc)(void **, size_t) = nullptr; +int (*cuMemFree)(void *) = nullptr; +int (*cuCtxSetCurrent)(void *) = nullptr; + +struct CudaState : public Halide::JITUserContext { + void *cuda_context = nullptr, *cuda_stream = nullptr; + std::atomic acquires = 0, releases = 0; + + static int my_cuda_acquire_context(JITUserContext *ctx, void **cuda_ctx, bool create) { + CudaState *state = (CudaState *)ctx; + *cuda_ctx = state->cuda_context; + state->acquires++; + return 0; + } + + static int my_cuda_release_context(JITUserContext *ctx) { + CudaState *state = (CudaState *)ctx; + state->releases++; + return 0; + } + + static int my_cuda_get_stream(JITUserContext *ctx, void *cuda_ctx, void **stream) { + CudaState *state = (CudaState *)ctx; + *stream = state->cuda_stream; + return 0; + } + + CudaState() { + handlers.custom_cuda_acquire_context = my_cuda_acquire_context; + handlers.custom_cuda_release_context = my_cuda_release_context; + handlers.custom_cuda_get_stream = my_cuda_get_stream; + } +}; + +int main(int argc, char **argv) { + Target target = get_jit_target_from_environment(); + if (!target.has_feature(Target::CUDA)) { + printf("[SKIP] CUDA not enabled.\n"); + return 0; + } + + { + // Do some nonsense to get symbols out of libcuda without + // needing the CUDA sdk. This would not be a concern in a real + // cuda-using application but is helpful for our + // build-and-test infrastructure. + + // We'll find cuda module in the Halide runtime so + // that we can use it resolve symbols into libcuda in a + // portable way. + + // Force-initialize the cuda runtime module by running something trivial. + evaluate_may_gpu(Expr(0.f)); + + // Go get it, and dig out the method used to resolve symbols in libcuda. + auto runtime_modules = Internal::JITSharedRuntime::get(nullptr, target, false); + void *(*halide_cuda_get_symbol)(void *, const char *) = nullptr; + for (Internal::JITModule &m : runtime_modules) { + // Just rifle through all the runtime modules for this + // target until we find the method we want. + auto sym = m.find_symbol_by_name("halide_cuda_get_symbol"); + if (sym.address != nullptr) { + halide_cuda_get_symbol = (decltype(halide_cuda_get_symbol))sym.address; + break; + } + } + + if (halide_cuda_get_symbol == nullptr) { + printf("Failed to extract halide_cuda_get_symbol from Halide cuda runtime\n"); + return -1; + } + + // Go get the CUDA API functions we actually intend to use. + cuStreamCreate = (decltype(cuStreamCreate))halide_cuda_get_symbol(nullptr, "cuStreamCreate"); + cuCtxCreate = (decltype(cuCtxCreate))halide_cuda_get_symbol(nullptr, "cuCtxCreate_v2"); + cuCtxDestroy = (decltype(cuCtxDestroy))halide_cuda_get_symbol(nullptr, "cuCtxDestroy_v2"); + cuCtxSetCurrent = (decltype(cuCtxSetCurrent))halide_cuda_get_symbol(nullptr, "cuCtxSetCurrent"); + cuMemAlloc = (decltype(cuMemAlloc))halide_cuda_get_symbol(nullptr, "cuMemAlloc_v2"); + cuMemFree = (decltype(cuMemFree))halide_cuda_get_symbol(nullptr, "cuMemFree_v2"); + + if (cuStreamCreate == nullptr || + cuCtxCreate == nullptr || + cuCtxDestroy == nullptr || + cuCtxSetCurrent == nullptr || + cuMemAlloc == nullptr || + cuMemFree == nullptr) { + printf("Failed to find cuda API\n"); + return -1; + } + } + + // Make a cuda context and stream. + CudaState state; + int err = cuCtxCreate(&state.cuda_context, 0, 0); + if (state.cuda_context == nullptr) { + printf("Failed to initialize context: %d\n", err); + return -1; + } + + err = cuCtxSetCurrent(state.cuda_context); + if (err) { + printf("Failed to set context: %d\n", err); + return -1; + } + + err = cuStreamCreate(&state.cuda_stream, 1 /* non-blocking */); + if (state.cuda_stream == nullptr) { + printf("Failed to initialize stream: %d\n", err); + return -1; + } + + // Allocate some GPU memory on this context + const int width = 32, height = 1024; + + void *ptr = nullptr; + err = cuMemAlloc(&ptr, width * height * sizeof(float)); + + if (ptr == nullptr) { + printf("cuMemAlloc failed: %d\n", err); + return -1; + } + + // Wrap a Halide buffer around it, with some host memory too. + Buffer in(width, height); + in.fill(4.0f); + auto device_interface = get_device_interface_for_device_api(DeviceAPI::CUDA); + in.device_wrap_native(device_interface, + (uintptr_t)ptr, &state); + in.copy_to_device(device_interface, &state); + + // Run a kernel on multiple threads that copies slices of it into + // a Halide-allocated temporary buffer. This would likely crash + // if we don't allocate the outputs on the right context. If the + // copies don't happen on the same stream as the compute, we'll + // get incorrect outputs due to race conditions. + Func f, g; + Var x, xi, y; + f(x, y) = sqrt(in(x, y)); + g(x, y) = f(x, y); + f.gpu_tile(x, x, xi, 32).compute_at(g, y); + g.parallel(y); + + for (int i = 0; i < 10; i++) { + Buffer out = g.realize(&state, {width, height}); + out.copy_to_host(&state); + for (int y = 0; y < height; y++) { + for (int x = 0; x < width; x++) { + float correct = 2.0f; + if (out(x, y) != 2.0f) { + printf("out(%d, %d) = %f instead of %f\n", x, y, out(x, y), correct); + return -1; + } + } + } + } + + // Clean up + in.device_detach_native(&state); + cuMemFree(ptr); + cuCtxDestroy(state.cuda_stream); + + if (state.acquires.load() != state.releases.load() || + state.acquires.load() < height) { + printf("Context acquires: %d releases: %d\n", state.acquires.load(), state.releases.load()); + printf("Expected these to match and be at least %d (the number of parallel tasks)\n", height); + return -1; + } + + printf("Success!\n"); + return 0; +}