From 37c6cdf644f345223231be55d7ef242f1a39bce5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=B7=E8=A1=8C?= Date: Fri, 10 Jul 2020 18:02:23 +0800 Subject: [PATCH 1/6] Use Oblivious method to fix side channel privacy leak --- include/xgboost/common/quantile.h | 269 ++++++++++++++++++++---------- 1 file changed, 180 insertions(+), 89 deletions(-) diff --git a/include/xgboost/common/quantile.h b/include/xgboost/common/quantile.h index c4a4b501a..7bce9afea 100644 --- a/include/xgboost/common/quantile.h +++ b/include/xgboost/common/quantile.h @@ -1,6 +1,5 @@ /*! * Copyright 2014 by Contributors - * Modifications Copyright 2020 by Secure XGBoost Contributors * \file quantile.h * \brief util to compute quantiles * \author Tianqi Chen @@ -10,11 +9,13 @@ #include #include + #include #include #include #include #include + #include "obl_primitives.h" namespace xgboost { @@ -24,6 +25,7 @@ bool ObliviousSetCombineEnabled(); bool ObliviousSetPruneEnabled(); bool ObliviousDebugCheckEnabled(); bool ObliviousEnabled(); +void SetObliviousMode(bool); template struct WQSummaryEntry { @@ -154,6 +156,7 @@ struct EntryWithPartyInfo { using Entry = WQSummaryEntry; Entry entry; bool is_party_a; + bool is_dummy; inline bool operator<(const EntryWithPartyInfo &b) const { return entry < b.entry; @@ -228,9 +231,9 @@ template void CheckEqualSummary(const WQSummary &lhs, const WQSummary &rhs) { auto trace = [&]() { - LOG(INFO) << "---------- lhs: "; + LOG(CONSOLE) << "---------- lhs: "; lhs.Print(); - LOG(INFO) << "---------- rhs: "; + LOG(CONSOLE) << "---------- rhs: "; rhs.Print(); }; // DEBUG CHECK @@ -301,7 +304,11 @@ struct WQSummary { i = j; } } - + /* MakeSummaryOblivious protect the unique_count variable. + * in->size == qhelper.size + * out->size == qhelper.size + * out->data == || normal unique data | dummy data || + * */ inline void MakeSummaryOblivious(WQSummary *out) { ObliviousSort(queue.begin(), queue.begin() + qtail); @@ -330,34 +337,33 @@ struct WQSummary { } } - struct IsNewDescendingSorter { - bool operator()(const QEntryHelper &a, const QEntryHelper &b) { - return ObliviousGreater(a.is_new, b.is_new); - } - }; - struct ValueSorter { bool operator()(const QEntryHelper &a, const QEntryHelper &b) { return ObliviousLess(a.entry.value, b.entry.value); } }; - // Remove duplicates. - ObliviousSort(qhelper.begin(), qhelper.end(), IsNewDescendingSorter()); + for (size_t idx = 0; idx < qhelper.size(); ++idx) { + qhelper[idx].entry.value = + ObliviousChoose(qhelper[idx].is_new, qhelper[idx].entry.value, + std::numeric_limits::max()); + } // Resort by value. - ObliviousSort(qhelper.begin(), qhelper.begin() + unique_count, - ValueSorter()); + ObliviousSort(qhelper.begin(), qhelper.end(), ValueSorter()); out->size = 0; RType wsum = 0; - for (size_t idx = 0; idx < unique_count; ++idx) { + // is_new represent first sight + for (size_t idx = 0; idx < qhelper.size(); ++idx) { const RType w = qhelper[idx].entry.weight; - out->data[out->size++] = - Entry(wsum, wsum + w, w, qhelper[idx].entry.value); + bool is_new = qhelper[idx].is_new; + ObliviousAssign(is_new, + Entry(wsum, wsum + w, w, qhelper[idx].entry.value), + Entry(-1, -1, 0, std::numeric_limits::max()), + &out->data[out->size++]); wsum += w; } - if (ObliviousDebugCheckEnabled()) { std::vector oblivious_results(out->data, out->data + out->size); this->MakeSummaryRaw(out); @@ -417,6 +423,10 @@ struct WQSummary { size = src.size; std::memcpy(data, src.data, sizeof(Entry) * size); } + inline void CopyFromSize(const WQSummary &src, const size_t insize) { + size = insize; + std::memcpy(data, src.data, sizeof(Entry) * size); + } inline void MakeFromSorted(const Entry *entries, size_t n) { size = 0; for (size_t i = 0; i < n;) { @@ -448,8 +458,9 @@ struct WQSummary { } /*! - * \brief set current summary to be pruned summary of src + * \brief set current summary to be obliviously pruned summary of src * assume data field is already allocated to be at least maxsize + * dummy item will rank last of return and will involved in following computation * \param src source summary * \param maxsize size we can afford in the pruned sketch */ @@ -458,14 +469,23 @@ struct WQSummary { this->CopyFrom(src); return; } - - // Make sure dx2 items are last one when `d == (rmax + rmin) / 2`. - const Entry kDummyEntryWithMaxValue{0, 0, 1, + const Entry kDummyEntryWithMaxValue{-1, -1, 0, std::numeric_limits::max()}; + // Make sure dx2 items are last one when `d == (rmax + rmin) / 2`. const RType begin = src.data[0].rmax; - const RType range = src.data[src.size - 1].rmin - src.data[0].rmax; - const size_t n = maxsize - 1; + const RType n = maxsize - 1; + // max_index is equal to previous src.size + size_t max_index = 0; + RType range = 0; + // find actually max item + for (size_t idx = 0; idx < src.size; idx++) { + max_index = ObliviousChoose( + src.data[idx].value != std::numeric_limits::max(), idx, + max_index); + range = src.data[max_index].rmin - src.data[0].rmax; + } + max_index += 1; // Construct sort vector. using Item = PruneItem; @@ -475,16 +495,27 @@ struct WQSummary { RType dx2 = 2 * ((k * range) / n + begin); items.push_back(Item{kDummyEntryWithMaxValue, dx2, false}); } - std::transform(src.data + 1, src.data + src.size, std::back_inserter(items), - [](const Entry &entry) { - return Item{entry, entry.rmax + entry.rmin, true}; - }); + // ObliviousPrune contains Dummy item,So here we doing this on 2 cases + // CASE i < max_index: handle normal data + // CASE other: handle dummy data + // + for (size_t i = 1; i < src.size; ++i) { + Item obliviousItem = ObliviousChoose( + i < max_index - 1, + Item{src.data[i], src.data[i].rmax + src.data[i].rmin, true}, + Item{kDummyEntryWithMaxValue, std::numeric_limits::max(), + true}); + items.push_back(obliviousItem); + } for (size_t i = 1; i < src.size - 1; ++i) { - items.push_back(Item{src.data[i], - src.data[i].RMinNext() + src.data[i + 1].RMaxPrev(), - true}); + Item obliviousItem = ObliviousChoose( + i < max_index - 1, + Item{src.data[i], src.data[i].RMinNext() + src.data[i + 1].RMaxPrev(), + true}, + Item{kDummyEntryWithMaxValue, std::numeric_limits::max(), + true}); + items.push_back(obliviousItem); } - // Bitonic Sort. LOG(DEBUG) << __func__ << " BEGIN 1" << std::endl; ObliviousSort(items.begin(), items.end()); @@ -493,34 +524,42 @@ struct WQSummary { // Choose entrys. RType last_selected_entry_value = std::numeric_limits::min(); size_t select_count = 0; + Entry lastEntry = items[0].entry; for (size_t i = 1; i < items.size(); ++i) { - bool do_select = !items[i - 1].has_entry && items[i].has_entry && - items[i].entry.value != last_selected_entry_value; + // CASE max_index<=maxsize:All unique item will be select + // CASE other : select unique after dx2 index + bool do_select = ObliviousChoose( + max_index <= maxsize, + items[i].entry.value != last_selected_entry_value && + items[i].entry.value != std::numeric_limits::max(), + !items[i - 1].has_entry && items[i].has_entry && + items[i].entry.value != last_selected_entry_value); ObliviousAssign(do_select, items[i].entry.value, last_selected_entry_value, &last_selected_entry_value); ObliviousAssign(do_select, std::numeric_limits::min(), items[i].rank, &items[i].rank); + ObliviousAssign(i == max_index - 1, src.data[i], lastEntry, &lastEntry); + ObliviousAssign(do_select, items[i].entry, kDummyEntryWithMaxValue, + &items[i].entry); select_count += ObliviousChoose(do_select, 1, 0); } + // Bitonic Sort. LOG(DEBUG) << __func__ << " BEGIN 2" << std::endl; ObliviousSort(items.begin(), items.end()); LOG(DEBUG) << __func__ << " PASSED 2" << std::endl; + // Append actual last item to items vector + for (size_t i = 0; i < src.size; i++) { + ObliviousAssign(i == select_count, lastEntry, items[i].entry, + &items[i].entry); + } + this->data[0] = src.data[0]; - this->size = 1 + select_count; - std::transform(items.begin(), items.begin() + select_count, this->data + 1, - [](const Item &item) { - CHECK(item.has_entry && - item.rank == std::numeric_limits::min()); - return item.entry; - }); + this->size = maxsize; - // First and last ones are always kept in prune. - if (data[size - 1].value != src.data[src.size - 1].value) { - CHECK(size < maxsize); - data[size++] = src.data[src.size - 1]; - } + std::transform(items.begin(), items.begin() + maxsize - 1, this->data + 1, + [](const Item &item) { return item.entry; }); if (ObliviousDebugCheckEnabled()) { std::vector oblivious_results(data, data + size); @@ -605,28 +644,33 @@ struct WQSummary { this->CopyFrom(sa); return; } - using EntryWithPartyInfo = EntryWithPartyInfo; std::vector merged_party_entrys(this->size); // Fill party info and build bitonic sequence. + // std::transform(sa.data, sa.data + sa.size, merged_party_entrys.begin(), [](const Entry &entry) { - return EntryWithPartyInfo{entry, true}; - }); - std::transform(sb.data, sb.data + sb.size, - merged_party_entrys.begin() + sa.size, - [](const Entry &entry) { - return EntryWithPartyInfo{entry, false}; + bool is_dummy = ObliviousChoose( + entry.value == std::numeric_limits::max(), true, + false); + return EntryWithPartyInfo{entry, true, is_dummy}; }); + std::transform( + sb.data, sb.data + sb.size, merged_party_entrys.begin() + sa.size, + [](const Entry &entry) { + bool is_dummy = ObliviousChoose( + entry.value == std::numeric_limits::max(), true, false); + return EntryWithPartyInfo{entry, false, is_dummy}; + }); // Build bitonic sequence. std::reverse(merged_party_entrys.begin(), merged_party_entrys.begin() + sa.size); // Bitonic merge. // ObliviousSort(merged_party_entrys.begin(), merged_party_entrys.end()); ObliviousMerge(merged_party_entrys.begin(), merged_party_entrys.end()); - // Forward pass to compute rmin. + // Forward pass don`t need Oblivious RType a_prev_rmin = 0; RType b_prev_rmin = 0; for (size_t idx = 0; idx < merged_party_entrys.size(); ++idx) { @@ -642,10 +686,12 @@ struct WQSummary { // Save first. RType next_aprev_rmin = ObliviousChoose( - merged_party_entrys[idx].is_party_a, + merged_party_entrys[idx].is_party_a && + !merged_party_entrys[idx].is_dummy, merged_party_entrys[idx].entry.RMinNext(), a_prev_rmin); RType next_bprev_rmin = ObliviousChoose( - !merged_party_entrys[idx].is_party_a, + !merged_party_entrys[idx].is_party_a && + !merged_party_entrys[idx].is_dummy, merged_party_entrys[idx].entry.RMinNext(), b_prev_rmin); // This is a. Need to add previous b->RMinNext(). @@ -663,9 +709,22 @@ struct WQSummary { } // Backward pass to compute rmax. + // Backward Algo: + // 1、 find really data[last].rmax for sa and sb assign to prev_rmax + // 2、 use is_dummy to contral backward computation dataflow RType a_prev_rmax = sa.data[sa.size - 1].rmax; RType b_prev_rmax = sb.data[sb.size - 1].rmax; + + for (int idx = 0; idx < sa.size; idx++) { + a_prev_rmax = ObliviousChoose(sa.data[idx].rmax > a_prev_rmax, + sa.data[idx].rmax, a_prev_rmax); + } + for (int idx = 0; idx < sb.size; idx++) { + b_prev_rmax = ObliviousChoose(sb.data[idx].rmax > b_prev_rmax, + sb.data[idx].rmax, b_prev_rmax); + } size_t duplicate_count = 0; + size_t dummy_count = 0; for (ssize_t idx = merged_party_entrys.size() - 1; idx >= 0; --idx) { bool equal_prev = idx == 0 ? false @@ -676,28 +735,33 @@ struct WQSummary { ? false : ObliviousEqual(merged_party_entrys[idx].entry.value, merged_party_entrys[idx + 1].entry.value); + bool dummy_item = merged_party_entrys[idx].is_dummy; duplicate_count += ObliviousChoose(equal_next, 1, 0); + dummy_count += ObliviousChoose(merged_party_entrys[idx].is_dummy, 1, 0); // Need to save first since the rmax will be overwritten. RType next_aprev_rmax = ObliviousChoose( - merged_party_entrys[idx].is_party_a, + merged_party_entrys[idx].is_party_a && + !merged_party_entrys[idx].is_dummy, merged_party_entrys[idx].entry.RMaxPrev(), a_prev_rmax); RType next_bprev_rmax = ObliviousChoose( - !merged_party_entrys[idx].is_party_a, + !merged_party_entrys[idx].is_party_a && + !merged_party_entrys[idx].is_dummy, merged_party_entrys[idx].entry.RMaxPrev(), b_prev_rmax); - // Add peer RMaxPrev. RType rmax_to_add = ObliviousChoose(merged_party_entrys[idx].is_party_a, b_prev_rmax, a_prev_rmax); // Handle equals. - RType rmin_to_add = - ObliviousChoose(equal_prev, merged_party_entrys[idx - 1].entry.rmin, - static_cast(0)); - RType wmin_to_add = - ObliviousChoose(equal_prev, merged_party_entrys[idx - 1].entry.wmin, - static_cast(0)); - rmax_to_add = ObliviousChoose( - equal_prev, merged_party_entrys[idx - 1].entry.rmax, rmax_to_add); + // Handle dummys + RType rmin_to_add = ObliviousChoose( + equal_prev && !dummy_item, merged_party_entrys[idx - 1].entry.rmin, + static_cast(0)); + RType wmin_to_add = ObliviousChoose( + equal_prev && !dummy_item, merged_party_entrys[idx - 1].entry.wmin, + static_cast(0)); + rmax_to_add = + ObliviousChoose(equal_prev && !dummy_item, + merged_party_entrys[idx - 1].entry.rmax, rmax_to_add); // Update. merged_party_entrys[idx].entry.rmax += rmax_to_add; merged_party_entrys[idx].entry.rmin += rmin_to_add; @@ -706,17 +770,17 @@ struct WQSummary { // Copy rmin, rmax, wmin from previous if values are equal. // Value is ok to be infinite now since this is two party merge, at most // two items are the same given a specific value. - ObliviousAssign(equal_next, merged_party_entrys[idx + 1].entry, - merged_party_entrys[idx].entry, - &merged_party_entrys[idx].entry); - ObliviousAssign(equal_next, std::numeric_limits::max(), + ObliviousAssign( + equal_next && !dummy_item, merged_party_entrys[idx + 1].entry, + merged_party_entrys[idx].entry, &merged_party_entrys[idx].entry); + ObliviousAssign(equal_next && !dummy_item, + std::numeric_limits::max(), merged_party_entrys[idx].entry.value, &merged_party_entrys[idx].entry.value); a_prev_rmax = next_aprev_rmax; b_prev_rmax = next_bprev_rmax; } - // Bitonic sort to push duplicates to end of list. std::transform(merged_party_entrys.begin(), merged_party_entrys.end(), this->data, [](const EntryWithPartyInfo &party_entry) { @@ -726,10 +790,8 @@ struct WQSummary { ObliviousSort(this->data, this->data + this->size); // std::sort(this->data, this->data + this->size); LOG(DEBUG) << __func__ << " PASSED 3" << std::endl; - + // exit(1); // Need to confirm shrink. - this->size -= duplicate_count; - if (ObliviousDebugCheckEnabled()) { std::vector oblivious_results(this->data, this->data + this->size); RawSetCombine(sa, sb); @@ -822,7 +884,7 @@ struct WQSummary { // helper function to print the current content of sketch inline void Print() const { for (size_t i = 0; i < this->size; ++i) { - LOG(INFO) << "[" << i << "] rmin=" << data[i].rmin + LOG(CONSOLE) << "[" << i << "] rmin=" << data[i].rmin << ", rmax=" << data[i].rmax << ", wmin=" << data[i].wmin << ", v=" << data[i].value; } @@ -891,20 +953,31 @@ struct WXQSummary : public WQSummary { if (ObliviousSetPruneEnabled()) { return WQSummary::ObliviousSetPrune(src, maxsize); } - if (src.size <= maxsize) { - this->CopyFrom(src); + + size_t max_index = 0; + // find actually max item + for (size_t idx = 0; idx < src.size; idx++) { + max_index = ObliviousChoose( + src.data[idx].value != std::numeric_limits::max(), idx, + max_index); + } + max_index += 1; + + if (max_index <= maxsize) { + this->CopyFromSize(src, max_index); return; } RType begin = src.data[0].rmax; // n is number of points exclude the min/max points size_t n = maxsize - 2, nbig = 0; // these is the range of data exclude the min/max point - RType range = src.data[src.size - 1].rmin - begin; + RType range = src.data[max_index - 1].rmin - begin; + // RType range = src.data[src.size - 1].rmin - begin; // prune off zero weights if (range == 0.0f || maxsize <= 2) { // special case, contain only two effective data pts this->data[0] = src.data[0]; - this->data[1] = src.data[src.size - 1]; + this->data[1] = src.data[max_index - 1]; this->size = 2; return; } else { @@ -919,7 +992,7 @@ struct WXQSummary : public WQSummary { // first scan, grab all the big chunk // moving block index, exclude the two ends. size_t bid = 0; - for (size_t i = 1; i < src.size - 1; ++i) { + for (size_t i = 1; i < max_index - 1; ++i) { // detect big chunk data point in the middle // always save these data points. if (CheckLarge(src.data[i], chunk)) { @@ -931,8 +1004,8 @@ struct WXQSummary : public WQSummary { ++nbig; } } - if (bid != src.size - 2) { - mrange += src.data[src.size - 1].RMaxPrev() - src.data[bid].RMinNext(); + if (bid != max_index - 2) { + mrange += src.data[max_index - 1].RMaxPrev() - src.data[bid].RMinNext(); } } // assert: there cannot be more than n big data points @@ -951,8 +1024,8 @@ struct WXQSummary : public WQSummary { n = n - nbig; // find the rest of point size_t bid = 0, k = 1, lastidx = 0; - for (size_t end = 1; end < src.size; ++end) { - if (end == src.size - 1 || CheckLarge(src.data[end], chunk)) { + for (size_t end = 1; end < max_index; ++end) { + if (end == max_index - 1 || CheckLarge(src.data[end], chunk)) { if (bid != end - 1) { size_t i = bid; RType maxdx2 = src.data[end].RMaxPrev() * 2; @@ -1045,13 +1118,17 @@ struct GKSummary { size = src.size; std::memcpy(data, src.data, sizeof(Entry) * size); } + inline void CopyFromSize(const GKSummary &src, const size_t insize) { + size = insize; + std::memcpy(data, src.data, sizeof(Entry) * size); + } inline void CheckValid(RType eps) const { // assume always valid } /*! \brief used for debug purpose, print the summary */ inline void Print() const { for (size_t i = 0; i < size; ++i) { - LOG(INFO) << "x=" << data[i].value << "\t" + LOG(CONSOLE) << "x=" << data[i].value << "\t" << "[" << data[i].rmin << "," << data[i].rmax << "]"; } } @@ -1324,12 +1401,26 @@ class QuantileSketchTemplate { level[0].SetPrune(*out, limit_size); } } - out->CopyFrom(level[0]); + // filter out all the dummy item + size_t final_size = 0; + for (size_t idx = 0; idx < level[0].size; idx++) { + bool is_valid = !ObliviousEqual(out->data[idx].value, + std::numeric_limits::max()); + final_size += is_valid; + } + out->CopyFromSize(level[0], final_size); } else { if (out->size > limit_size) { temp.Reserve(limit_size); temp.SetPrune(*out, limit_size); - out->CopyFrom(temp); + // filter out all the dummy item + size_t final_size = 0; + for (size_t idx = 0; idx < out->size; idx++) { + bool is_valid = !ObliviousEqual(out->data[idx].value, + std::numeric_limits::max()); + final_size += is_valid; + } + out->CopyFromSize(temp, final_size); } } } @@ -1389,4 +1480,4 @@ class GKQuantileSketch : public QuantileSketchTemplate > {}; } // namespace common } // namespace xgboost -#endif // XGBOOST_COMMON_QUANTILE_H_ +#endif // XGBOOST_COMMON_QUANTILE_H_ \ No newline at end of file From a65d6bf23d8050f2fa0c7edd58054fdbe25c8cf9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=B7=E8=A1=8C?= Date: Tue, 14 Jul 2020 00:28:35 +0800 Subject: [PATCH 2/6] occlum g++ env bug fix --- include/xgboost/common/quantile.h | 41 +++++++++++++++++++------------ 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/include/xgboost/common/quantile.h b/include/xgboost/common/quantile.h index 7bce9afea..f3970a377 100644 --- a/include/xgboost/common/quantile.h +++ b/include/xgboost/common/quantile.h @@ -460,9 +460,9 @@ struct WQSummary { /*! * \brief set current summary to be obliviously pruned summary of src * assume data field is already allocated to be at least maxsize - * dummy item will rank last of return and will involved in following computation - * \param src source summary - * \param maxsize size we can afford in the pruned sketch + * dummy item will rank last of return and will involved in following + * computation \param src source summary \param maxsize size we can afford in + * the pruned sketch */ inline void ObliviousSetPrune(const WQSummary &src, size_t maxsize) { if (src.size <= maxsize) { @@ -538,7 +538,7 @@ struct WQSummary { last_selected_entry_value, &last_selected_entry_value); ObliviousAssign(do_select, std::numeric_limits::min(), items[i].rank, &items[i].rank); - ObliviousAssign(i == max_index - 1, src.data[i], lastEntry, &lastEntry); + ObliviousAssign(do_select, items[i].entry, kDummyEntryWithMaxValue, &items[i].entry); select_count += ObliviousChoose(do_select, 1, 0); @@ -549,6 +549,11 @@ struct WQSummary { ObliviousSort(items.begin(), items.end()); LOG(DEBUG) << __func__ << " PASSED 2" << std::endl; + // Assign actual last entry to lastEntry + for (size_t idx = 0; idx < src.size; ++idx) { + ObliviousAssign(idx == max_index - 1, src.data[idx], lastEntry, + &lastEntry); + } // Append actual last item to items vector for (size_t i = 0; i < src.size; i++) { ObliviousAssign(i == select_count, lastEntry, items[i].entry, @@ -645,7 +650,8 @@ struct WQSummary { return; } using EntryWithPartyInfo = EntryWithPartyInfo; - + const Entry kDummyEntryWithMaxValue{-1, -1, 0, + std::numeric_limits::max()}; std::vector merged_party_entrys(this->size); // Fill party info and build bitonic sequence. // @@ -714,7 +720,6 @@ struct WQSummary { // 2、 use is_dummy to contral backward computation dataflow RType a_prev_rmax = sa.data[sa.size - 1].rmax; RType b_prev_rmax = sb.data[sb.size - 1].rmax; - for (int idx = 0; idx < sa.size; idx++) { a_prev_rmax = ObliviousChoose(sa.data[idx].rmax > a_prev_rmax, sa.data[idx].rmax, a_prev_rmax); @@ -725,6 +730,8 @@ struct WQSummary { } size_t duplicate_count = 0; size_t dummy_count = 0; + Entry prevEntry = merged_party_entrys[merged_party_entrys.size() - 1].entry; + Entry nextEntry = kDummyEntryWithMaxValue; for (ssize_t idx = merged_party_entrys.size() - 1; idx >= 0; --idx) { bool equal_prev = idx == 0 ? false @@ -736,6 +743,11 @@ struct WQSummary { : ObliviousEqual(merged_party_entrys[idx].entry.value, merged_party_entrys[idx + 1].entry.value); bool dummy_item = merged_party_entrys[idx].is_dummy; + prevEntry = idx == 0 ? kDummyEntryWithMaxValue + : merged_party_entrys[idx - 1].entry; + nextEntry = idx == merged_party_entrys.size() - 1 + ? kDummyEntryWithMaxValue + : merged_party_entrys[idx + 1].entry; duplicate_count += ObliviousChoose(equal_next, 1, 0); dummy_count += ObliviousChoose(merged_party_entrys[idx].is_dummy, 1, 0); @@ -754,14 +766,11 @@ struct WQSummary { // Handle equals. // Handle dummys RType rmin_to_add = ObliviousChoose( - equal_prev && !dummy_item, merged_party_entrys[idx - 1].entry.rmin, - static_cast(0)); + equal_prev && !dummy_item, prevEntry.rmin, static_cast(0)); RType wmin_to_add = ObliviousChoose( - equal_prev && !dummy_item, merged_party_entrys[idx - 1].entry.wmin, - static_cast(0)); - rmax_to_add = - ObliviousChoose(equal_prev && !dummy_item, - merged_party_entrys[idx - 1].entry.rmax, rmax_to_add); + equal_prev && !dummy_item, prevEntry.wmin, static_cast(0)); + rmax_to_add = ObliviousChoose(equal_prev && !dummy_item, prevEntry.rmax, + rmax_to_add); // Update. merged_party_entrys[idx].entry.rmax += rmax_to_add; merged_party_entrys[idx].entry.rmin += rmin_to_add; @@ -770,9 +779,9 @@ struct WQSummary { // Copy rmin, rmax, wmin from previous if values are equal. // Value is ok to be infinite now since this is two party merge, at most // two items are the same given a specific value. - ObliviousAssign( - equal_next && !dummy_item, merged_party_entrys[idx + 1].entry, - merged_party_entrys[idx].entry, &merged_party_entrys[idx].entry); + ObliviousAssign(equal_next && !dummy_item, nextEntry, + merged_party_entrys[idx].entry, + &merged_party_entrys[idx].entry); ObliviousAssign(equal_next && !dummy_item, std::numeric_limits::max(), merged_party_entrys[idx].entry.value, From 68b3d29340ffc6d5c6f81643e4df3b5d4a3f4e4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=B7=E8=A1=8C?= Date: Tue, 14 Jul 2020 11:45:23 +0800 Subject: [PATCH 3/6] resolve cr comments --- include/xgboost/common/quantile.h | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/include/xgboost/common/quantile.h b/include/xgboost/common/quantile.h index f3970a377..2b244afbf 100644 --- a/include/xgboost/common/quantile.h +++ b/include/xgboost/common/quantile.h @@ -1,5 +1,6 @@ /*! * Copyright 2014 by Contributors + * Modifications Copyright 2020 by Secure XGBoost Contributors * \file quantile.h * \brief util to compute quantiles * \author Tianqi Chen @@ -231,9 +232,9 @@ template void CheckEqualSummary(const WQSummary &lhs, const WQSummary &rhs) { auto trace = [&]() { - LOG(CONSOLE) << "---------- lhs: "; + LOG(INFO) << "---------- lhs: "; lhs.Print(); - LOG(CONSOLE) << "---------- rhs: "; + LOG(INFO) << "---------- rhs: "; rhs.Print(); }; // DEBUG CHECK @@ -319,7 +320,7 @@ struct WQSummary { helper_entry.entry.weight = 0; } - size_t unique_count = 0; + for (size_t idx = 0; idx < qhelper.size(); ++idx) { // sum weight for same value qhelper[idx].entry.weight += queue[idx].weight; @@ -329,7 +330,6 @@ struct WQSummary { : !ObliviousEqual(qhelper[idx + 1].entry.value, qhelper[idx].entry.value); qhelper[idx].is_new = is_new; - unique_count += is_new; if (idx != qhelper.size() - 1) { // Accumulate when next is same with me, otherwise reset to zero. qhelper[idx + 1].entry.weight = @@ -753,11 +753,11 @@ struct WQSummary { // Need to save first since the rmax will be overwritten. RType next_aprev_rmax = ObliviousChoose( - merged_party_entrys[idx].is_party_a && + merged_party_entrys[idx].is_party_a & !merged_party_entrys[idx].is_dummy, merged_party_entrys[idx].entry.RMaxPrev(), a_prev_rmax); RType next_bprev_rmax = ObliviousChoose( - !merged_party_entrys[idx].is_party_a && + !merged_party_entrys[idx].is_party_a & !merged_party_entrys[idx].is_dummy, merged_party_entrys[idx].entry.RMaxPrev(), b_prev_rmax); // Add peer RMaxPrev. @@ -799,7 +799,6 @@ struct WQSummary { ObliviousSort(this->data, this->data + this->size); // std::sort(this->data, this->data + this->size); LOG(DEBUG) << __func__ << " PASSED 3" << std::endl; - // exit(1); // Need to confirm shrink. if (ObliviousDebugCheckEnabled()) { std::vector oblivious_results(this->data, this->data + this->size); @@ -893,7 +892,7 @@ struct WQSummary { // helper function to print the current content of sketch inline void Print() const { for (size_t i = 0; i < this->size; ++i) { - LOG(CONSOLE) << "[" << i << "] rmin=" << data[i].rmin + LOG(INFO) << "[" << i << "] rmin=" << data[i].rmin << ", rmax=" << data[i].rmax << ", wmin=" << data[i].wmin << ", v=" << data[i].value; } @@ -966,8 +965,10 @@ struct WXQSummary : public WQSummary { size_t max_index = 0; // find actually max item for (size_t idx = 0; idx < src.size; idx++) { + bool is_valid = !ObliviousEqual(src.data[idx].value, + std::numeric_limits::max()); max_index = ObliviousChoose( - src.data[idx].value != std::numeric_limits::max(), idx, + is_valid, idx, max_index); } max_index += 1; @@ -1137,7 +1138,7 @@ struct GKSummary { /*! \brief used for debug purpose, print the summary */ inline void Print() const { for (size_t i = 0; i < size; ++i) { - LOG(CONSOLE) << "x=" << data[i].value << "\t" + LOG(INFO) << "x=" << data[i].value << "\t" << "[" << data[i].rmin << "," << data[i].rmax << "]"; } } From cb2ebefbeec44729c7ff90e1e6db48515e2444b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=B7=E8=A1=8C?= Date: Tue, 14 Jul 2020 12:31:44 +0800 Subject: [PATCH 4/6] resolve cr comments --- include/xgboost/common/quantile.h | 55 ++++++++++++++----------------- 1 file changed, 25 insertions(+), 30 deletions(-) diff --git a/include/xgboost/common/quantile.h b/include/xgboost/common/quantile.h index 2b244afbf..8573ebbe1 100644 --- a/include/xgboost/common/quantile.h +++ b/include/xgboost/common/quantile.h @@ -10,13 +10,11 @@ #include #include - #include #include #include #include #include - #include "obl_primitives.h" namespace xgboost { @@ -320,7 +318,6 @@ struct WQSummary { helper_entry.entry.weight = 0; } - for (size_t idx = 0; idx < qhelper.size(); ++idx) { // sum weight for same value qhelper[idx].entry.weight += queue[idx].weight; @@ -441,7 +438,7 @@ struct WQSummary { } /*! * \brief debug function, validate whether the summary - * run consistency check to check if it is a valid summary + * run consistency check to check if it is a valid summary * \param eps the tolerate error level, used when RType is floating point and * some inconsistency could occur due to rounding error */ @@ -461,8 +458,9 @@ struct WQSummary { * \brief set current summary to be obliviously pruned summary of src * assume data field is already allocated to be at least maxsize * dummy item will rank last of return and will involved in following - * computation \param src source summary \param maxsize size we can afford in - * the pruned sketch + * computation + * \param src source summary \param maxsize size we can afford in + * the pruned sketch */ inline void ObliviousSetPrune(const WQSummary &src, size_t maxsize) { if (src.size <= maxsize) { @@ -529,7 +527,7 @@ struct WQSummary { // CASE max_index<=maxsize:All unique item will be select // CASE other : select unique after dx2 index bool do_select = ObliviousChoose( - max_index <= maxsize, + ObliviousLess(max_index , maxsize), items[i].entry.value != last_selected_entry_value && items[i].entry.value != std::numeric_limits::max(), !items[i - 1].has_entry && items[i].has_entry && @@ -658,17 +656,16 @@ struct WQSummary { std::transform(sa.data, sa.data + sa.size, merged_party_entrys.begin(), [](const Entry &entry) { bool is_dummy = ObliviousChoose( - entry.value == std::numeric_limits::max(), true, + ObliviousEqual(entry.value , std::numeric_limits::max()), true, false); return EntryWithPartyInfo{entry, true, is_dummy}; }); - std::transform( - sb.data, sb.data + sb.size, merged_party_entrys.begin() + sa.size, - [](const Entry &entry) { - bool is_dummy = ObliviousChoose( - entry.value == std::numeric_limits::max(), true, false); - return EntryWithPartyInfo{entry, false, is_dummy}; - }); + std::transform(sb.data, sb.data + sb.size, merged_party_entrys.begin() + sa.size, + [](const Entry &entry) { + bool is_dummy = ObliviousChoose( + ObliviousEqual(entry.value , std::numeric_limits::max()), true, false); + return EntryWithPartyInfo{entry, false, is_dummy}; + }); // Build bitonic sequence. std::reverse(merged_party_entrys.begin(), merged_party_entrys.begin() + sa.size); @@ -692,11 +689,11 @@ struct WQSummary { // Save first. RType next_aprev_rmin = ObliviousChoose( - merged_party_entrys[idx].is_party_a && + merged_party_entrys[idx].is_party_a & !merged_party_entrys[idx].is_dummy, merged_party_entrys[idx].entry.RMinNext(), a_prev_rmin); RType next_bprev_rmin = ObliviousChoose( - !merged_party_entrys[idx].is_party_a && + !merged_party_entrys[idx].is_party_a & !merged_party_entrys[idx].is_dummy, merged_party_entrys[idx].entry.RMinNext(), b_prev_rmin); @@ -721,11 +718,11 @@ struct WQSummary { RType a_prev_rmax = sa.data[sa.size - 1].rmax; RType b_prev_rmax = sb.data[sb.size - 1].rmax; for (int idx = 0; idx < sa.size; idx++) { - a_prev_rmax = ObliviousChoose(sa.data[idx].rmax > a_prev_rmax, + a_prev_rmax = ObliviousChoose(ObliviousGreater(sa.data[idx].rmax , a_prev_rmax), sa.data[idx].rmax, a_prev_rmax); } for (int idx = 0; idx < sb.size; idx++) { - b_prev_rmax = ObliviousChoose(sb.data[idx].rmax > b_prev_rmax, + b_prev_rmax = ObliviousChoose(ObliviousGreater(sb.data[idx].rmax , b_prev_rmax), sb.data[idx].rmax, b_prev_rmax); } size_t duplicate_count = 0; @@ -766,10 +763,10 @@ struct WQSummary { // Handle equals. // Handle dummys RType rmin_to_add = ObliviousChoose( - equal_prev && !dummy_item, prevEntry.rmin, static_cast(0)); + equal_prev & !dummy_item, prevEntry.rmin, static_cast(0)); RType wmin_to_add = ObliviousChoose( - equal_prev && !dummy_item, prevEntry.wmin, static_cast(0)); - rmax_to_add = ObliviousChoose(equal_prev && !dummy_item, prevEntry.rmax, + equal_prev & !dummy_item, prevEntry.wmin, static_cast(0)); + rmax_to_add = ObliviousChoose(equal_prev & !dummy_item, prevEntry.rmax, rmax_to_add); // Update. merged_party_entrys[idx].entry.rmax += rmax_to_add; @@ -779,10 +776,10 @@ struct WQSummary { // Copy rmin, rmax, wmin from previous if values are equal. // Value is ok to be infinite now since this is two party merge, at most // two items are the same given a specific value. - ObliviousAssign(equal_next && !dummy_item, nextEntry, + ObliviousAssign(equal_next & !dummy_item, nextEntry, merged_party_entrys[idx].entry, &merged_party_entrys[idx].entry); - ObliviousAssign(equal_next && !dummy_item, + ObliviousAssign(equal_next & !dummy_item, std::numeric_limits::max(), merged_party_entrys[idx].entry.value, &merged_party_entrys[idx].entry.value); @@ -893,8 +890,8 @@ struct WQSummary { inline void Print() const { for (size_t i = 0; i < this->size; ++i) { LOG(INFO) << "[" << i << "] rmin=" << data[i].rmin - << ", rmax=" << data[i].rmax << ", wmin=" << data[i].wmin - << ", v=" << data[i].value; + << ", rmax=" << data[i].rmax << ", wmin=" << data[i].wmin + << ", v=" << data[i].value; } } @@ -967,9 +964,7 @@ struct WXQSummary : public WQSummary { for (size_t idx = 0; idx < src.size; idx++) { bool is_valid = !ObliviousEqual(src.data[idx].value, std::numeric_limits::max()); - max_index = ObliviousChoose( - is_valid, idx, - max_index); + max_index = ObliviousChoose(is_valid, idx, max_index); } max_index += 1; @@ -1139,7 +1134,7 @@ struct GKSummary { inline void Print() const { for (size_t i = 0; i < size; ++i) { LOG(INFO) << "x=" << data[i].value << "\t" - << "[" << data[i].rmin << "," << data[i].rmax << "]"; + << "[" << data[i].rmin << "," << data[i].rmax << "]"; } } /*! From 7d4535d7d7fbab7768c43c86184dc2a9e1b09b08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=B7=E8=A1=8C?= Date: Tue, 14 Jul 2020 11:45:23 +0800 Subject: [PATCH 5/6] resolve cr comments --- include/xgboost/common/quantile.h | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/include/xgboost/common/quantile.h b/include/xgboost/common/quantile.h index 8573ebbe1..5f7c35dcb 100644 --- a/include/xgboost/common/quantile.h +++ b/include/xgboost/common/quantile.h @@ -24,7 +24,6 @@ bool ObliviousSetCombineEnabled(); bool ObliviousSetPruneEnabled(); bool ObliviousDebugCheckEnabled(); bool ObliviousEnabled(); -void SetObliviousMode(bool); template struct WQSummaryEntry { @@ -459,7 +458,8 @@ struct WQSummary { * assume data field is already allocated to be at least maxsize * dummy item will rank last of return and will involved in following * computation - * \param src source summary \param maxsize size we can afford in + * \param src source summary + * \param maxsize size we can afford in * the pruned sketch */ inline void ObliviousSetPrune(const WQSummary &src, size_t maxsize) { @@ -479,7 +479,7 @@ struct WQSummary { // find actually max item for (size_t idx = 0; idx < src.size; idx++) { max_index = ObliviousChoose( - src.data[idx].value != std::numeric_limits::max(), idx, + !ObliviousEqual(src.data[idx].value , std::numeric_limits::max()), idx, max_index); range = src.data[max_index].rmin - src.data[0].rmax; } @@ -499,7 +499,7 @@ struct WQSummary { // for (size_t i = 1; i < src.size; ++i) { Item obliviousItem = ObliviousChoose( - i < max_index - 1, + ObliviousLess(i , max_index - 1), Item{src.data[i], src.data[i].rmax + src.data[i].rmin, true}, Item{kDummyEntryWithMaxValue, std::numeric_limits::max(), true}); @@ -507,7 +507,7 @@ struct WQSummary { } for (size_t i = 1; i < src.size - 1; ++i) { Item obliviousItem = ObliviousChoose( - i < max_index - 1, + ObliviousLess(i , max_index - 1), Item{src.data[i], src.data[i].RMinNext() + src.data[i + 1].RMaxPrev(), true}, Item{kDummyEntryWithMaxValue, std::numeric_limits::max(), @@ -527,11 +527,11 @@ struct WQSummary { // CASE max_index<=maxsize:All unique item will be select // CASE other : select unique after dx2 index bool do_select = ObliviousChoose( - ObliviousLess(max_index , maxsize), - items[i].entry.value != last_selected_entry_value && - items[i].entry.value != std::numeric_limits::max(), - !items[i - 1].has_entry && items[i].has_entry && - items[i].entry.value != last_selected_entry_value); + ObliviousLessOrEqual(max_index , maxsize), + !ObliviousEqual(items[i].entry.value , last_selected_entry_value) & + !ObliviousEqual(items[i].entry.value , std::numeric_limits::max()), + !items[i - 1].has_entry & items[i].has_entry & + !ObliviousEqual(items[i].entry.value , last_selected_entry_value)); ObliviousAssign(do_select, items[i].entry.value, last_selected_entry_value, &last_selected_entry_value); ObliviousAssign(do_select, std::numeric_limits::min(), @@ -549,12 +549,12 @@ struct WQSummary { // Assign actual last entry to lastEntry for (size_t idx = 0; idx < src.size; ++idx) { - ObliviousAssign(idx == max_index - 1, src.data[idx], lastEntry, + ObliviousAssign(ObliviousEqual(idx , max_index - 1), src.data[idx], lastEntry, &lastEntry); } // Append actual last item to items vector for (size_t i = 0; i < src.size; i++) { - ObliviousAssign(i == select_count, lastEntry, items[i].entry, + ObliviousAssign(ObliviousEqual(i , select_count), lastEntry, items[i].entry, &items[i].entry); } From 148a516d4e4ae64aabb7c7ce2b72e0408563b73f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=B7=E8=A1=8C?= Date: Tue, 21 Jul 2020 15:27:38 +0800 Subject: [PATCH 6/6] pin tool script use for side channel leakage test --- demo/quantile-mem-trace/.gitignore | 4 + demo/quantile-mem-trace/README.md | 45 + demo/quantile-mem-trace/gen_arr.sh | 10 + demo/quantile-mem-trace/make.sh | 19 + demo/quantile-mem-trace/profile_script.sh | 7 + demo/quantile-mem-trace/test_A.cc | 36 + demo/quantile-mem-trace/test_B.cc | 36 + include/xgboost/common/pin_quantile.h | 1493 +++++++++++++++++++++ 8 files changed, 1650 insertions(+) create mode 100644 demo/quantile-mem-trace/.gitignore create mode 100644 demo/quantile-mem-trace/README.md create mode 100755 demo/quantile-mem-trace/gen_arr.sh create mode 100755 demo/quantile-mem-trace/make.sh create mode 100755 demo/quantile-mem-trace/profile_script.sh create mode 100644 demo/quantile-mem-trace/test_A.cc create mode 100644 demo/quantile-mem-trace/test_B.cc create mode 100644 include/xgboost/common/pin_quantile.h diff --git a/demo/quantile-mem-trace/.gitignore b/demo/quantile-mem-trace/.gitignore new file mode 100644 index 000000000..7317e2550 --- /dev/null +++ b/demo/quantile-mem-trace/.gitignore @@ -0,0 +1,4 @@ +arr_*.h +test_A +test_B +*.trace \ No newline at end of file diff --git a/demo/quantile-mem-trace/README.md b/demo/quantile-mem-trace/README.md new file mode 100644 index 000000000..af68db1d9 --- /dev/null +++ b/demo/quantile-mem-trace/README.md @@ -0,0 +1,45 @@ + +## Setup + +#### Install Intel Pin +Download the software + + wget https://software.intel.com/sites/landingpage/pintool/downloads/pin-3.11-97998-g7ecce2dac-gcc-linux.tar.gz + cp pin-3.11-97998-g7ecce2dac-gcc-linux.tar.gz ~ + cd ~ + tar -zxvf pin-3.11-97998-g7ecce2dac-gcc-linux.tar.gz + ln -s ~/pin-3.11-97998-g7ecce2dac-gcc-linux ~/pin-dir + +Set environment variable + + export PIN_ROOT=~/pin-dir + +Build Intel Pin + + cd $PIN_ROOT/source/tools + make all + +#### Disable ASLR + + echo 0 | sudo tee /proc/sys/kernel/randomize_va_space + +## Run the tests + +#### Build the source files +This will auto generate files `arr_A.h` and `arr_B.h` containing random arrays, and then build the programs `test_A.cc` and `test_B.cc`. + +**src/common/quantile.h should be replaced by src/common/pin_quantile.h** + +**because logging and Macro would effect compile and memtrace.** + + ./make.sh + +#### Execute the tests and capture memory trace + ./profile_script.sh + +#### Compare memory traces +Compare the memory traces captured during the runs. Trace files `test_A` and `test_B` traces should show no difference between them. + + diff test_A.trace test_B.trace + +Repeat the steps to run the tests with different random inputs. diff --git a/demo/quantile-mem-trace/gen_arr.sh b/demo/quantile-mem-trace/gen_arr.sh new file mode 100755 index 000000000..144999fb9 --- /dev/null +++ b/demo/quantile-mem-trace/gen_arr.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash + +FILE="arr_$1.h" +echo "std::vector V = {" > $FILE + +for i in {1..999} +do + echo -n $RANDOM"," >> $FILE +done +echo -n $RANDOM "};" >> $FILE diff --git a/demo/quantile-mem-trace/make.sh b/demo/quantile-mem-trace/make.sh new file mode 100755 index 000000000..b11d896e6 --- /dev/null +++ b/demo/quantile-mem-trace/make.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env bash + +set -e + +ret=$(cat /proc/sys/kernel/randomize_va_space) +if [ $ret -ne 0 ]; then + echo "ASLR is NOT disabled. Please disable ASLR." + exit 1 +fi + +#echo "Generating random arrays" +./gen_arr.sh A +./gen_arr.sh B + +echo "Building" +g++ -w -O2 -fno-strict-aliasing test_A.cc ../../src/common/quantile.cc ../../src/common/obl_primitives.cc -I../../ -I../../dmlc-core/include -I../../include -o test_A -std=c++11 -mavx2 +g++ -w -O2 -fno-strict-aliasing test_B.cc ../../src/common/quantile.cc ../../src/common/obl_primitives.cc -I../../ -I../../dmlc-core/include -I../../include -o test_B -std=c++11 -mavx2 + +echo "Done" diff --git a/demo/quantile-mem-trace/profile_script.sh b/demo/quantile-mem-trace/profile_script.sh new file mode 100755 index 000000000..a9def31bc --- /dev/null +++ b/demo/quantile-mem-trace/profile_script.sh @@ -0,0 +1,7 @@ +#!/usr/bin/env bash +echo "start test" +$PIN_ROOT/pin -t $PIN_ROOT/source/tools/ManualExamples/obj-intel64/pinatrace.so -- ./test_A; +mv pinatrace.out test_A.trace; + +$PIN_ROOT/pin -t $PIN_ROOT/source/tools/ManualExamples/obj-intel64/pinatrace.so -- ./test_B; +mv pinatrace.out test_B.trace; diff --git a/demo/quantile-mem-trace/test_A.cc b/demo/quantile-mem-trace/test_A.cc new file mode 100644 index 000000000..2a98eda58 --- /dev/null +++ b/demo/quantile-mem-trace/test_A.cc @@ -0,0 +1,36 @@ +#include +#include +#include + +#include +#include +#include + +#include "arr_A.h" +using namespace xgboost::common; + +int main(int argc, char* argv[]) { + WXQuantileSketch::SummaryContainer out; + WXQuantileSketch sketchs; + sketchs.Init(64, 1.0); + sketchs.limit_size = 50; + sketchs.nlevel = 3; + sketchs.inqueue.queue.resize(sketchs.limit_size * 2); + for (size_t i = 0; i < 100; i++) { + sketchs.inqueue.Push(V[i], 1); + } + + WXQuantileSketch::SummaryContainer sa; + WXQuantileSketch::SummaryContainer sb; + sa.Reserve(sketchs.inqueue.queue.size()); + sb.Reserve(sketchs.inqueue.queue.size()); + out.Reserve(sketchs.inqueue.queue.size()); + // test for MakeSummaryOblivious + sketchs.inqueue.MakeSummaryOblivious(&out); + sb.CopyFromSize(out, 30); + sa.CopyFromSize(out, 30); + // test fo ObliviousSetPrune + out.ObliviousSetPrune(out,10); + // test for ObliviousSetCombine + out.ObliviousSetCombine(sa,sb); +} diff --git a/demo/quantile-mem-trace/test_B.cc b/demo/quantile-mem-trace/test_B.cc new file mode 100644 index 000000000..f02c53ab2 --- /dev/null +++ b/demo/quantile-mem-trace/test_B.cc @@ -0,0 +1,36 @@ +#include +#include +#include + +#include +#include +#include + +#include "arr_B.h" +using namespace xgboost::common; + +int main(int argc, char* argv[]) { + WXQuantileSketch::SummaryContainer out; + WXQuantileSketch sketchs; + sketchs.Init(64, 1.0); + sketchs.limit_size = 50; + sketchs.nlevel = 3; + sketchs.inqueue.queue.resize(sketchs.limit_size * 2); + for (size_t i = 0; i < 100; i++) { + sketchs.inqueue.Push(V[i], 1); + } + + WXQuantileSketch::SummaryContainer sa; + WXQuantileSketch::SummaryContainer sb; + sa.Reserve(sketchs.inqueue.queue.size()); + sb.Reserve(sketchs.inqueue.queue.size()); + out.Reserve(sketchs.inqueue.queue.size()); + // test for MakeSummaryOblivious + sketchs.inqueue.MakeSummaryOblivious(&out); + sb.CopyFromSize(out, 30); + sa.CopyFromSize(out, 30); + // test fo ObliviousSetPrune + out.ObliviousSetPrune(out,10); + // test for ObliviousSetCombine + out.ObliviousSetCombine(sa,sb); +} diff --git a/include/xgboost/common/pin_quantile.h b/include/xgboost/common/pin_quantile.h new file mode 100644 index 000000000..57e03022f --- /dev/null +++ b/include/xgboost/common/pin_quantile.h @@ -0,0 +1,1493 @@ +/*! + * Copyright 2014 by Contributors + * Modifications Copyright 2020 by Secure XGBoost Contributors + * \file quantile.h + * \brief util to compute quantiles + * \author Tianqi Chen + */ +#ifndef XGBOOST_COMMON_QUANTILE_H_ +#define XGBOOST_COMMON_QUANTILE_H_ +#define XGBOOST_DEVICE + +#include +//#include +#include +#include +#include +#include +#include +#include "obl_primitives.h" + +namespace xgboost { +namespace common { + +bool ObliviousSetCombineEnabled(); +bool ObliviousSetPruneEnabled(); +bool ObliviousDebugCheckEnabled(); +bool ObliviousEnabled(); +void SetObliviousMode(bool); + +template +struct WQSummaryEntry { + /*! \brief minimum rank */ + RType rmin; + /*! \brief maximum rank */ + RType rmax; + /*! \brief maximum weight */ + RType wmin; + /*! \brief the value of data */ + DType value; + // constructor + XGBOOST_DEVICE WQSummaryEntry() {} // NOLINT + // constructor + XGBOOST_DEVICE WQSummaryEntry(RType rmin, RType rmax, RType wmin, DType value) + : rmin(rmin), rmax(rmax), wmin(wmin), value(value) {} + /*! + * \brief debug function, check Valid + * \param eps the tolerate level for violating the relation + */ + inline void CheckValid(RType eps = 0) const { +// CHECK(rmin >= 0 && rmax >= 0 && wmin >= 0) << "nonneg constraint"; +// CHECK(rmax - rmin - wmin > -eps) << "relation constraint: min/max"; + } + + // For bitonic sort/merge. + inline bool operator<(const WQSummaryEntry &b) const { + return value < b.value; + } + + inline bool operator==(const WQSummaryEntry &b) const { + return value == b.value && rmin == b.rmin && rmax == b.rmax && + wmin == b.wmin; + } + + inline bool operator!=(const WQSummaryEntry &b) const { + return !(*this == b); + } + + /*! \return rmin estimation for v strictly bigger than value */ + XGBOOST_DEVICE inline RType RMinNext() const { return rmin + wmin; } + /*! \return rmax estimation for v strictly smaller than value */ + XGBOOST_DEVICE inline RType RMaxPrev() const { return rmax - wmin; } +}; + +template +std::ostream &operator<<(std::ostream &out, + const WQSummaryEntry &entry) { + out << "[ v=" << entry.value << ", w=" << entry.wmin + << ", rmin,rmax=" << entry.rmin << "," << entry.rmax << " ]"; + return out; +} + +template +struct WQSummaryQEntry { + // value of the instance + DType value; + // weight of instance + RType weight; + // default constructor + WQSummaryQEntry() = default; + // constructor + WQSummaryQEntry(DType value, RType weight) : value(value), weight(weight) {} + // comparator on value + inline bool operator<(const WQSummaryQEntry &b) const { + return value < b.value; + } +}; + +template +std::ostream &operator<<(std::ostream &out, + const WQSummaryQEntry &entry) { + out << "[ v=" << entry.value << ", w=" << entry.weight << " ]"; + return out; +} + +template +struct WQSummaryQEntryHelper { + using QEntry = WQSummaryQEntry; + // Entry + QEntry entry; + // New + bool is_new; + // default constructor + WQSummaryQEntryHelper() = default; + // constructor + WQSummaryQEntryHelper(DType value, RType weight) + : entry(value, weight), is_new(false) {} + // ctor from entry + WQSummaryQEntryHelper(const QEntry &entry) : entry(entry), is_new(false) {} + // comparator + inline bool operator<(const WQSummaryQEntryHelper &b) const { + return entry < b.entry; + } +}; + +template +std::ostream &operator<<(std::ostream &out, + const WQSummaryQEntryHelper &entry) { + out << "[ entry=" << entry << ", is_new=" << entry.is_new + << ", wsum=" << entry.wsum << " ]"; + return out; +} + +template +struct PruneItem { + using Entry = WQSummaryEntry; + Entry entry; + RType rank; + bool has_entry; + + inline bool operator<(const PruneItem &rhs) const { + return rank < rhs.rank || + (rank == rhs.rank && entry.value < rhs.entry.value); + } +}; + +template +std::ostream &operator<<(std::ostream &out, + const PruneItem &item) { + out << item.entry << ", rank=" << item.rank + << ", has_entry=" << item.has_entry; + return out; +} + +template +struct EntryWithPartyInfo { + using Entry = WQSummaryEntry; + Entry entry; + bool is_party_a; + bool is_dummy; + + inline bool operator<(const EntryWithPartyInfo &b) const { + return entry < b.entry; + } +}; + +template +std::ostream &operator<<(std::ostream &out, + const EntryWithPartyInfo &item) { + out << item.entry << ", is_party_a=" << item.is_party_a; + return out; +} + +} // namespace common +} // namespace xgboost + +namespace obl { + +// Implement oblivious less. + +using SummaryEntry = ::xgboost::common::WQSummaryEntry; +using SummaryQEntry = ::xgboost::common::WQSummaryQEntry; +using SummaryPruneItem = ::xgboost::common::PruneItem; +using SummaryEntryWithPartyInfo = +::xgboost::common::EntryWithPartyInfo; + +template <> +struct less { + bool operator()(const SummaryEntry &x, const SummaryEntry &y) { + return ObliviousLess(x.value, y.value); + } +}; + +template <> +struct less { + bool operator()(const SummaryQEntry &x, const SummaryQEntry &y) { + return ObliviousLess(x.value, y.value); + } +}; + +template <> +struct less { + bool operator()(const SummaryPruneItem &a, const SummaryPruneItem &b) { + bool b0 = ObliviousLess(a.rank, b.rank); + bool same_rank = ObliviousEqual(a.rank, b.rank); + bool b1 = ObliviousChoose( + same_rank, ::obl::less()(a.entry, b.entry), + false); + bool ret = ObliviousChoose(b0, true, b1); + //CHECK_EQ(ret, a < b) << "a=" << a << ", b=" << b; + return ret; + } +}; + +template <> +struct less { + bool operator()(const SummaryEntryWithPartyInfo &a, + const SummaryEntryWithPartyInfo &b) { + return ObliviousLess(a.entry, b.entry); + } +}; + +} // namespace obl + +namespace xgboost { +namespace common { + +template +struct WQSummary; + +template +void CheckEqualSummary(const WQSummary &lhs, + const WQSummary &rhs) { +// auto trace = [&]() { +// LOG(INFO) << "---------- lhs: "; +// lhs.Print(); +// LOG(INFO) << "---------- rhs: "; +// rhs.Print(); +// }; +// // DEBUG CHECK +// if (lhs.size != rhs.size) { +// trace(); +// //LOG(FATAL) << "lhs.size=" << lhs.size << ", rhs.size=" << rhs.size; +// } +// for (size_t i = 0; i < lhs.size; ++i) { +// if (lhs.data[i] != rhs.data[i]) { +// trace(); +// LOG(FATAL) << "Results mismatch in index " << i; +// } +// } +} + +/*! + * \brief experimental wsummary + * \tparam DType type of data content + * \tparam RType type of rank + */ +template +struct WQSummary { + /*! \brief an entry in the sketch summary */ + using Entry = WQSummaryEntry; + + /*! \brief input data queue before entering the summary */ + struct Queue { + // entry in the queue + using QEntry = WQSummaryQEntry; + using QEntryHelper = WQSummaryQEntryHelper; + // the input queue + std::vector queue; + // end of the queue + size_t qtail; + // push data to the queue + inline void Push(DType x, RType w) { + if (qtail == 0 || queue[qtail - 1].value != x) { + queue[qtail++] = QEntry(x, w); + } else { + queue[qtail - 1].weight += w; + } + } + + inline void MakeSummary(WQSummary *out) { + if (ObliviousEnabled()) { + return MakeSummaryOblivious(out); + } else { + return MakeSummaryRaw(out); + } + } + + inline void MakeSummaryRaw(WQSummary *out) { + std::sort(queue.begin(), queue.begin() + qtail); + + out->size = 0; + // start update sketch + RType wsum = 0; + // construct data with unique weights + for (size_t i = 0; i < qtail;) { + size_t j = i + 1; + RType w = queue[i].weight; + while (j < qtail && queue[j].value == queue[i].value) { + w += queue[j].weight; + ++j; + } + out->data[out->size++] = Entry(wsum, wsum + w, w, queue[i].value); + wsum += w; + i = j; + } + } + /* MakeSummaryOblivious protect the unique_count variable. + * in->size == qhelper.size + * out->size == qhelper.size + * out->data == || normal unique data | dummy data || + * */ + inline void MakeSummaryOblivious(WQSummary *out) { + ObliviousSort(queue.begin(), queue.begin() + qtail); + + std::vector qhelper(queue.begin(), queue.begin() + qtail); + + for (auto &helper_entry : qhelper) { + // zero weights + helper_entry.entry.weight = 0; + } + + for (size_t idx = 0; idx < qhelper.size(); ++idx) { + // sum weight for same value + qhelper[idx].entry.weight += queue[idx].weight; + // next is not same as me + bool is_new = idx == qhelper.size() - 1 + ? true + : !ObliviousEqual(qhelper[idx + 1].entry.value, + qhelper[idx].entry.value); + qhelper[idx].is_new = is_new; + if (idx != qhelper.size() - 1) { + // Accumulate when next is same with me, otherwise reset to zero. + qhelper[idx + 1].entry.weight = + ObliviousChoose(is_new, 0.f, qhelper[idx].entry.weight); + } + } + + struct ValueSorter { + bool operator()(const QEntryHelper &a, const QEntryHelper &b) { + return ObliviousLess(a.entry.value, b.entry.value); + } + }; + + for (size_t idx = 0; idx < qhelper.size(); ++idx) { + qhelper[idx].entry.value = + ObliviousChoose(qhelper[idx].is_new, qhelper[idx].entry.value, + std::numeric_limits::max()); + } + + // Resort by value. + ObliviousSort(qhelper.begin(), qhelper.end(), ValueSorter()); + + out->size = 0; + RType wsum = 0; + // is_new represent first sight + for (size_t idx = 0; idx < qhelper.size(); ++idx) { + const RType w = qhelper[idx].entry.weight; + bool is_new = qhelper[idx].is_new; + ObliviousAssign(is_new, + Entry(wsum, wsum + w, w, qhelper[idx].entry.value), + Entry(-1, -1, 0, std::numeric_limits::max()), + &out->data[out->size++]); + wsum += w; + } + if (ObliviousDebugCheckEnabled()) { + std::vector oblivious_results(out->data, out->data + out->size); + this->MakeSummaryRaw(out); + CheckEqualSummary(*out, WQSummary(oblivious_results.data(), + oblivious_results.size())); + } + } + }; + /*! \brief data field */ + Entry *data; + /*! \brief number of elements in the summary */ + size_t size; + // constructor + WQSummary(Entry *data, size_t size) : data(data), size(size) {} + /*! + * \return the maximum error of the Summary + */ + inline RType MaxError() const { + RType res = data[0].rmax - data[0].rmin - data[0].wmin; + for (size_t i = 1; i < size; ++i) { + res = std::max(data[i].RMaxPrev() - data[i - 1].RMinNext(), res); + res = std::max(data[i].rmax - data[i].rmin - data[i].wmin, res); + } + return res; + } + /*! + * \brief query qvalue, start from istart + * \param qvalue the value we query for + * \param istart starting position + */ + inline Entry Query(DType qvalue, size_t &istart) const { // NOLINT(*) + while (istart < size && qvalue > data[istart].value) { + ++istart; + } + if (istart == size) { + RType rmax = data[size - 1].rmax; + return Entry(rmax, rmax, 0.0f, qvalue); + } + if (qvalue == data[istart].value) { + return data[istart]; + } else { + if (istart == 0) { + return Entry(0.0f, 0.0f, 0.0f, qvalue); + } else { + return Entry(data[istart - 1].RMinNext(), data[istart].RMaxPrev(), 0.0f, + qvalue); + } + } + } + /*! \return maximum rank in the summary */ + inline RType MaxRank() const { return data[size - 1].rmax; } + /*! + * \brief copy content from src + * \param src source sketch + */ + inline void CopyFrom(const WQSummary &src) { + size = src.size; + std::memcpy(data, src.data, sizeof(Entry) * size); + } + inline void CopyFromSize(const WQSummary &src, const size_t insize) { + size = insize; + std::memcpy(data, src.data, sizeof(Entry) * size); + } + inline void MakeFromSorted(const Entry *entries, size_t n) { + size = 0; + for (size_t i = 0; i < n;) { + size_t j = i + 1; + // ignore repeated values + for (; j < n && entries[j].value == entries[i].value; ++j) { + } + data[size++] = Entry(entries[i].rmin, entries[i].rmax, entries[i].wmin, + entries[i].value); + i = j; + } + } + /*! + * \brief debug function, validate whether the summary + * run consistency check to check if it is a valid summary + * \param eps the tolerate error level, used when RType is floating point and + * some inconsistency could occur due to rounding error + */ + inline void CheckValid(RType eps) const { +// for (size_t i = 0; i < size; ++i) { +// data[i].CheckValid(eps); +// if (i != 0) { +// CHECK(data[i].rmin >= data[i - 1].rmin + data[i - 1].wmin) +// << "rmin range constraint"; +// CHECK(data[i].rmax >= data[i - 1].rmax + data[i].wmin) +// << "rmax range constraint"; +// } +// } + } + + /*! + * \brief set current summary to be obliviously pruned summary of src + * assume data field is already allocated to be at least maxsize + * dummy item will rank last of return and will involved in following + * computation + * \param src source summary \param maxsize size we can afford in + * the pruned sketch + */ + inline void ObliviousSetPrune(const WQSummary &src, size_t maxsize) { + if (src.size <= maxsize) { + this->CopyFrom(src); + return; + } + const Entry kDummyEntryWithMaxValue{-1, -1, 0, + std::numeric_limits::max()}; + + // Make sure dx2 items are last one when `d == (rmax + rmin) / 2`. + const RType begin = src.data[0].rmax; + const RType n = maxsize - 1; + // max_index is equal to previous src.size + size_t max_index = 0; + RType range = 0; + // find actually max item + for (size_t idx = 0; idx < src.size; idx++) { + max_index = ObliviousChoose( + src.data[idx].value != std::numeric_limits::max(), idx, + max_index); + range = src.data[max_index].rmin - src.data[0].rmax; + } + max_index += 1; + + // Construct sort vector. + using Item = PruneItem; + std::vector items; + items.reserve(2 * src.size + n); + for (size_t k = 1; k < n; ++k) { + RType dx2 = 2 * ((k * range) / n + begin); + items.push_back(Item{kDummyEntryWithMaxValue, dx2, false}); + } + // ObliviousPrune contains Dummy item,So here we doing this on 2 cases + // CASE i < max_index: handle normal data + // CASE other: handle dummy data + // + for (size_t i = 1; i < src.size; ++i) { + Item obliviousItem = ObliviousChoose( + ObliviousLess(i , max_index - 1), + Item{src.data[i], src.data[i].rmax + src.data[i].rmin, true}, + Item{kDummyEntryWithMaxValue, std::numeric_limits::max(), + true}); + items.push_back(obliviousItem); + } + for (size_t i = 1; i < src.size - 1; ++i) { + Item obliviousItem = ObliviousChoose( + ObliviousLess(i , max_index - 1), + Item{src.data[i], src.data[i].RMinNext() + src.data[i + 1].RMaxPrev(), + true}, + Item{kDummyEntryWithMaxValue, std::numeric_limits::max(), + true}); + items.push_back(obliviousItem); + } + // Bitonic Sort. + //LOG(DEBUG) << __func__ << " BEGIN 1" << std::endl; + ObliviousSort(items.begin(), items.end()); + //LOG(DEBUG) << __func__ << " PASSED 1" << std::endl; + + // Choose entrys. + RType last_selected_entry_value = std::numeric_limits::min(); + size_t select_count = 0; + Entry lastEntry = items[0].entry; + for (size_t i = 1; i < items.size(); ++i) { + // CASE max_index<=maxsize:All unique item will be select + // CASE other : select unique after dx2 index + bool do_select = ObliviousChoose( + ObliviousLess(max_index , maxsize), + !ObliviousEqual(items[i].entry.value , last_selected_entry_value) & + !ObliviousEqual(items[i].entry.value , std::numeric_limits::max()), + !items[i - 1].has_entry & items[i].has_entry & + !ObliviousEqual(items[i].entry.value , last_selected_entry_value)); + ObliviousAssign(do_select, items[i].entry.value, + last_selected_entry_value, &last_selected_entry_value); + ObliviousAssign(do_select, std::numeric_limits::min(), + items[i].rank, &items[i].rank); + + ObliviousAssign(do_select, items[i].entry, kDummyEntryWithMaxValue, + &items[i].entry); + select_count += ObliviousChoose(do_select, 1, 0); + } + + // Bitonic Sort. + //LOG(DEBUG) << __func__ << " BEGIN 2" << std::endl; + ObliviousSort(items.begin(), items.end()); + //LOG(DEBUG) << __func__ << " PASSED 2" << std::endl; + + // Assign actual last entry to lastEntry + for (size_t idx = 0; idx < src.size; ++idx) { + ObliviousAssign(ObliviousEqual(idx , max_index - 1), src.data[idx], lastEntry, + &lastEntry); + } + // Append actual last item to items vector + for (size_t i = 0; i < src.size; i++) { + ObliviousAssign(ObliviousEqual(i , select_count), lastEntry, items[i].entry, + &items[i].entry); + } + + this->data[0] = src.data[0]; + this->size = maxsize; + + std::transform(items.begin(), items.begin() + maxsize - 1, this->data + 1, + [](const Item &item) { return item.entry; }); + + // if (ObliviousDebugCheckEnabled()) { + // std::vector oblivious_results(data, data + size); + // RawSetPrune(src, maxsize); + // CheckEqualSummary( + // *this, WQSummary(oblivious_results.data(), oblivious_results.size())); + // } + } + + /*! + * \brief set current summary to be pruned summary of src + * assume data field is already allocated to be at least maxsize + * \param src source summary + * \param maxsize size we can afford in the pruned sketch + */ + inline void RawSetPrune(const WQSummary &src, size_t maxsize) { + if (src.size <= maxsize) { + this->CopyFrom(src); + return; + } + const RType begin = src.data[0].rmax; + const RType range = src.data[src.size - 1].rmin - src.data[0].rmax; + const size_t n = maxsize - 1; + data[0] = src.data[0]; + this->size = 1; + // lastidx is used to avoid duplicated records + size_t i = 1, lastidx = 0; + for (size_t k = 1; k < n; ++k) { + RType dx2 = 2 * ((k * range) / n + begin); + // find first i such that d < (rmax[i+1] + rmin[i+1]) / 2 + while (i < src.size - 1 && + dx2 >= src.data[i + 1].rmax + src.data[i + 1].rmin) + ++i; + //CHECK(i != src.size - 1); + if (dx2 < src.data[i].RMinNext() + src.data[i + 1].RMaxPrev()) { + if (i != lastidx) { + data[size++] = src.data[i]; + lastidx = i; + } + } else { + if (i + 1 != lastidx) { + data[size++] = src.data[i + 1]; + lastidx = i + 1; + } + } + } + if (lastidx != src.size - 1) { + data[size++] = src.data[src.size - 1]; + } + } + /*! + * \brief set current summary to be pruned summary of src + * assume data field is already allocated to be at least maxsize + * \param src source summary + * \param maxsize size we can afford in the pruned sketch + */ + + inline void SetPrune(const WQSummary &src, size_t maxsize) { + if (ObliviousSetPruneEnabled()) + return ObliviousSetPrune(src, maxsize); + else + return RawSetPrune(src, maxsize); + } + + /*! + * \brief set current summary to be merged summary of sa and sb + * \param sa first input summary to be merged + * \param sb second input summary to be merged + */ + inline void ObliviousSetCombine(const WQSummary &sa, const WQSummary &sb) { + this->size = sa.size + sb.size; + if (this->size == 0) { + return; + } + + // TODO: need confirm. + if (sa.size == 0) { + this->CopyFrom(sb); + return; + } + if (sb.size == 0) { + this->CopyFrom(sa); + return; + } + using EntryWithPartyInfo = EntryWithPartyInfo; + const Entry kDummyEntryWithMaxValue{-1, -1, 0, + std::numeric_limits::max()}; + std::vector merged_party_entrys(this->size); + // Fill party info and build bitonic sequence. + // + std::transform(sa.data, sa.data + sa.size, merged_party_entrys.begin(), + [](const Entry &entry) { + bool is_dummy = ObliviousChoose( + ObliviousEqual(entry.value , std::numeric_limits::max()), true, + false); + return EntryWithPartyInfo{entry, true, is_dummy}; + }); + std::transform(sb.data, sb.data + sb.size, merged_party_entrys.begin() + sa.size, + [](const Entry &entry) { + bool is_dummy = ObliviousChoose( + ObliviousEqual(entry.value , std::numeric_limits::max()), true, false); + return EntryWithPartyInfo{entry, false, is_dummy}; + }); + // Build bitonic sequence. + std::reverse(merged_party_entrys.begin(), + merged_party_entrys.begin() + sa.size); + // Bitonic merge. + // ObliviousSort(merged_party_entrys.begin(), merged_party_entrys.end()); + ObliviousMerge(merged_party_entrys.begin(), merged_party_entrys.end()); + // Forward pass to compute rmin. + // Forward pass don`t need Oblivious + RType a_prev_rmin = 0; + RType b_prev_rmin = 0; + for (size_t idx = 0; idx < merged_party_entrys.size(); ++idx) { + bool equal_next = + (idx == merged_party_entrys.size() - 1) + ? false + : ObliviousEqual(merged_party_entrys[idx].entry.value, + merged_party_entrys[idx + 1].entry.value); + bool equal_prev = + idx == 0 ? false + : ObliviousEqual(merged_party_entrys[idx].entry.value, + merged_party_entrys[idx - 1].entry.value); + + // Save first. + RType next_aprev_rmin = ObliviousChoose( + merged_party_entrys[idx].is_party_a & + !merged_party_entrys[idx].is_dummy, + merged_party_entrys[idx].entry.RMinNext(), a_prev_rmin); + RType next_bprev_rmin = ObliviousChoose( + !merged_party_entrys[idx].is_party_a & + !merged_party_entrys[idx].is_dummy, + merged_party_entrys[idx].entry.RMinNext(), b_prev_rmin); + + // This is a. Need to add previous b->RMinNext(). + RType chosen_prev_rmin = ObliviousChoose( + merged_party_entrys[idx].is_party_a, b_prev_rmin, a_prev_rmin); + + // Update rmin. Skip for equal groups now. + RType rmin_to_add = ObliviousChoose( + equal_next || equal_prev, static_cast(0), chosen_prev_rmin); + merged_party_entrys[idx].entry.rmin += rmin_to_add; + + // Update prev_rmin. + a_prev_rmin = next_aprev_rmin; + b_prev_rmin = next_bprev_rmin; + } + + // Backward pass to compute rmax. + // Backward Algo: + // 1、 find really data[last].rmax for sa and sb assign to prev_rmax + // 2、 use is_dummy to contral backward computation dataflow + RType a_prev_rmax = sa.data[sa.size - 1].rmax; + RType b_prev_rmax = sb.data[sb.size - 1].rmax; + for (int idx = 0; idx < sa.size; idx++) { + a_prev_rmax = ObliviousChoose(ObliviousGreater(sa.data[idx].rmax , a_prev_rmax), + sa.data[idx].rmax, a_prev_rmax); + } + for (int idx = 0; idx < sb.size; idx++) { + b_prev_rmax = ObliviousChoose(ObliviousGreater(sb.data[idx].rmax , b_prev_rmax), + sb.data[idx].rmax, b_prev_rmax); + } + size_t duplicate_count = 0; + size_t dummy_count = 0; + Entry prevEntry = merged_party_entrys[merged_party_entrys.size() - 1].entry; + Entry nextEntry = kDummyEntryWithMaxValue; + for (ssize_t idx = merged_party_entrys.size() - 1; idx >= 0; --idx) { + bool equal_prev = + idx == 0 ? false + : ObliviousEqual(merged_party_entrys[idx].entry.value, + merged_party_entrys[idx - 1].entry.value); + bool equal_next = + idx == merged_party_entrys.size() - 1 + ? false + : ObliviousEqual(merged_party_entrys[idx].entry.value, + merged_party_entrys[idx + 1].entry.value); + bool dummy_item = merged_party_entrys[idx].is_dummy; + prevEntry = idx == 0 ? kDummyEntryWithMaxValue + : merged_party_entrys[idx - 1].entry; + nextEntry = idx == merged_party_entrys.size() - 1 + ? kDummyEntryWithMaxValue + : merged_party_entrys[idx + 1].entry; + duplicate_count += ObliviousChoose(equal_next, 1, 0); + dummy_count += ObliviousChoose(merged_party_entrys[idx].is_dummy, 1, 0); + + // Need to save first since the rmax will be overwritten. + RType next_aprev_rmax = ObliviousChoose( + merged_party_entrys[idx].is_party_a & + !merged_party_entrys[idx].is_dummy, + merged_party_entrys[idx].entry.RMaxPrev(), a_prev_rmax); + RType next_bprev_rmax = ObliviousChoose( + !merged_party_entrys[idx].is_party_a & + !merged_party_entrys[idx].is_dummy, + merged_party_entrys[idx].entry.RMaxPrev(), b_prev_rmax); + // Add peer RMaxPrev. + RType rmax_to_add = ObliviousChoose(merged_party_entrys[idx].is_party_a, + b_prev_rmax, a_prev_rmax); + // Handle equals. + // Handle dummys + RType rmin_to_add = ObliviousChoose( + equal_prev & !dummy_item, prevEntry.rmin, static_cast(0)); + RType wmin_to_add = ObliviousChoose( + equal_prev & !dummy_item, prevEntry.wmin, static_cast(0)); + rmax_to_add = ObliviousChoose(equal_prev & !dummy_item, prevEntry.rmax, + rmax_to_add); + // Update. + merged_party_entrys[idx].entry.rmax += rmax_to_add; + merged_party_entrys[idx].entry.rmin += rmin_to_add; + merged_party_entrys[idx].entry.wmin += wmin_to_add; + + // Copy rmin, rmax, wmin from previous if values are equal. + // Value is ok to be infinite now since this is two party merge, at most + // two items are the same given a specific value. + ObliviousAssign(equal_next & !dummy_item, nextEntry, + merged_party_entrys[idx].entry, + &merged_party_entrys[idx].entry); + ObliviousAssign(equal_next & !dummy_item, + std::numeric_limits::max(), + merged_party_entrys[idx].entry.value, + &merged_party_entrys[idx].entry.value); + + a_prev_rmax = next_aprev_rmax; + b_prev_rmax = next_bprev_rmax; + } + // Bitonic sort to push duplicates to end of list. + std::transform(merged_party_entrys.begin(), merged_party_entrys.end(), + this->data, [](const EntryWithPartyInfo &party_entry) { + return party_entry.entry; + }); + //LOG(DEBUG) << __func__ << " BEGIN 3" << std::endl; + ObliviousSort(this->data, this->data + this->size); + // std::sort(this->data, this->data + this->size); + //LOG(DEBUG) << __func__ << " PASSED 3" << std::endl; + // Need to confirm shrink. + if (ObliviousDebugCheckEnabled()) { + std::vector oblivious_results(this->data, this->data + this->size); + RawSetCombine(sa, sb); + CheckEqualSummary( + *this, WQSummary(oblivious_results.data(), oblivious_results.size())); + } + } + + /*! + * \brief set current summary to be merged summary of sa and sb + * \param sa first input summary to be merged + * \param sb second input summary to be merged + */ + inline void RawSetCombine(const WQSummary &sa, const WQSummary &sb) { + if (sa.size == 0) { + this->CopyFrom(sb); + return; + } + if (sb.size == 0) { + this->CopyFrom(sa); + return; + } + //CHECK(sa.size > 0 && sb.size > 0); + const Entry *a = sa.data, *a_end = sa.data + sa.size; + const Entry *b = sb.data, *b_end = sb.data + sb.size; + // extended rmin value + RType aprev_rmin = 0, bprev_rmin = 0; + Entry *dst = this->data; + while (a != a_end && b != b_end) { + // duplicated value entry + if (a->value == b->value) { + *dst = Entry(a->rmin + b->rmin, a->rmax + b->rmax, a->wmin + b->wmin, + a->value); + aprev_rmin = a->RMinNext(); + bprev_rmin = b->RMinNext(); + ++dst; + ++a; + ++b; + } else if (a->value < b->value) { + *dst = Entry(a->rmin + bprev_rmin, a->rmax + b->RMaxPrev(), a->wmin, + a->value); + aprev_rmin = a->RMinNext(); + ++dst; + ++a; + } else { + *dst = Entry(b->rmin + aprev_rmin, b->rmax + a->RMaxPrev(), b->wmin, + b->value); + bprev_rmin = b->RMinNext(); + ++dst; + ++b; + } + } + if (a != a_end) { + RType brmax = (b_end - 1)->rmax; + do { + *dst = Entry(a->rmin + bprev_rmin, a->rmax + brmax, a->wmin, a->value); + ++dst; + ++a; + } while (a != a_end); + } + if (b != b_end) { + RType armax = (a_end - 1)->rmax; + do { + *dst = Entry(b->rmin + aprev_rmin, b->rmax + armax, b->wmin, b->value); + ++dst; + ++b; + } while (b != b_end); + } + this->size = dst - data; + const RType tol = 10; + RType err_mingap, err_maxgap, err_wgap; + this->FixError(&err_mingap, &err_maxgap, &err_wgap); + //if (err_mingap > tol || err_maxgap > tol || err_wgap > tol) { + //LOG(INFO) << "mingap=" << err_mingap << ", maxgap=" << err_maxgap + // << ", wgap=" << err_wgap; + //} + //CHECK(size <= sa.size + sb.size) << "bug in combine"; + } + /*! + * \brief set current summary to be merged summary of sa and sb + * \param sa first input summary to be merged + * \param sb second input summary to be merged + */ + inline void SetCombine(const WQSummary &sa, const WQSummary &sb) { + if (ObliviousSetCombineEnabled()) + return ObliviousSetCombine(sa, sb); + else + return RawSetCombine(sa, sb); + } + // helper function to print the current content of sketch + inline void Print() const { + // for (size_t i = 0; i < this->size; ++i) { + // LOG(INFO) << "[" << i << "] rmin=" << data[i].rmin + // << ", rmax=" << data[i].rmax << ", wmin=" << data[i].wmin + // << ", v=" << data[i].value; + // } + } + + inline void CheckAndPrint() const { + //CheckValid(kRtEps); + //Print(); + } + + // try to fix rounding error + // and re-establish invariance + inline void FixError(RType *err_mingap, RType *err_maxgap, + RType *err_wgap) const { + *err_mingap = 0; + *err_maxgap = 0; + *err_wgap = 0; + RType prev_rmin = 0, prev_rmax = 0; + for (size_t i = 0; i < this->size; ++i) { + if (data[i].rmin < prev_rmin) { + data[i].rmin = prev_rmin; + *err_mingap = std::max(*err_mingap, prev_rmin - data[i].rmin); + } else { + prev_rmin = data[i].rmin; + } + if (data[i].rmax < prev_rmax) { + data[i].rmax = prev_rmax; + *err_maxgap = std::max(*err_maxgap, prev_rmax - data[i].rmax); + } + RType rmin_next = data[i].RMinNext(); + if (data[i].rmax < rmin_next) { + data[i].rmax = rmin_next; + *err_wgap = std::max(*err_wgap, data[i].rmax - rmin_next); + } + prev_rmax = data[i].rmax; + } + } + // check consistency of the summary + inline bool Check(const char *msg) const { + const float tol = 10.0f; + for (size_t i = 0; i < this->size; ++i) { + if (data[i].rmin + data[i].wmin > data[i].rmax + tol || + data[i].rmin < -1e-6f || data[i].rmax < -1e-6f) { +// LOG(INFO) << "---------- WQSummary::Check did not pass ----------"; + this->Print(); + return false; + } + } + return true; + } +}; + +/*! \brief try to do efficient pruning */ +template +struct WXQSummary : public WQSummary { + // redefine entry type + using Entry = typename WQSummary::Entry; + // constructor + WXQSummary(Entry *data, size_t size) : WQSummary(data, size) {} + // check if the block is large chunk + inline static bool CheckLarge(const Entry &e, RType chunk) { + return e.RMinNext() > e.RMaxPrev() + chunk; + } + // set prune + inline void SetPrune(const WQSummary &src, size_t maxsize) { + if (ObliviousSetPruneEnabled()) { + return WQSummary::ObliviousSetPrune(src, maxsize); + } + + size_t max_index = 0; + // find actually max item + for (size_t idx = 0; idx < src.size; idx++) { + bool is_valid = !ObliviousEqual(src.data[idx].value, + std::numeric_limits::max()); + max_index = ObliviousChoose(is_valid, idx, max_index); + } + max_index += 1; + + if (max_index <= maxsize) { + this->CopyFromSize(src, max_index); + return; + } + RType begin = src.data[0].rmax; + // n is number of points exclude the min/max points + size_t n = maxsize - 2, nbig = 0; + // these is the range of data exclude the min/max point + RType range = src.data[max_index - 1].rmin - begin; + // RType range = src.data[src.size - 1].rmin - begin; + // prune off zero weights + if (range == 0.0f || maxsize <= 2) { + // special case, contain only two effective data pts + this->data[0] = src.data[0]; + this->data[1] = src.data[max_index - 1]; + this->size = 2; + return; + } else { + range = std::max(range, static_cast(1e-3f)); + } + // Get a big enough chunk size, bigger than range / n + // (multiply by 2 is a safe factor) + const RType chunk = 2 * range / n; + // minimized range + RType mrange = 0; + { + // first scan, grab all the big chunk + // moving block index, exclude the two ends. + size_t bid = 0; + for (size_t i = 1; i < max_index - 1; ++i) { + // detect big chunk data point in the middle + // always save these data points. + if (CheckLarge(src.data[i], chunk)) { + if (bid != i - 1) { + // accumulate the range of the rest points + mrange += src.data[i].RMaxPrev() - src.data[bid].RMinNext(); + } + bid = i; + ++nbig; + } + } + if (bid != max_index - 2) { + mrange += src.data[max_index - 1].RMaxPrev() - src.data[bid].RMinNext(); + } + } + // assert: there cannot be more than n big data points + //if (nbig >= n) { + // // see what was the case + // LOG(INFO) << " check quantile stats, nbig=" << nbig << ", n=" << n; + // LOG(INFO) << " srcsize=" << src.size << ", maxsize=" << maxsize + // << ", range=" << range << ", chunk=" << chunk; + // src.Print(); + // CHECK(nbig < n) << "quantile: too many large chunk"; + //} + this->data[0] = src.data[0]; + this->size = 1; + // The counter on the rest of points, to be selected equally from small + // chunks. + n = n - nbig; + // find the rest of point + size_t bid = 0, k = 1, lastidx = 0; + for (size_t end = 1; end < max_index; ++end) { + if (end == max_index - 1 || CheckLarge(src.data[end], chunk)) { + if (bid != end - 1) { + size_t i = bid; + RType maxdx2 = src.data[end].RMaxPrev() * 2; + for (; k < n; ++k) { + RType dx2 = 2 * ((k * mrange) / n + begin); + if (dx2 >= maxdx2) break; + while (i < end && + dx2 >= src.data[i + 1].rmax + src.data[i + 1].rmin) + ++i; + if (i == end) break; + if (dx2 < src.data[i].RMinNext() + src.data[i + 1].RMaxPrev()) { + if (i != lastidx) { + this->data[this->size++] = src.data[i]; + lastidx = i; + } + } else { + if (i + 1 != lastidx) { + this->data[this->size++] = src.data[i + 1]; + lastidx = i + 1; + } + } + } + } + if (lastidx != end) { + this->data[this->size++] = src.data[end]; + lastidx = end; + } + bid = end; + // shift base by the gap + begin += src.data[bid].RMinNext() - src.data[bid].RMaxPrev(); + } + } + } +}; +/*! + * \brief traditional GK summary + */ +template +struct GKSummary { + /*! \brief an entry in the sketch summary */ + struct Entry { + /*! \brief minimum rank */ + RType rmin; + /*! \brief maximum rank */ + RType rmax; + /*! \brief the value of data */ + DType value; + // constructor + Entry() = default; + // constructor + Entry(RType rmin, RType rmax, DType value) + : rmin(rmin), rmax(rmax), value(value) {} + }; + /*! \brief input data queue before entering the summary */ + struct Queue { + // the input queue + std::vector queue; + // end of the queue + size_t qtail; + // push data to the queue + inline void Push(DType x, RType w) { queue[qtail++] = x; } + inline void MakeSummary(GKSummary *out) { + std::sort(queue.begin(), queue.begin() + qtail); + out->size = qtail; + for (size_t i = 0; i < qtail; ++i) { + out->data[i] = Entry(i + 1, i + 1, queue[i]); + } + } + }; + /*! \brief data field */ + Entry *data; + /*! \brief number of elements in the summary */ + size_t size; + GKSummary(Entry *data, size_t size) : data(data), size(size) {} + /*! \brief the maximum error of the summary */ + inline RType MaxError() const { + RType res = 0; + for (size_t i = 1; i < size; ++i) { + res = std::max(data[i].rmax - data[i - 1].rmin, res); + } + return res; + } + /*! \return maximum rank in the summary */ + inline RType MaxRank() const { return data[size - 1].rmax; } + /*! + * \brief copy content from src + * \param src source sketch + */ + inline void CopyFrom(const GKSummary &src) { + size = src.size; + std::memcpy(data, src.data, sizeof(Entry) * size); + } + inline void CopyFromSize(const GKSummary &src, const size_t insize) { + size = insize; + std::memcpy(data, src.data, sizeof(Entry) * size); + } + inline void CheckValid(RType eps) const { + // assume always valid + } + /*! \brief used for debug purpose, print the summary */ + inline void Print() const { + // for (size_t i = 0; i < size; ++i) { + // LOG(INFO) << "x=" << data[i].value << "\t" + // << "[" << data[i].rmin << "," << data[i].rmax << "]"; + // } + } + /*! + * \brief set current summary to be pruned summary of src + * assume data field is already allocated to be at least maxsize + * \param src source summary + * \param maxsize size we can afford in the pruned sketch + */ + inline void SetPrune(const GKSummary &src, size_t maxsize) { + if (src.size <= maxsize) { + this->CopyFrom(src); + return; + } + const RType max_rank = src.MaxRank(); + this->size = maxsize; + data[0] = src.data[0]; + size_t n = maxsize - 1; + RType top = 1; + for (size_t i = 1; i < n; ++i) { + RType k = (i * max_rank) / n; + while (k > src.data[top + 1].rmax) ++top; + // assert src.data[top].rmin <= k + // because k > src.data[top].rmax >= src.data[top].rmin + if ((k - src.data[top].rmin) < (src.data[top + 1].rmax - k)) { + data[i] = src.data[top]; + } else { + data[i] = src.data[top + 1]; + } + } + data[n] = src.data[src.size - 1]; + } + inline void SetCombine(const GKSummary &sa, const GKSummary &sb) { + if (sa.size == 0) { + this->CopyFrom(sb); + return; + } + if (sb.size == 0) { + this->CopyFrom(sa); + return; + } + //CHECK(sa.size > 0 && sb.size > 0) << "invalid input for merge"; + const Entry *a = sa.data, *a_end = sa.data + sa.size; + const Entry *b = sb.data, *b_end = sb.data + sb.size; + this->size = sa.size + sb.size; + RType aprev_rmin = 0, bprev_rmin = 0; + Entry *dst = this->data; + while (a != a_end && b != b_end) { + if (a->value < b->value) { + *dst = Entry(bprev_rmin + a->rmin, a->rmax + b->rmax - 1, a->value); + aprev_rmin = a->rmin; + ++dst; + ++a; + } else { + *dst = Entry(aprev_rmin + b->rmin, b->rmax + a->rmax - 1, b->value); + bprev_rmin = b->rmin; + ++dst; + ++b; + } + } + if (a != a_end) { + RType bprev_rmax = (b_end - 1)->rmax; + do { + *dst = Entry(bprev_rmin + a->rmin, bprev_rmax + a->rmax, a->value); + ++dst; + ++a; + } while (a != a_end); + } + if (b != b_end) { + RType aprev_rmax = (a_end - 1)->rmax; + do { + *dst = Entry(aprev_rmin + b->rmin, aprev_rmax + b->rmax, b->value); + ++dst; + ++b; + } while (b != b_end); + } + //CHECK(dst == data + size) << "bug in combine"; + } +}; + +/*! + * \brief template for all quantile sketch algorithm + * that uses merge/prune scheme + * \tparam DType type of data content + * \tparam RType type of rank + * \tparam TSummary actual summary data structure it uses + */ +template +class QuantileSketchTemplate { + public: + /*! \brief type of summary type */ + using Summary = TSummary; + /*! \brief the entry type */ + using Entry = typename Summary::Entry; + /*! \brief same as summary, but use STL to backup the space */ + struct SummaryContainer : public Summary { + std::vector space; + SummaryContainer(const SummaryContainer &src) : Summary(nullptr, src.size) { + this->space = src.space; + this->data = dmlc::BeginPtr(this->space); + } + SummaryContainer() : Summary(nullptr, 0) {} + /*! \brief reserve space for summary */ + inline void Reserve(size_t size) { + if (size > space.size()) { + space.resize(size); + this->data = dmlc::BeginPtr(space); + } + } + /*! + * \brief set the space to be merge of all Summary arrays + * \param begin beginning position in the summary array + * \param end ending position in the Summary array + */ + inline void SetMerge(const Summary *begin, const Summary *end) { + //CHECK(begin < end) << "can not set combine to empty instance"; + size_t len = end - begin; + if (len == 1) { + this->Reserve(begin[0].size); + this->CopyFrom(begin[0]); + } else if (len == 2) { + this->Reserve(begin[0].size + begin[1].size); + this->SetMerge(begin[0], begin[1]); + } else { + // recursive merge + SummaryContainer lhs, rhs; + lhs.SetCombine(begin, begin + len / 2); + rhs.SetCombine(begin + len / 2, end); + this->Reserve(lhs.size + rhs.size); + this->SetCombine(lhs, rhs); + } + } + /*! + * \brief do elementwise combination of summary array + * this[i] = combine(this[i], src[i]) for each i + * \param src the source summary + * \param max_nbyte maximum number of byte allowed in here + */ + inline void Reduce(const Summary &src, size_t max_nbyte) { + this->Reserve((max_nbyte - sizeof(this->size)) / sizeof(Entry)); + SummaryContainer temp; + temp.Reserve(this->size + src.size); + temp.SetCombine(*this, src); + this->SetPrune(temp, space.size()); + } + /*! \brief return the number of bytes this data structure cost in + * serialization */ + inline static size_t CalcMemCost(size_t nentry) { + return sizeof(size_t) + sizeof(Entry) * nentry; + } + /*! \brief save the data structure into stream */ + template + inline void Save(TStream &fo) const { // NOLINT(*) + fo.Write(&(this->size), sizeof(this->size)); + if (this->size != 0) { + fo.Write(this->data, this->size * sizeof(Entry)); + } + } + /*! \brief load data structure from input stream */ + template + inline void Load(TStream &fi) { // NOLINT(*) + //CHECK_EQ( + fi.Read(&this->size, sizeof(this->size)); + //, sizeof(this->size)); + this->Reserve(this->size); + if (this->size != 0) { + //CHECK_EQ( + fi.Read(this->data, this->size * sizeof(Entry)); + //, + // this->size * sizeof(Entry)); + } + } + }; + /*! + * \brief initialize the quantile sketch, given the performance specification + * \param maxn maximum number of data points can be feed into sketch + * \param eps accuracy level of summary + */ + inline void Init(size_t maxn, double eps) { + LimitSizeLevel(maxn, eps, &nlevel, &limit_size); + // lazy reserve the space, if there is only one value, no need to allocate + // space + inqueue.queue.resize(1); + inqueue.qtail = 0; + data.clear(); + level.clear(); + } + + inline static void LimitSizeLevel(size_t maxn, double eps, size_t *out_nlevel, + size_t *out_limit_size) { + size_t &nlevel = *out_nlevel; + size_t &limit_size = *out_limit_size; + nlevel = 1; + while (true) { + limit_size = static_cast(ceil(nlevel / eps)) + 1; + size_t n = (1ULL << nlevel); + if (n * limit_size >= maxn) break; + ++nlevel; + } + // check invariant + size_t n = (1ULL << nlevel); + // CHECK(n * limit_size >= maxn) << "invalid init parameter"; + // CHECK(nlevel <= limit_size * eps) << "invalid init parameter"; + } + + /*! + * \brief add an element to a sketch + * \param x The element added to the sketch + * \param w The weight of the element. + */ + inline void Push(DType x, RType w = 1) { + if (w == static_cast(0)) return; + if (inqueue.qtail == inqueue.queue.size()) { + // jump from lazy one value to limit_size * 2 + if (inqueue.queue.size() == 1) { + inqueue.queue.resize(limit_size * 2); + } else { + temp.Reserve(limit_size * 2); + inqueue.MakeSummary(&temp); + // cleanup queue + inqueue.qtail = 0; + this->PushTemp(); + } + } + inqueue.Push(x, w); + } + + inline void PushSummary(const Summary &summary) { + temp.Reserve(limit_size * 2); + temp.SetPrune(summary, limit_size * 2); + PushTemp(); + } + + /*! \brief push up temp */ + inline void PushTemp() { + temp.Reserve(limit_size * 2); + for (size_t l = 1; true; ++l) { + this->InitLevel(l + 1); + // check if level l is empty + if (level[l].size == 0) { + level[l].SetPrune(temp, limit_size); + break; + } else { + // level 0 is actually temp space + level[0].SetPrune(temp, limit_size); + temp.SetCombine(level[0], level[l]); + if (temp.size > limit_size) { + // try next level + level[l].size = 0; + } else { + // if merged record is still smaller, no need to send to next level + level[l].CopyFrom(temp); + break; + } + } + } + } + /*! \brief get the summary after finalize */ + inline void GetSummary(SummaryContainer *out) { + if (level.size() != 0) { + out->Reserve(limit_size * 2); + } else { + out->Reserve(inqueue.queue.size()); + } + inqueue.MakeSummary(out); + if (level.size() != 0) { + level[0].SetPrune(*out, limit_size); + for (size_t l = 1; l < level.size(); ++l) { + if (level[l].size == 0) continue; + if (level[0].size == 0) { + level[0].CopyFrom(level[l]); + } else { + out->SetCombine(level[0], level[l]); + level[0].SetPrune(*out, limit_size); + } + } + // filter out all the dummy item + size_t final_size = 0; + for (size_t idx = 0; idx < level[0].size; idx++) { + bool is_valid = !ObliviousEqual(out->data[idx].value, + std::numeric_limits::max()); + final_size += is_valid; + } + out->CopyFromSize(level[0], final_size); + } else { + if (out->size > limit_size) { + temp.Reserve(limit_size); + temp.SetPrune(*out, limit_size); + // filter out all the dummy item + size_t final_size = 0; + for (size_t idx = 0; idx < out->size; idx++) { + bool is_valid = !ObliviousEqual(out->data[idx].value, + std::numeric_limits::max()); + final_size += is_valid; + } + out->CopyFromSize(temp, final_size); + } + } + } + // used for debug, check if the sketch is valid + inline void CheckValid(RType eps) const { + for (size_t l = 1; l < level.size(); ++l) { + level[l].CheckValid(eps); + } + } + // initialize level space to at least nlevel + inline void InitLevel(size_t nlevel) { + if (level.size() >= nlevel) return; + data.resize(limit_size * nlevel); + level.resize(nlevel, Summary(nullptr, 0)); + for (size_t l = 0; l < level.size(); ++l) { + level[l].data = dmlc::BeginPtr(data) + l * limit_size; + } + } + // input data queue + typename Summary::Queue inqueue; + // number of levels + size_t nlevel; + // size of summary in each level + size_t limit_size; + // the level of each summaries + std::vector level; + // content of the summary + std::vector data; + // temporal summary, used for temp-merge + SummaryContainer temp; +}; + +/*! + * \brief Quantile sketch use WQSummary + * \tparam DType type of data content + * \tparam RType type of rank + */ +template +class WQuantileSketch + : public QuantileSketchTemplate > {}; + +/*! + * \brief Quantile sketch use WXQSummary + * \tparam DType type of data content + * \tparam RType type of rank + */ +template +class WXQuantileSketch + : public QuantileSketchTemplate > {}; +/*! + * \brief Quantile sketch use WQSummary + * \tparam DType type of data content + * \tparam RType type of rank + */ +template +class GKQuantileSketch + : public QuantileSketchTemplate > {}; +} // namespace common +} // namespace xgboost +#endif // XGBOOST_COMMON_QUANTILE_H_