Skip to content

Commit

Permalink
Allow int64 array interface for groups
Browse files Browse the repository at this point in the history
  • Loading branch information
RAMitchell committed Jun 12, 2020
1 parent fed51ac commit db6e51b
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 21 deletions.
43 changes: 29 additions & 14 deletions src/data/data.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

namespace xgboost {

void CopyInfoImpl(ArrayInterface column, HostDeviceVector<float>* out) {
void CopyFloatImpl(ArrayInterface column, HostDeviceVector<float>* out) {
auto SetDeviceToPtr = [](void* ptr) {
cudaPointerAttributes attr;
dh::safe_cuda(cudaPointerGetAttributes(&attr, ptr));
Expand All @@ -34,6 +34,30 @@ void CopyInfoImpl(ArrayInterface column, HostDeviceVector<float>* out) {
});
}

void CopyGroupImpl(ArrayInterface column, std::vector<bst_group_t>* out) {
CHECK(column.type[1] == 'i' || column.type[1] == 'u')
<< "Expected integer metainfo";
auto SetDeviceToPtr = [](void* ptr) {
cudaPointerAttributes attr;
dh::safe_cuda(cudaPointerGetAttributes(&attr, ptr));
int32_t ptr_device = attr.device;
dh::safe_cuda(cudaSetDevice(ptr_device));
return ptr_device;
};
auto ptr_device = SetDeviceToPtr(column.data);
dh::TemporaryArray<bst_group_t> temp(column.num_rows);
auto d_tmp = temp.data();

dh::LaunchN(ptr_device, column.num_rows, [=] __device__(size_t idx) {
d_tmp[idx] = column.GetElement(idx);
});
auto length = column.num_rows;
out->resize(length + 1);
out->at(0) = 0;
thrust::copy(temp.data(), temp.data() + length, out->begin() + 1);
std::partial_sum(out->begin(), out->end(), out->begin());
}

void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
Json j_interface = Json::Load({interface_str.c_str(), interface_str.size()});
auto const& j_arr = get<Array>(j_interface);
Expand All @@ -47,22 +71,13 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
<< "Meta info should be a single column.";

if (key == "label") {
CopyInfoImpl(array_interface, &labels_);
CopyFloatImpl(array_interface, &labels_);
} else if (key == "weight") {
CopyInfoImpl(array_interface, &weights_);
CopyFloatImpl(array_interface, &weights_);
} else if (key == "base_margin") {
CopyInfoImpl(array_interface, &base_margin_);
CopyFloatImpl(array_interface, &base_margin_);
} else if (key == "group") {
// Ranking is not performed on device.
thrust::device_ptr<uint32_t> p_src{
reinterpret_cast<uint32_t*>(array_interface.data)};

auto length = array_interface.num_rows;
group_ptr_.resize(length + 1);
group_ptr_[0] = 0;
thrust::copy(p_src, p_src + length, group_ptr_.begin() + 1);
std::partial_sum(group_ptr_.begin(), group_ptr_.end(), group_ptr_.begin());

CopyGroupImpl(array_interface, &group_ptr_);
return;
} else {
LOG(FATAL) << "Unknown metainfo: " << key;
Expand Down
30 changes: 23 additions & 7 deletions tests/cpp/data/test_metainfo.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ std::string PrepareData(std::string typestr, thrust::device_vector<T>* out, cons

std::vector<Json> j_shape {Json(Integer(static_cast<Integer::Int>(kRows)))};
column["shape"] = Array(j_shape);
column["strides"] = Array(std::vector<Json>{Json(Integer(static_cast<Integer::Int>(4)))});
column["strides"] = Array(std::vector<Json>{Json(Integer(static_cast<Integer::Int>(sizeof(T))))});
column["version"] = Integer(static_cast<Integer::Int>(1));
column["typestr"] = String(typestr);

Expand Down Expand Up @@ -78,16 +78,32 @@ TEST(MetaInfo, FromInterface) {

TEST(MetaInfo, Group) {
cudaSetDevice(0);
thrust::device_vector<uint32_t> d_data;
std::string str = PrepareData<uint32_t>("<u4", &d_data);

MetaInfo info;

info.SetInfo("group", str.c_str());
auto const& h_group = info.group_ptr_;
ASSERT_EQ(h_group.size(), d_data.size() + 1);
thrust::device_vector<uint32_t> d_uint;
std::string uint_str = PrepareData<uint32_t>("<u4", &d_uint);
info.SetInfo("group", uint_str.c_str());
auto& h_group = info.group_ptr_;
ASSERT_EQ(h_group.size(), d_uint.size() + 1);
for (size_t i = 1; i < h_group.size(); ++i) {
ASSERT_EQ(h_group[i], d_data[i-1] + h_group[i-1]) << "i: " << i;
ASSERT_EQ(h_group[i], d_uint[i - 1] + h_group[i - 1]) << "i: " << i;
}

thrust::device_vector<int64_t> d_int64;
std::string int_str = PrepareData<int64_t>("<i8", &d_int64);
info = MetaInfo();
info.SetInfo("group", int_str.c_str());
h_group = info.group_ptr_;
ASSERT_EQ(h_group.size(), d_uint.size() + 1);
for (size_t i = 1; i < h_group.size(); ++i) {
ASSERT_EQ(h_group[i], d_uint[i - 1] + h_group[i - 1]) << "i: " << i;
}

// Incorrect type
thrust::device_vector<float> d_float;
std::string float_str = PrepareData<float>("<f4", &d_float);
info = MetaInfo();
EXPECT_ANY_THROW(info.SetInfo("group", float_str.c_str()));
}
} // namespace xgboost

0 comments on commit db6e51b

Please sign in to comment.