Skip to content

Commit

Permalink
clean the conflicts, make sure the pipeline functions
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyueXu77 committed Mar 4, 2024
1 parent d0bee2f commit 4624c3f
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 88 deletions.
89 changes: 19 additions & 70 deletions src/common/quantile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -361,21 +361,23 @@ void SketchContainerImpl<WQSketch>::AllReduce(
}

template <typename SketchType>
void AddCutPoint(typename SketchType::SummaryContainer const &summary, int max_bin,
double AddCutPoint(typename SketchType::SummaryContainer const &summary, int max_bin,
HistogramCuts *cuts, bool secure) {
size_t required_cuts = std::min(summary.size, static_cast<size_t>(max_bin));
// make a copy of required_cuts for mode selection
size_t required_cuts_original = required_cuts;
if (secure) {
// Sync the required_cuts across all workers
// sync the required_cuts across all workers
collective::Allreduce<collective::Operation::kMax>(&required_cuts, 1);
}
// add the cut points
auto &cut_values = cuts->cut_values_.HostVector();
// if empty column, fill the cut values with 0
// if secure and empty column, fill the cut values with NaN
if (secure && (required_cuts_original == 0)) {
for (size_t i = 1; i < required_cuts; ++i) {
cut_values.push_back(0.0);
cut_values.push_back(std::numeric_limits<double>::quiet_NaN());
}
return std::numeric_limits<double>::quiet_NaN();
} else {
// we use the min_value as the first (0th) element, hence starting from 1.
for (size_t i = 1; i < required_cuts; ++i) {
Expand All @@ -384,43 +386,10 @@ void AddCutPoint(typename SketchType::SummaryContainer const &summary, int max_b
cut_values.push_back(cpt);
}
}
return cut_values.back();
}
}

template <typename SketchType>
double AddCutPointSecure(typename SketchType::SummaryContainer const &summary, int max_bin,
HistogramCuts *cuts) {
// For secure vertical pipeline, we fill the cut values corresponding to empty columns
// with a vector of minimum value
size_t required_cuts = std::min(summary.size, static_cast<size_t>(max_bin));
// make a copy of required_cuts for mode selection
size_t required_cuts_original = required_cuts;
// Sync the required_cuts across all workers
collective::Allreduce<collective::Operation::kMax>(&required_cuts, 1);

// add the cut points
auto &cut_values = cuts->cut_values_.HostVector();
// if not empty column, fill the cut values with the actual values
if (required_cuts_original > 0) {
// we use the min_value as the first (0th) element, hence starting from 1.
for (size_t i = 1; i < required_cuts; ++i) {
bst_float cpt = summary.data[i].value;
if (i == 1 || cpt > cut_values.back()) {
cut_values.push_back(cpt);
}
}
return cut_values.back();
}
// if empty column, fill the cut values with NaN
else {
for (size_t i = 1; i < required_cuts; ++i) {
//cut_values.push_back(0.0);
cut_values.push_back(std::numeric_limits<double>::quiet_NaN());
}
return std::numeric_limits<double>::quiet_NaN();
}
}

auto AddCategories(std::set<float> const &categories, HistogramCuts *cuts) {
if (std::any_of(categories.cbegin(), categories.cend(), InvalidCat)) {
InvalidCategory();
Expand Down Expand Up @@ -480,47 +449,27 @@ void SketchContainerImpl<WQSketch>::MakeCuts(Context const *ctx, MetaInfo const
max_cat = std::max(max_cat, AddCategories(categories_.at(fid), p_cuts));
} else {
// use special AddCutPoint scheme for secure vertical federated learning
if (info.IsVerticalFederated() && info.IsSecure()) {
double last_value = AddCutPointSecure<WQSketch>(a, max_num_bins, p_cuts);
// push a value that is greater than anything if the feature is not empty
// i.e. if the last value is not NaN
if (!std::isnan(last_value)) {
const bst_float cpt =
(a.size > 0) ? a.data[a.size - 1].value : p_cuts->min_vals_.HostVector()[fid];
// this must be bigger than last value in a scale
const bst_float last = cpt + (fabs(cpt) + 1e-5f);
p_cuts->cut_values_.HostVector().push_back(last);
}
else {
// if the feature is empty, push a NaN value
p_cuts->cut_values_.HostVector().push_back(std::numeric_limits<double>::quiet_NaN());
}
}
else {
AddCutPoint<WQSketch>(a, max_num_bins, p_cuts);
// push a value that is greater than anything
const bst_float cpt =
(a.size > 0) ? a.data[a.size - 1].value : p_cuts->min_vals_.HostVector()[fid];
// this must be bigger than last value in a scale
const bst_float last = cpt + (fabs(cpt) + 1e-5f);
p_cuts->cut_values_.HostVector().push_back(last);
double last_value = AddCutPoint<WQSketch>(a, max_num_bins, p_cuts, info.IsSecure());
// push a value that is greater than anything if the feature is not empty
// i.e. if the last value is not NaN
if (!std::isnan(last_value)) {
const bst_float cpt =
(a.size > 0) ? a.data[a.size - 1].value : p_cuts->min_vals_.HostVector()[fid];
// this must be bigger than last value in a scale
const bst_float last = cpt + (fabs(cpt) + 1e-5f);
p_cuts->cut_values_.HostVector().push_back(last);
} else {
// if the feature is empty, push a NaN value
p_cuts->cut_values_.HostVector().push_back(std::numeric_limits<double>::quiet_NaN());
}
}

// Ensure that every feature gets at least one quantile point
CHECK_LE(p_cuts->cut_values_.HostVector().size(), std::numeric_limits<uint32_t>::max());
auto cut_size = static_cast<uint32_t>(p_cuts->cut_values_.HostVector().size());
CHECK_GT(cut_size, p_cuts->cut_ptrs_.HostVector().back());
p_cuts->cut_ptrs_.HostVector().push_back(cut_size);
}

if (info.IsVerticalFederated() && info.IsSecure()) {
// cut values need to be synced across all workers via Allreduce
auto cut_val = p_cuts->cut_values_.HostVector().data();
std::size_t n = p_cuts->cut_values_.HostVector().size();
collective::Allreduce<collective::Operation::kSum>(cut_val, n);
}

p_cuts->SetCategorical(this->has_categorical_, max_cat);
monitor_.Stop(__func__);
}
Expand Down
11 changes: 5 additions & 6 deletions src/tree/common_row_partitioner.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ class CommonRowPartitioner {

template <typename ExpandEntry>
void FindSplitConditions(const std::vector<ExpandEntry>& nodes, const RegTree& tree,
const GHistIndexMatrix& gmat, std::vector<int32_t>* split_conditions, bool is_index) {
const GHistIndexMatrix& gmat,
std::vector<int32_t>* split_conditions, bool is_index) {
auto const& ptrs = gmat.cut.Ptrs();
auto const& vals = gmat.cut.Values();

Expand All @@ -121,10 +122,9 @@ class CommonRowPartitioner {
// therefore can recover the split_pt from bin_id, update tree info
auto split_pt_local = vals[split_pt];
// make updates to the tree, replacing the existing index
// with cut value, note that we modified const here, carefully
// with cut value, so as to be consistent with the tree model format
const_cast<RegTree::Node&>(tree.GetNodes()[nidx]).SetSplit(fidx, split_pt_local);
}
else {
} else {
// otherwise find the bin_id that corresponds to split_pt
std::uint32_t const lower_bound = ptrs[fidx];
std::uint32_t const upper_bound = ptrs[fidx + 1];
Expand Down Expand Up @@ -210,8 +210,7 @@ class CommonRowPartitioner {
if (is_secure_) {
// in secure mode, the split index is kept instead of the split value
FindSplitConditions(nodes, *p_tree, gmat, &split_conditions, true);
}
else {
} else {
FindSplitConditions(nodes, *p_tree, gmat, &split_conditions, false);
}
}
Expand Down
19 changes: 8 additions & 11 deletions src/tree/hist/evaluate_splits.h
Original file line number Diff line number Diff line change
Expand Up @@ -303,33 +303,30 @@ class HistEvaluator {
// forward enumeration: split at right bound of each bin
loss_chg =
static_cast<float>(evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{left_sum},
GradStats{right_sum}) -
parent.root_gain);
GradStats{right_sum}) - parent.root_gain);
if (!is_secure_) {
split_pt = cut_val[i]; // not used for partition based
best.Update(loss_chg, fidx, split_pt, d_step == -1, false, left_sum, right_sum);
}
else {
// secure mode: record the best split point, rather than the actual value since it is not accessible
// at this point (active party finding best-split)
} else {
// secure mode: record the best split point, rather than the actual value
// since it is not accessible at this point (active party finding best-split)
best.Update(loss_chg, fidx, i, d_step == -1, false, left_sum, right_sum);
}
} else {
// backward enumeration: split at left bound of each bin
loss_chg =
static_cast<float>(evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{right_sum},
GradStats{left_sum}) -
parent.root_gain);
GradStats{left_sum}) - parent.root_gain);
if (!is_secure_) {
if (i == imin) {
split_pt = cut.MinValues()[fidx];
} else {
split_pt = cut_val[i - 1];
}
best.Update(loss_chg, fidx, split_pt, d_step == -1, false, right_sum, left_sum);
}
else {
// secure mode: record the best split point, rather than the actual value since it is not accessible
} else {
// secure mode: record the best split point, rather than the actual value
// since it is not accessible at this point (active party finding best-split)
if (i != imin) {
i = i - 1;
}
Expand Down
3 changes: 2 additions & 1 deletion src/tree/updater_quantile_hist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ class MultiTargetHistBuilder {
} else {
CHECK_EQ(n_total_bins, page.cut.TotalBins());
}
partitioner_.emplace_back(ctx_, page.Size(), page.base_rowid, p_fmat->Info().IsColumnSplit(), p_fmat->Info().IsSecure());
partitioner_.emplace_back(ctx_, page.Size(), page.base_rowid,
p_fmat->Info().IsColumnSplit(), p_fmat->Info().IsSecure());
}

bst_target_t n_targets = p_tree->NumTargets();
Expand Down

0 comments on commit 4624c3f

Please sign in to comment.