Skip to content

Commit

Permalink
Implement extend method for meta info.
Browse files Browse the repository at this point in the history
* Implement extend for host device vector.
  • Loading branch information
trivialfis committed Jun 19, 2020
1 parent a6d9a06 commit 939377f
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 1 deletion.
10 changes: 10 additions & 0 deletions include/xgboost/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,16 @@ class MetaInfo {
*/
void SetInfo(const char* key, std::string const& interface_str);

/*
* \brief Extend with other MetaInfo.
*
* \param that The other MetaInfo object.
*
* \param accumulate_rows Whether rows need to be accumulated in this function. If
* client code knows number of rows in advance, set this parameter to false.
*/
void Extend(MetaInfo const& that, bool accumulate_rows);

private:
/*! \brief argsort of labels */
mutable std::vector<size_t> label_order_cache_;
Expand Down
5 changes: 5 additions & 0 deletions include/xgboost/host_device_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@

#include <initializer_list>
#include <vector>
#include <type_traits>

#include "span.h"

Expand Down Expand Up @@ -83,6 +84,8 @@ enum GPUAccess {

template <typename T>
class HostDeviceVector {
static_assert(std::is_standard_layout<T>::value, "HostDeviceVector admits only POD types");

public:
explicit HostDeviceVector(size_t size = 0, T v = T(), int device = -1);
HostDeviceVector(std::initializer_list<T> init, int device = -1);
Expand Down Expand Up @@ -117,6 +120,8 @@ class HostDeviceVector {
void Copy(const std::vector<T>& other);
void Copy(std::initializer_list<T> other);

void Extend(const HostDeviceVector<T>& other);

std::vector<T>& HostVector();
const std::vector<T>& ConstHostVector() const;
const std::vector<T>& HostVector() const {return ConstHostVector(); }
Expand Down
8 changes: 8 additions & 0 deletions src/common/host_device_vector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,14 @@ void HostDeviceVector<T>::Copy(std::initializer_list<T> other) {
std::copy(other.begin(), other.end(), HostVector().begin());
}

template <typename T>
void HostDeviceVector<T>::Extend(HostDeviceVector const& other) {
auto ori_size = this->Size();
this->HostVector().resize(ori_size + other.Size());
std::copy(other.ConstHostVector().cbegin(), other.ConstHostVector().cend(),
this->HostVector().begin() + ori_size);
}

template <typename T>
bool HostDeviceVector<T>::HostCanRead() const {
return true;
Expand Down
24 changes: 24 additions & 0 deletions src/common/host_device_vector.cu
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,25 @@ class HostDeviceVectorImpl {
}
}

void Extend(HostDeviceVectorImpl* other) {
auto ori_size = this->Size();
this->Resize(ori_size + other->Size(), T());
if (HostCanWrite() && other->HostCanRead()) {
auto& h_vec = this->HostVector();
auto& other_vec = other->HostVector();
CHECK_EQ(h_vec.size(), ori_size + other->Size());
std::copy(other_vec.cbegin(), other_vec.cend(), h_vec.begin() + ori_size);
} else {
auto ptr = other->ConstDevicePointer();
SetDevice();
CHECK_EQ(this->DeviceIdx(), other->DeviceIdx());
dh::safe_cuda(cudaMemcpyAsync(this->DevicePointer() + ori_size,
ptr,
other->Size() * sizeof(T),
cudaMemcpyDeviceToDevice));
}
}

std::vector<T>& HostVector() {
LazySyncHost(GPUAccess::kNone);
return data_h_;
Expand Down Expand Up @@ -326,6 +345,11 @@ void HostDeviceVector<T>::Copy(std::initializer_list<T> other) {
impl_->Copy(other);
}

template <typename T>
void HostDeviceVector<T>::Extend(HostDeviceVector const& other) {
impl_->Extend(other.impl_);
}

template <typename T>
std::vector<T>& HostDeviceVector<T>::HostVector() { return impl_->HostVector(); }

Expand Down
38 changes: 38 additions & 0 deletions src/data/data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,44 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t
}
}

void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows) {
if (accumulate_rows) {
this->num_row_ += that.num_row_;
}
if (this->num_col_ != 0) {
CHECK_EQ(this->num_col_, that.num_col_)
<< "Number of columns must be consistent across batches.";
}
this->num_col_ = that.num_col_;

this->labels_.SetDevice(that.labels_.DeviceIdx());
this->labels_.Extend(that.labels_);

this->weights_.SetDevice(that.weights_.DeviceIdx());
this->weights_.Extend(that.weights_);

this->labels_lower_bound_.SetDevice(that.labels_lower_bound_.DeviceIdx());
this->labels_lower_bound_.Extend(that.labels_lower_bound_);

this->labels_upper_bound_.SetDevice(that.labels_upper_bound_.DeviceIdx());
this->labels_upper_bound_.Extend(that.labels_upper_bound_);

this->base_margin_.SetDevice(that.base_margin_.DeviceIdx());
this->base_margin_.Extend(that.base_margin_);

if (this->group_ptr_.size() == 0) {
this->group_ptr_ = that.group_ptr_;
} else {
CHECK_NE(that.group_ptr_.size(), 0);
auto group_ptr = that.group_ptr_;
for (size_t i = 1; i < group_ptr.size(); ++i) {
group_ptr[i] += this->group_ptr_.back();
}
this->group_ptr_.insert(this->group_ptr_.end(), group_ptr.begin() + 1,
group_ptr.end());
}
}

void MetaInfo::Validate(int32_t device) const {
if (group_ptr_.size() != 0 && weights_.Size() != 0) {
CHECK_EQ(group_ptr_.size(), weights_.Size() + 1)
Expand Down
32 changes: 32 additions & 0 deletions tests/cpp/data/test_metainfo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,35 @@ TEST(MetaInfo, Validate) {
EXPECT_THROW(info.Validate(1), dmlc::Error);
#endif // defined(XGBOOST_USE_CUDA)
}

TEST(MetaInfo, HostExtend) {
xgboost::MetaInfo lhs, rhs;
size_t const kRows = 100;
lhs.labels_.Resize(kRows);
lhs.num_row_ = kRows;
rhs.labels_.Resize(kRows);
rhs.num_row_ = kRows;
ASSERT_TRUE(lhs.labels_.HostCanRead());
ASSERT_TRUE(rhs.labels_.HostCanRead());

size_t per_group = 10;
std::vector<xgboost::bst_group_t> groups;
for (size_t g = 0; g < kRows / per_group; ++g) {
groups.emplace_back(per_group);
}
lhs.SetInfo("group", groups.data(), xgboost::DataType::kUInt32, groups.size());
rhs.SetInfo("group", groups.data(), xgboost::DataType::kUInt32, groups.size());

lhs.Extend(rhs, true);
ASSERT_EQ(lhs.num_row_, kRows * 2);
ASSERT_TRUE(lhs.labels_.HostCanRead());
ASSERT_TRUE(rhs.labels_.HostCanRead());
ASSERT_FALSE(lhs.labels_.DeviceCanRead());
ASSERT_FALSE(rhs.labels_.DeviceCanRead());

ASSERT_EQ(lhs.group_ptr_.front(), 0);
ASSERT_EQ(lhs.group_ptr_.back(), kRows * 2);
for (size_t i = 0; i < kRows * 2 / per_group; ++i) {
ASSERT_EQ(lhs.group_ptr_.at(i), per_group * i);
}
}
22 changes: 21 additions & 1 deletion tests/cpp/data/test_metainfo.cu
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ TEST(MetaInfo, FromInterface) {

TEST(MetaInfo, Group) {
cudaSetDevice(0);

MetaInfo info;

thrust::device_vector<uint32_t> d_uint;
Expand Down Expand Up @@ -105,4 +104,25 @@ TEST(MetaInfo, Group) {
info = MetaInfo();
EXPECT_ANY_THROW(info.SetInfo("group", float_str.c_str()));
}

TEST(MetaInfo, DeviceExtend) {
dh::safe_cuda(cudaSetDevice(0));
size_t const kRows = 100;
MetaInfo lhs, rhs;

thrust::device_vector<float> d_data;
std::string str = PrepareData<float>("<f4", &d_data, kRows);
lhs.SetInfo("label", str.c_str());
rhs.SetInfo("label", str.c_str());
ASSERT_FALSE(rhs.labels_.HostCanRead());
lhs.num_row_ = kRows;
rhs.num_row_ = kRows;

lhs.Extend(rhs, true);
ASSERT_EQ(lhs.num_row_, kRows * 2);
ASSERT_FALSE(lhs.labels_.HostCanRead());

ASSERT_FALSE(lhs.labels_.HostCanRead());
ASSERT_FALSE(rhs.labels_.HostCanRead());
}
} // namespace xgboost

0 comments on commit 939377f

Please sign in to comment.