diff --git a/src/CodeGen_Hexagon.cpp b/src/CodeGen_Hexagon.cpp index 847fe588d758..a32bca98ff7d 100644 --- a/src/CodeGen_Hexagon.cpp +++ b/src/CodeGen_Hexagon.cpp @@ -431,6 +431,30 @@ class InjectHVXLocks : public IRMutator { } Expr visit(const Call *op) override { uses_hvx = uses_hvx || op->type.is_vector(); + + if (op->name == "halide_do_par_for") { + // If we see a call to halide_do_par_for() at this point, it should mean that + // this statement was produced via HexagonOffload calling lower_parallel_tasks() + // explicitly; in this case, we won't see any parallel For statements, since they've + // all been transformed into closures already. To mirror the pattern above, + // we need to wrap the halide_do_par_for() call with an unlock/lock pair, but + // that's hard to do in Halide IR (we'd need to produce a Stmt to enforce the ordering, + // and the resulting Stmt can't easily be substituted for the Expr here). Rather than + // make fragile assumptions about the structure of the IR produced by lower_parallel_tasks(), + // we'll use a trick: we'll define a WEAK_INLINE function, _halide_hexagon_do_par_for, + // which simply encapsulates the unlock()/do_par_for()/lock() sequences, and swap out + // the call here. Since it is inlined, and since uses_hvx_var gets substituted at the end, + // we end up with LLVM IR that properly includes (or omits) the unlock/lock pair depending + // on the final value of uses_hvx_var in this scope. + + internal_assert(op->call_type == Call::Extern); + internal_assert(op->args.size() == 4); + + std::vector args = op->args; + args.push_back(cast(uses_hvx_var)); + + return Call::make(Int(32), "_halide_hexagon_do_par_for", args, Call::Extern); + } return op; } diff --git a/src/CodeGen_Internal.cpp b/src/CodeGen_Internal.cpp index 7fe5900588f7..d8388882329a 100644 --- a/src/CodeGen_Internal.cpp +++ b/src/CodeGen_Internal.cpp @@ -137,6 +137,7 @@ bool function_takes_user_context(const std::string &name) { "_halide_buffer_crop", "_halide_buffer_retire_crop_after_extern_stage", "_halide_buffer_retire_crops_after_extern_stage", + "_halide_hexagon_do_par_for", }; for (const char *user_context_runtime_func : user_context_runtime_funcs) { if (name == user_context_runtime_func) { diff --git a/src/runtime/qurt_hvx.cpp b/src/runtime/qurt_hvx.cpp index 9748cc662142..afc607bfce47 100644 --- a/src/runtime/qurt_hvx.cpp +++ b/src/runtime/qurt_hvx.cpp @@ -69,4 +69,29 @@ WEAK_INLINE uint8_t *_halide_hexagon_buffer_get_host(const hexagon_buffer_t_arg WEAK_INLINE uint64_t _halide_hexagon_buffer_get_device(const hexagon_buffer_t_arg *buf) { return buf->device; } + +WEAK_INLINE int _halide_hexagon_do_par_for(void *user_context, halide_task_t f, + int min, int size, uint8_t *closure, + int use_hvx) { + if (use_hvx) { + const int result = halide_qurt_hvx_unlock(user_context); + if (result != 0) { + return result; + } + } + + const int result = halide_do_par_for(user_context, f, min, size, closure); + if (result != 0) { + return result; + } + + if (use_hvx) { + const int result = halide_qurt_hvx_lock(user_context); + if (result != 0) { + return result; + } + } + + return 0; +} }