Skip to content

Commit

Permalink
Prevent copying SimpleDMatrix. (#5453)
Browse files Browse the repository at this point in the history
* Set default dtor for SimpleDMatrix to initialize default copy ctor, which is
deleted due to unique ptr.

* Remove commented code.
* Remove warning for calling host function (std::max).
* Remove warning for initialization order.
* Remove warning for unused variables.
  • Loading branch information
trivialfis authored Apr 1, 2020
1 parent e86030c commit 29c6ad9
Show file tree
Hide file tree
Showing 9 changed files with 36 additions and 50 deletions.
5 changes: 5 additions & 0 deletions src/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ inline std::vector<std::string> Split(const std::string& s, char delim) {
return ret;
}

template <typename T>
XGBOOST_DEVICE T Max(T a, T b) {
return a < b ? b : a;
}

// simple routine to convert any data to string
template<typename T>
inline std::string ToString(const T& data) {
Expand Down
6 changes: 4 additions & 2 deletions src/common/compressed_iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include <cstddef>
#include <algorithm>

#include "common.h"

#ifdef __CUDACC__
#include "device_helpers.cuh"
#endif // __CUDACC__
Expand All @@ -29,12 +31,12 @@ inline void ClearBit(CompressedByteT *byte, int bit_idx) {
*byte &= ~(1 << bit_idx);
}
static const int kPadding = 4; // Assign padding so we can read slightly off
// the beginning of the array
// the beginning of the array

// The number of bits required to represent a given unsigned range
inline XGBOOST_DEVICE size_t SymbolBits(size_t num_symbols) {
auto bits = std::ceil(log2(static_cast<double>(num_symbols)));
return std::max(static_cast<size_t>(bits), size_t(1));
return common::Max(static_cast<size_t>(bits), size_t(1));
}
} // namespace detail

Expand Down
4 changes: 2 additions & 2 deletions src/common/hist_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,9 @@ struct Index {
std::vector<uint8_t> data_;
std::vector<uint32_t> offset_; // size of this field is equal to number of features
void* data_ptr_;
uint32_t* offset_ptr_;
size_t p_;
BinTypeSize binTypeSize_;
size_t p_;
uint32_t* offset_ptr_;
Func func_;
};

Expand Down
2 changes: 1 addition & 1 deletion src/common/quantile.h
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ class QuantileSketchTemplate {
// check invariant
size_t n = (1ULL << nlevel);
CHECK(n * limit_size >= maxn) << "invalid init parameter";
CHECK(nlevel <= std::max(1, static_cast<int>(limit_size * eps)))
CHECK(nlevel <= std::max(static_cast<size_t>(1), static_cast<size_t>(limit_size * eps)))
<< "invalid init parameter";
}

Expand Down
19 changes: 0 additions & 19 deletions src/data/data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -427,25 +427,6 @@ DMatrix* DMatrix::Load(const std::string& uri,
return dmat;
}


/*
DMatrix* DMatrix::Create(std::unique_ptr<DataSource<SparsePage>>&& source,
const std::string& cache_prefix) {
if (cache_prefix.length() == 0) {
// Data split mode is fixed to be row right now.
rabit::Allreduce<rabit::op::Max>(&source->info.num_col_, 1);
return new data::SimpleDMatrix(std::move(source));
} else {
#if DMLC_ENABLE_STD_THREAD
return new data::SparsePageDMatrix(std::move(source), cache_prefix);
#else
LOG(FATAL) << "External memory is not enabled in mingw";
return nullptr;
#endif // DMLC_ENABLE_STD_THREAD
}
}
*/

template <typename AdapterT>
DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread,
const std::string& cache_prefix, size_t page_size ) {
Expand Down
1 change: 1 addition & 0 deletions src/data/simple_dmatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class SimpleDMatrix : public DMatrix {
explicit SimpleDMatrix(AdapterT* adapter, float missing, int nthread);

explicit SimpleDMatrix(dmlc::Stream* in_stream);
~SimpleDMatrix() override = default;

void SaveToLocalFile(const std::string& fname);

Expand Down
1 change: 0 additions & 1 deletion src/data/sparse_page_dmatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ class SparsePageDMatrix : public DMatrix {
row_source_.reset(new data::SparsePageSource(adapter, missing, nthread,
cache_prefix, page_size));
}
// Set number of threads but keep old value so we can reset it after
~SparsePageDMatrix() override = default;

MetaInfo& Info() override;
Expand Down
2 changes: 0 additions & 2 deletions tests/cpp/common/test_hist_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,6 @@ TEST(hist_util, IndexBinData) {
size_t constexpr kRows = 100;
size_t constexpr kCols = 10;

size_t bin_id = 0;
for (auto max_bin : kBinSizes) {
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatix();
common::GHistIndexMatrix hmat;
Expand Down Expand Up @@ -434,7 +433,6 @@ TEST(hist_util, SparseIndexBinData) {
size_t constexpr kRows = 100;
size_t constexpr kCols = 10;

size_t bin_id = 0;
for (auto max_bin : bin_sizes) {
auto p_fmat = RandomDataGenerator(kRows, kCols, 0.2).GenerateDMatix();
common::GHistIndexMatrix hmat;
Expand Down
46 changes: 23 additions & 23 deletions tests/cpp/data/test_simple_dmatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,32 +72,32 @@ TEST(SimpleDMatrix, Empty) {

data::CSRAdapter csr_adapter(row_ptr.data(), feature_idx.data(), data.data(),
0, 0, 0);
data::SimpleDMatrix dmat(&csr_adapter,
std::numeric_limits<float>::quiet_NaN(), 1);
CHECK_EQ(dmat.Info().num_nonzero_, 0);
CHECK_EQ(dmat.Info().num_row_, 0);
CHECK_EQ(dmat.Info().num_col_, 0);
for (auto &batch : dmat.GetBatches<SparsePage>()) {
std::unique_ptr<data::SimpleDMatrix> dmat(new data::SimpleDMatrix(
&csr_adapter, std::numeric_limits<float>::quiet_NaN(), 1));
CHECK_EQ(dmat->Info().num_nonzero_, 0);
CHECK_EQ(dmat->Info().num_row_, 0);
CHECK_EQ(dmat->Info().num_col_, 0);
for (auto &batch : dmat->GetBatches<SparsePage>()) {
CHECK_EQ(batch.Size(), 0);
}

data::DenseAdapter dense_adapter(nullptr, 0, 0);
dmat = data::SimpleDMatrix(&dense_adapter,
std::numeric_limits<float>::quiet_NaN(), 1);
CHECK_EQ(dmat.Info().num_nonzero_, 0);
CHECK_EQ(dmat.Info().num_row_, 0);
CHECK_EQ(dmat.Info().num_col_, 0);
for (auto &batch : dmat.GetBatches<SparsePage>()) {
dmat.reset( new data::SimpleDMatrix(&dense_adapter,
std::numeric_limits<float>::quiet_NaN(), 1) );
CHECK_EQ(dmat->Info().num_nonzero_, 0);
CHECK_EQ(dmat->Info().num_row_, 0);
CHECK_EQ(dmat->Info().num_col_, 0);
for (auto &batch : dmat->GetBatches<SparsePage>()) {
CHECK_EQ(batch.Size(), 0);
}

data::CSCAdapter csc_adapter(nullptr, nullptr, nullptr, 0, 0);
dmat = data::SimpleDMatrix(&csc_adapter,
std::numeric_limits<float>::quiet_NaN(), 1);
CHECK_EQ(dmat.Info().num_nonzero_, 0);
CHECK_EQ(dmat.Info().num_row_, 0);
CHECK_EQ(dmat.Info().num_col_, 0);
for (auto &batch : dmat.GetBatches<SparsePage>()) {
dmat.reset(new data::SimpleDMatrix(
&csc_adapter, std::numeric_limits<float>::quiet_NaN(), 1));
CHECK_EQ(dmat->Info().num_nonzero_, 0);
CHECK_EQ(dmat->Info().num_row_, 0);
CHECK_EQ(dmat->Info().num_col_, 0);
for (auto &batch : dmat->GetBatches<SparsePage>()) {
CHECK_EQ(batch.Size(), 0);
}
}
Expand All @@ -109,11 +109,11 @@ TEST(SimpleDMatrix, MissingData) {

data::CSRAdapter adapter(row_ptr.data(), feature_idx.data(), data.data(), 2,
3, 2);
data::SimpleDMatrix dmat(&adapter, std::numeric_limits<float>::quiet_NaN(),
1);
CHECK_EQ(dmat.Info().num_nonzero_, 2);
dmat = data::SimpleDMatrix(&adapter, 1.0, 1);
CHECK_EQ(dmat.Info().num_nonzero_, 1);
std::unique_ptr<data::SimpleDMatrix> dmat{new data::SimpleDMatrix{
&adapter, std::numeric_limits<float>::quiet_NaN(), 1}};
CHECK_EQ(dmat->Info().num_nonzero_, 2);
dmat.reset(new data::SimpleDMatrix(&adapter, 1.0, 1));
CHECK_EQ(dmat->Info().num_nonzero_, 1);
}

TEST(SimpleDMatrix, EmptyRow) {
Expand Down

0 comments on commit 29c6ad9

Please sign in to comment.