Skip to content

Commit

Permalink
Merge branch 'master' into srj/hannk-wasm-build
Browse files Browse the repository at this point in the history
  • Loading branch information
steven-johnson committed Oct 28, 2021
2 parents 6e92475 + 541bc37 commit 0829d1d
Show file tree
Hide file tree
Showing 12 changed files with 396 additions and 40 deletions.
2 changes: 1 addition & 1 deletion apps/hannk/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)

enable_testing()

option(HANNK_BUILD_TFLITE "Build TFLite Delegate for HANNK" ON)
option(HANNK_BUILD_TFLITE "Build TFLite+Delegate for HANNK" ON)
if (HANNK_BUILD_TFLITE AND (Halide_TARGET MATCHES "wasm"))
message(FATAL_ERROR "HANNK_BUILD_TFLITE must be OFF when targeting wasm")
endif()
Expand Down
2 changes: 1 addition & 1 deletion apps/hannk/delegate/hannk_delegate.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#define HANNK_DELEGATE_H

#if !HANNK_BUILD_TFLITE
#error "This file should not be included when HANNK_BUILD_TFLITE=0"
#error "This file should not be included when HANNK_BUILD_TFLITE=0"
#endif

#include "tensorflow/lite/c/c_api.h"
Expand Down
5 changes: 2 additions & 3 deletions dependencies/wasm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,8 @@ function(add_wasm_executable TARGET)
-Wsuggest-override
-s ASSERTIONS=1
-s ALLOW_MEMORY_GROWTH=1
-s WASM_BIGINT=1
-s STANDALONE_WASM=1
-s ENVIRONMENT=node)
-s ENVIRONMENT=node
)

set(SRCS)
foreach (S IN LISTS args_SRCS)
Expand Down
6 changes: 5 additions & 1 deletion src/CodeGen_ARM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -580,15 +580,19 @@ const ArmIntrinsic intrinsic_defs[] = {
{nullptr, "udot.v4i32.v16i8", UInt(32, 4), "dot_product", {UInt(32, 4), UInt(8, 16), UInt(8, 16)}, ArmIntrinsic::NoMangle},

// ABDL - Widening absolute difference
// Need to be able to handle both signed and unsigned outputs for signed inputs.
// The ARM backend folds both signed and unsigned widening casts of absd to a widening_absd, so we need to handle both signed and
// unsigned input and return types.
{"vabdl_i8x8", "vabdl_i8x8", Int(16, 8), "widening_absd", {Int(8, 8), Int(8, 8)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix},
{"vabdl_i8x8", "vabdl_i8x8", UInt(16, 8), "widening_absd", {Int(8, 8), Int(8, 8)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix},
{"vabdl_u8x8", "vabdl_u8x8", Int(16, 8), "widening_absd", {UInt(8, 8), UInt(8, 8)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix},
{"vabdl_u8x8", "vabdl_u8x8", UInt(16, 8), "widening_absd", {UInt(8, 8), UInt(8, 8)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix},
{"vabdl_i16x4", "vabdl_i16x4", Int(32, 4), "widening_absd", {Int(16, 4), Int(16, 4)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix},
{"vabdl_i16x4", "vabdl_i16x4", UInt(32, 4), "widening_absd", {Int(16, 4), Int(16, 4)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix},
{"vabdl_u16x4", "vabdl_u16x4", Int(32, 4), "widening_absd", {UInt(16, 4), UInt(16, 4)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix},
{"vabdl_u16x4", "vabdl_u16x4", UInt(32, 4), "widening_absd", {UInt(16, 4), UInt(16, 4)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix},
{"vabdl_i32x2", "vabdl_i32x2", Int(64, 2), "widening_absd", {Int(32, 2), Int(32, 2)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix},
{"vabdl_i32x2", "vabdl_i32x2", UInt(64, 2), "widening_absd", {Int(32, 2), Int(32, 2)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix},
{"vabdl_u32x2", "vabdl_u32x2", Int(64, 2), "widening_absd", {UInt(32, 2), UInt(32, 2)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix},
{"vabdl_u32x2", "vabdl_u32x2", UInt(64, 2), "widening_absd", {UInt(32, 2), UInt(32, 2)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix},
};

Expand Down
116 changes: 92 additions & 24 deletions src/JITModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -506,63 +506,72 @@ void merge_handlers(JITHandlers &base, const JITHandlers &addins) {
if (addins.custom_get_library_symbol) {
base.custom_get_library_symbol = addins.custom_get_library_symbol;
}
if (addins.custom_cuda_acquire_context) {
base.custom_cuda_acquire_context = addins.custom_cuda_acquire_context;
}
if (addins.custom_cuda_release_context) {
base.custom_cuda_release_context = addins.custom_cuda_release_context;
}
if (addins.custom_cuda_get_stream) {
base.custom_cuda_get_stream = addins.custom_cuda_get_stream;
}
}

void print_handler(JITUserContext *context, const char *msg) {
if (context) {
(*context->handlers.custom_print)(context, msg);
if (context && context->handlers.custom_print) {
context->handlers.custom_print(context, msg);
} else {
return (*active_handlers.custom_print)(context, msg);
return active_handlers.custom_print(context, msg);
}
}

void *malloc_handler(JITUserContext *context, size_t x) {
if (context) {
return (*context->handlers.custom_malloc)(context, x);
if (context && context->handlers.custom_malloc) {
return context->handlers.custom_malloc(context, x);
} else {
return (*active_handlers.custom_malloc)(context, x);
return active_handlers.custom_malloc(context, x);
}
}

void free_handler(JITUserContext *context, void *ptr) {
if (context) {
(*context->handlers.custom_free)(context, ptr);
if (context && context->handlers.custom_free) {
context->handlers.custom_free(context, ptr);
} else {
(*active_handlers.custom_free)(context, ptr);
active_handlers.custom_free(context, ptr);
}
}

int do_task_handler(JITUserContext *context, int (*f)(JITUserContext *, int, uint8_t *), int idx,
uint8_t *closure) {
if (context) {
return (*context->handlers.custom_do_task)(context, f, idx, closure);
if (context && context->handlers.custom_do_task) {
return context->handlers.custom_do_task(context, f, idx, closure);
} else {
return (*active_handlers.custom_do_task)(context, f, idx, closure);
return active_handlers.custom_do_task(context, f, idx, closure);
}
}

int do_par_for_handler(JITUserContext *context, int (*f)(JITUserContext *, int, uint8_t *),
int min, int size, uint8_t *closure) {
if (context) {
return (*context->handlers.custom_do_par_for)(context, f, min, size, closure);
if (context && context->handlers.custom_do_par_for) {
return context->handlers.custom_do_par_for(context, f, min, size, closure);
} else {
return (*active_handlers.custom_do_par_for)(context, f, min, size, closure);
return active_handlers.custom_do_par_for(context, f, min, size, closure);
}
}

void error_handler_handler(JITUserContext *context, const char *msg) {
if (context) {
(*context->handlers.custom_error)(context, msg);
if (context && context->handlers.custom_error) {
context->handlers.custom_error(context, msg);
} else {
(*active_handlers.custom_error)(context, msg);
active_handlers.custom_error(context, msg);
}
}

int32_t trace_handler(JITUserContext *context, const halide_trace_event_t *e) {
if (context) {
return (*context->handlers.custom_trace)(context, e);
if (context && context->handlers.custom_trace) {
return context->handlers.custom_trace(context, e);
} else {
return (*active_handlers.custom_trace)(context, e);
return active_handlers.custom_trace(context, e);
}
}

Expand All @@ -578,6 +587,30 @@ void *get_library_symbol_handler(void *lib, const char *name) {
return (*active_handlers.custom_get_library_symbol)(lib, name);
}

int cuda_acquire_context_handler(JITUserContext *context, void **cuda_context_ptr, bool create) {
if (context && context->handlers.custom_cuda_acquire_context) {
return context->handlers.custom_cuda_acquire_context(context, cuda_context_ptr, create);
} else {
return active_handlers.custom_cuda_acquire_context(context, cuda_context_ptr, create);
}
}

int cuda_release_context_handler(JITUserContext *context) {
if (context && context->handlers.custom_cuda_release_context) {
return context->handlers.custom_cuda_release_context(context);
} else {
return active_handlers.custom_cuda_release_context(context);
}
}

int cuda_get_stream_handler(JITUserContext *context, void *cuda_context, void **cuda_stream_ptr) {
if (context && context->handlers.custom_cuda_get_stream) {
return context->handlers.custom_cuda_get_stream(context, cuda_context, cuda_stream_ptr);
} else {
return active_handlers.custom_cuda_get_stream(context, cuda_context, cuda_stream_ptr);
}
}

template<typename function_t>
function_t hook_function(const std::map<std::string, JITModule::Symbol> &exports, const char *hook_name, function_t hook) {
auto iter = exports.find(hook_name);
Expand Down Expand Up @@ -776,13 +809,13 @@ JITModule &make_module(llvm::Module *for_module, Target target,
hook_function(runtime.exports(), "halide_set_custom_trace", trace_handler);

runtime_internal_handlers.custom_get_symbol =
hook_function(shared_runtimes(MainShared).exports(), "halide_set_custom_get_symbol", get_symbol_handler);
hook_function(runtime.exports(), "halide_set_custom_get_symbol", get_symbol_handler);

runtime_internal_handlers.custom_load_library =
hook_function(shared_runtimes(MainShared).exports(), "halide_set_custom_load_library", load_library_handler);
hook_function(runtime.exports(), "halide_set_custom_load_library", load_library_handler);

runtime_internal_handlers.custom_get_library_symbol =
hook_function(shared_runtimes(MainShared).exports(), "halide_set_custom_get_library_symbol", get_library_symbol_handler);
hook_function(runtime.exports(), "halide_set_custom_get_library_symbol", get_library_symbol_handler);

active_handlers = runtime_internal_handlers;
merge_handlers(active_handlers, default_handlers);
Expand All @@ -794,6 +827,41 @@ JITModule &make_module(llvm::Module *for_module, Target target,
runtime.jit_module->name = "MainShared";
} else {
runtime.jit_module->name = "GPU";

// There are two versions of these cuda context
// management handlers we could use - one in the cuda
// module, and one in the cuda-debug module. If both
// modules are in use, we'll just want to use one of
// them, so that we don't needlessly create two cuda
// contexts. We'll use whichever was first
// created. The second one will then declare a
// dependency on the first one, to make sure things
// are destroyed in the correct order.

if (runtime_kind == CUDA || runtime_kind == CUDADebug) {
if (!runtime_internal_handlers.custom_cuda_acquire_context) {
// Neither module has been created.
runtime_internal_handlers.custom_cuda_acquire_context =
hook_function(runtime.exports(), "halide_set_cuda_acquire_context", cuda_acquire_context_handler);

runtime_internal_handlers.custom_cuda_release_context =
hook_function(runtime.exports(), "halide_set_cuda_release_context", cuda_release_context_handler);

runtime_internal_handlers.custom_cuda_get_stream =
hook_function(runtime.exports(), "halide_set_cuda_get_stream", cuda_get_stream_handler);

active_handlers = runtime_internal_handlers;
merge_handlers(active_handlers, default_handlers);
} else if (runtime_kind == CUDA) {
// The CUDADebug module has already been created.
// Use the context in the CUDADebug module and add
// a dependence edge from the CUDA module to it.
runtime.add_dependency(shared_runtimes(CUDADebug));
} else {
// The CUDA module has already been created.
runtime.add_dependency(shared_runtimes(CUDA));
}
}
}

uint64_t arg_addr =
Expand Down
16 changes: 16 additions & 0 deletions src/JITModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,22 @@ struct JITHandlers {
* an opened library. Equivalent to dlsym. Takes a handle
* returned by custom_load_library as the first argument. */
void *(*custom_get_library_symbol)(void *lib, const char *name){nullptr};

/** A custom method for the Halide runtime acquire a cuda
* context. The cuda context is treated as a void * to avoid a
* dependence on the cuda headers. If the create argument is set
* to true, a context should be created if one does not already
* exist. */
int32_t (*custom_cuda_acquire_context)(JITUserContext *user_context, void **cuda_context_ptr, bool create){nullptr};

/** The Halide runtime calls this when it is done with a cuda
* context. The default implementation does nothing. */
int32_t (*custom_cuda_release_context)(JITUserContext *user_context){nullptr};

/** A custom method for the Halide runtime to acquire a cuda
* stream to use. The cuda context and stream are both modelled
* as a void *, to avoid a dependence on the cuda headers. */
int32_t (*custom_cuda_get_stream)(JITUserContext *user_context, void *cuda_context, void **stream_ptr){nullptr};
};

namespace Internal {
Expand Down
2 changes: 1 addition & 1 deletion src/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,7 @@ Realization Pipeline::realize(JITUserContext *context,
if (needs_crop) {
r[i].crop(crop);
}
r[i].copy_to_host();
r[i].copy_to_host(context);
}
return r;
}
Expand Down
17 changes: 17 additions & 0 deletions src/runtime/HalideRuntimeCuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,23 @@ extern uintptr_t halide_cuda_get_device_ptr(void *user_context, struct halide_bu
* driver. See halide_reuse_device_allocations. */
extern int halide_cuda_release_unused_device_allocations(void *user_context);

// These typedefs treat both a CUcontext and a CUstream as a void *,
// to avoid dependencies on cuda headers.
typedef int (*halide_cuda_acquire_context_t)(void *, // user_context
void **, // cuda context out parameter
bool); // should create a context if none exist
typedef int (*halide_cuda_release_context_t)(void * /* user_context */);
typedef int (*halide_cuda_get_stream_t)(void *, // user_context
void *, // context
void **); // stream out parameter

/** Set custom methods to acquire and release cuda contexts and streams */
// @{
extern halide_cuda_acquire_context_t halide_set_cuda_acquire_context(halide_cuda_acquire_context_t handler);
extern halide_cuda_release_context_t halide_set_cuda_release_context(halide_cuda_release_context_t handler);
extern halide_cuda_get_stream_t halide_set_cuda_get_stream(halide_cuda_get_stream_t handler);
// @}

#ifdef __cplusplus
} // End extern "C"
#endif
Expand Down
Loading

0 comments on commit 0829d1d

Please sign in to comment.