Skip to content

Commit

Permalink
updated based on comments
Browse files Browse the repository at this point in the history
  • Loading branch information
imatiach-msft committed May 23, 2020
1 parent 64abf8c commit 7e34e36
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
14 changes: 8 additions & 6 deletions src/treelearner/data_parallel_tree_learner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,14 @@ void DataParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data, boo
// allocate buffer for communication
size_t buffer_size = this->train_data_->NumTotalBin() * kHistEntrySize;

input_buffer_.resize(buffer_size);
auto max_cat_threshold = this->config_->max_cat_threshold;
// input_buffer_ needs to be able to hold smaller and larger best splits in SyncUpGlobalBestSplit
int splitInfoSize = SplitInfo::Size(max_cat_threshold) * 2;
if (buffer_size < splitInfoSize) {
input_buffer_.resize(splitInfoSize);
} else {
input_buffer_.resize(buffer_size);
}
output_buffer_.resize(buffer_size);

is_feature_aggregated_.resize(this->num_features_);
Expand Down Expand Up @@ -231,11 +238,6 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const
}

// sync global best info
auto max_cat_threshold = this->config_->max_cat_threshold;
int splitInfoSize = SplitInfo::Size(max_cat_threshold);
if (input_buffer_.size() < splitInfoSize * 2) {
input_buffer_.resize(splitInfoSize * 2);
}
SyncUpGlobalBestSplit(input_buffer_.data(), input_buffer_.data(), &smaller_best_split, &larger_best_split, this->config_->max_cat_threshold);

// set best split
Expand Down
2 changes: 1 addition & 1 deletion src/treelearner/split_info.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ struct SplitInfo {
bool default_left = true;
int8_t monotone_type = 0;
inline static int Size(int max_cat_threshold) {
return 2 * sizeof(int) + sizeof(uint32_t) + sizeof(bool) + sizeof(double) * 9 + sizeof(data_size_t) * 2 + max_cat_threshold * sizeof(uint32_t) + sizeof(int8_t);
return 2 * sizeof(int) + sizeof(uint32_t) + sizeof(bool) + sizeof(double) * 7 + sizeof(data_size_t) * 2 + max_cat_threshold * sizeof(uint32_t) + sizeof(int8_t);
}

inline void CopyTo(char* buffer) const {
Expand Down

0 comments on commit 7e34e36

Please sign in to comment.