Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reset OpenMP thread number if num_threads <= 0 #4704

Merged
merged 8 commits into from
Oct 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions include/LightGBM/utils/openmp_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@ inline int OMP_NUM_THREADS() {
return ret;
}

inline void OMP_SET_NUM_THREADS(int num_threads) {
static const int default_omp_num_threads = OMP_NUM_THREADS();
if (num_threads > 0) {
omp_set_num_threads(num_threads);
} else {
omp_set_num_threads(default_omp_num_threads);
}
}

class ThreadExceptionHelper {
public:
ThreadExceptionHelper() {
Expand Down Expand Up @@ -94,6 +103,7 @@ class ThreadExceptionHelper {
simulate a single thread running.
All #pragma omp should be ignored by the compiler **/
inline void omp_set_num_threads(int) __GOMP_NOTHROW {} // NOLINT (no cast done here)
inline void OMP_SET_NUM_THREADS(int) __GOMP_NOTHROW {}
inline int omp_get_num_threads() __GOMP_NOTHROW {return 1;}
inline int omp_get_max_threads() __GOMP_NOTHROW {return 1;}
inline int omp_get_thread_num() __GOMP_NOTHROW {return 0;}
Expand Down
4 changes: 1 addition & 3 deletions src/application/application.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ namespace LightGBM {
Application::Application(int argc, char** argv) {
LoadParameters(argc, argv);
// set number of threads for openmp
if (config_.num_threads > 0) {
omp_set_num_threads(config_.num_threads);
}
OMP_SET_NUM_THREADS(config_.num_threads);
if (config_.data.size() == 0 && config_.task != TaskType::kConvertModel) {
Log::Fatal("No training/prediction data, application quit");
}
Expand Down
76 changes: 19 additions & 57 deletions src/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,7 @@ class Booster {
const char* parameters) {
auto param = Config::Str2Map(parameters);
config_.Set(param);
if (config_.num_threads > 0) {
omp_set_num_threads(config_.num_threads);
}
OMP_SET_NUM_THREADS(config_.num_threads);
// create boosting
if (config_.input_model.size() > 0) {
Log::Warning("Continued train from model is not supported for c_api,\n"
Expand Down Expand Up @@ -314,9 +312,7 @@ class Booster {

config_.Set(param);

if (config_.num_threads > 0) {
omp_set_num_threads(config_.num_threads);
}
OMP_SET_NUM_THREADS(config_.num_threads);

if (param.count("objective")) {
// create objective function
Expand Down Expand Up @@ -951,9 +947,7 @@ int LGBM_DatasetCreateFromFile(const char* filename,
auto param = Config::Str2Map(parameters);
Config config;
config.Set(param);
if (config.num_threads > 0) {
omp_set_num_threads(config.num_threads);
}
OMP_SET_NUM_THREADS(config.num_threads);
DatasetLoader loader(config, nullptr, 1, filename);
if (reference == nullptr) {
if (Network::num_machines() == 1) {
Expand Down Expand Up @@ -981,9 +975,7 @@ int LGBM_DatasetCreateFromSampledColumn(double** sample_data,
auto param = Config::Str2Map(parameters);
Config config;
config.Set(param);
if (config.num_threads > 0) {
omp_set_num_threads(config.num_threads);
}
OMP_SET_NUM_THREADS(config.num_threads);
DatasetLoader loader(config, nullptr, 1, nullptr);
*out = loader.ConstructFromSampleData(sample_data, sample_indices, ncol, num_per_col,
num_sample_row,
Expand Down Expand Up @@ -1096,9 +1088,7 @@ int LGBM_DatasetCreateFromMats(int32_t nmat,
auto param = Config::Str2Map(parameters);
Config config;
config.Set(param);
if (config.num_threads > 0) {
omp_set_num_threads(config.num_threads);
}
OMP_SET_NUM_THREADS(config.num_threads);
std::unique_ptr<Dataset> ret;
int32_t total_nrow = 0;
for (int j = 0; j < nmat; ++j) {
Expand Down Expand Up @@ -1188,9 +1178,7 @@ int LGBM_DatasetCreateFromCSR(const void* indptr,
auto param = Config::Str2Map(parameters);
Config config;
config.Set(param);
if (config.num_threads > 0) {
omp_set_num_threads(config.num_threads);
}
OMP_SET_NUM_THREADS(config.num_threads);
std::unique_ptr<Dataset> ret;
auto get_row_fun = RowFunctionFromCSR<int>(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
int32_t nrow = static_cast<int32_t>(nindptr - 1);
Expand Down Expand Up @@ -1256,9 +1244,7 @@ int LGBM_DatasetCreateFromCSRFunc(void* get_row_funptr,
auto param = Config::Str2Map(parameters);
Config config;
config.Set(param);
if (config.num_threads > 0) {
omp_set_num_threads(config.num_threads);
}
OMP_SET_NUM_THREADS(config.num_threads);
std::unique_ptr<Dataset> ret;
int32_t nrow = num_rows;
if (reference == nullptr) {
Expand Down Expand Up @@ -1328,9 +1314,7 @@ int LGBM_DatasetCreateFromCSC(const void* col_ptr,
auto param = Config::Str2Map(parameters);
Config config;
config.Set(param);
if (config.num_threads > 0) {
omp_set_num_threads(config.num_threads);
}
OMP_SET_NUM_THREADS(config.num_threads);
std::unique_ptr<Dataset> ret;
int32_t nrow = static_cast<int32_t>(num_row);
if (reference == nullptr) {
Expand Down Expand Up @@ -1409,9 +1393,7 @@ int LGBM_DatasetGetSubset(
auto param = Config::Str2Map(parameters);
Config config;
config.Set(param);
if (config.num_threads > 0) {
omp_set_num_threads(config.num_threads);
}
OMP_SET_NUM_THREADS(config.num_threads);
auto full_dataset = reinterpret_cast<const Dataset*>(handle);
CHECK_GT(num_used_row_indices, 0);
const int32_t lower = 0;
Expand Down Expand Up @@ -1816,9 +1798,7 @@ int LGBM_BoosterPredictForFile(BoosterHandle handle,
auto param = Config::Str2Map(parameter);
Config config;
config.Set(param);
if (config.num_threads > 0) {
omp_set_num_threads(config.num_threads);
}
OMP_SET_NUM_THREADS(config.num_threads);
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->Predict(start_iteration, num_iteration, predict_type, data_filename, data_has_header,
config, result_filename);
Expand Down Expand Up @@ -1894,9 +1874,7 @@ int LGBM_BoosterPredictForCSR(BoosterHandle handle,
auto param = Config::Str2Map(parameter);
Config config;
config.Set(param);
if (config.num_threads > 0) {
omp_set_num_threads(config.num_threads);
}
OMP_SET_NUM_THREADS(config.num_threads);
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto get_row_fun = RowFunctionFromCSR<int>(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
int nrow = static_cast<int>(nindptr - 1);
Expand Down Expand Up @@ -1928,9 +1906,7 @@ int LGBM_BoosterPredictSparseOutput(BoosterHandle handle,
auto param = Config::Str2Map(parameter);
Config config;
config.Set(param);
if (config.num_threads > 0) {
omp_set_num_threads(config.num_threads);
}
OMP_SET_NUM_THREADS(config.num_threads);
if (matrix_type == C_API_MATRIX_TYPE_CSR) {
if (num_col_or_row <= 0) {
Log::Fatal("The number of columns should be greater than zero.");
Expand Down Expand Up @@ -2015,9 +1991,7 @@ int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle,
auto param = Config::Str2Map(parameter);
Config config;
config.Set(param);
if (config.num_threads > 0) {
omp_set_num_threads(config.num_threads);
}
OMP_SET_NUM_THREADS(config.num_threads);
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto get_row_fun = RowFunctionFromCSR<int>(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
ref_booster->SetSingleRowPredictor(start_iteration, num_iteration, predict_type, config);
Expand Down Expand Up @@ -2047,9 +2021,7 @@ int LGBM_BoosterPredictForCSRSingleRowFastInit(BoosterHandle handle,
data_type,
static_cast<int32_t>(num_col)));

if (fastConfig_ptr->config.num_threads > 0) {
omp_set_num_threads(fastConfig_ptr->config.num_threads);
}
OMP_SET_NUM_THREADS(fastConfig_ptr->config.num_threads);

fastConfig_ptr->booster->SetSingleRowPredictor(start_iteration, num_iteration, predict_type, fastConfig_ptr->config);

Expand Down Expand Up @@ -2095,9 +2067,7 @@ int LGBM_BoosterPredictForCSC(BoosterHandle handle,
auto param = Config::Str2Map(parameter);
Config config;
config.Set(param);
if (config.num_threads > 0) {
omp_set_num_threads(config.num_threads);
}
OMP_SET_NUM_THREADS(config.num_threads);
int num_threads = OMP_NUM_THREADS();
int ncol = static_cast<int>(ncol_ptr - 1);
std::vector<std::vector<CSC_RowIterator>> iterators(num_threads, std::vector<CSC_RowIterator>());
Expand Down Expand Up @@ -2140,9 +2110,7 @@ int LGBM_BoosterPredictForMat(BoosterHandle handle,
auto param = Config::Str2Map(parameter);
Config config;
config.Set(param);
if (config.num_threads > 0) {
omp_set_num_threads(config.num_threads);
}
OMP_SET_NUM_THREADS(config.num_threads);
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
ref_booster->Predict(start_iteration, num_iteration, predict_type, nrow, ncol, get_row_fun,
Expand All @@ -2165,9 +2133,7 @@ int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle,
auto param = Config::Str2Map(parameter);
Config config;
config.Set(param);
if (config.num_threads > 0) {
omp_set_num_threads(config.num_threads);
}
OMP_SET_NUM_THREADS(config.num_threads);
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto get_row_fun = RowPairFunctionFromDenseMatric(data, 1, ncol, data_type, is_row_major);
ref_booster->SetSingleRowPredictor(start_iteration, num_iteration, predict_type, config);
Expand All @@ -2191,9 +2157,7 @@ int LGBM_BoosterPredictForMatSingleRowFastInit(BoosterHandle handle,
data_type,
ncol));

if (fastConfig_ptr->config.num_threads > 0) {
omp_set_num_threads(fastConfig_ptr->config.num_threads);
}
OMP_SET_NUM_THREADS(fastConfig_ptr->config.num_threads);

fastConfig_ptr->booster->SetSingleRowPredictor(start_iteration, num_iteration, predict_type, fastConfig_ptr->config);

Expand Down Expand Up @@ -2231,9 +2195,7 @@ int LGBM_BoosterPredictForMats(BoosterHandle handle,
auto param = Config::Str2Map(parameter);
Config config;
config.Set(param);
if (config.num_threads > 0) {
omp_set_num_threads(config.num_threads);
}
OMP_SET_NUM_THREADS(config.num_threads);
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto get_row_fun = RowPairFunctionFromDenseRows(data, ncol, data_type);
ref_booster->Predict(start_iteration, num_iteration, predict_type, nrow, ncol, get_row_fun, config, out_result, out_len);
Expand Down