Skip to content

Commit

Permalink
[RUNTIME] Fast path for single thread run to allow app level threading (
Browse files Browse the repository at this point in the history
#7454)

* Fast path for single thread run to allow app level threading

* add sync counter to avoid error in one of tests
  • Loading branch information
masahi authored Feb 18, 2021
1 parent 84c4b15 commit 944d8d1
Showing 1 changed file with 20 additions and 11 deletions.
31 changes: 20 additions & 11 deletions src/runtime/thread_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -363,21 +363,30 @@ TVM_REGISTER_GLOBAL("runtime.config_threadpool").set_body([](TVMArgs args, TVMRe
} // namespace tvm

int TVMBackendParallelLaunch(FTVMParallelLambda flambda, void* cdata, int num_task) {
int num_workers = tvm::runtime::threading::MaxConcurrency();
if (num_workers == 1) {
std::atomic<int32_t> sync_counter{0};
TVMParallelGroupEnv env;
env.num_task = 1;
env.sync_handle = &sync_counter;
(*flambda)(0, &env, cdata);
return 0;
} else {
#if !TVM_THREADPOOL_USE_OPENMP
int res = tvm::runtime::ThreadPool::ThreadLocal()->Launch(flambda, cdata, num_task, 1);
return res;
int res = tvm::runtime::ThreadPool::ThreadLocal()->Launch(flambda, cdata, num_task, 1);
return res;
#else
int num_workers = tvm::runtime::threading::MaxConcurrency();
if (num_task == 0) num_task = num_workers;
omp_set_num_threads(num_task);
if (num_task == 0) num_task = num_workers;
omp_set_num_threads(num_task);
#pragma omp parallel num_threads(num_task)
{
TVMParallelGroupEnv env;
env.num_task = num_task;
(*flambda)(omp_get_thread_num(), &env, cdata);
}
return 0;
{
TVMParallelGroupEnv env;
env.num_task = num_task;
(*flambda)(omp_get_thread_num(), &env, cdata);
}
return 0;
#endif
}
}

int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) {
Expand Down

0 comments on commit 944d8d1

Please sign in to comment.