Skip to content

Commit

Permalink
add qid for dmlc#2748
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliang09 authored and hcho3 committed Jun 29, 2018
1 parent 8bec8d5 commit 9d29c5d
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
2 changes: 2 additions & 0 deletions include/xgboost/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class MetaInfo {
std::vector<bst_uint> group_ptr_;
/*! \brief weights of each instance, optional */
std::vector<bst_float> weights_;
/*! \brief session-id of each instance, optional */
std::vector<bst_float> qids;
/*!
* \brief initialized margins,
* if specified, xgboost will start from this init margin
Expand Down
24 changes: 24 additions & 0 deletions src/data/simple_csr_source.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ void SimpleCSRSource::CopyFrom(DMatrix* src) {
}

void SimpleCSRSource::CopyFrom(dmlc::Parser<uint32_t>* parser) {
// use sessionID get gourp info
bst_float lastSessionId = -1.0;
bst_uint groupSizeAll = 0;
this->Clear();
while (parser->Next()) {
const dmlc::RowBlock<uint32_t>& batch = parser->Value();
Expand All @@ -35,6 +38,22 @@ void SimpleCSRSource::CopyFrom(dmlc::Parser<uint32_t>* parser) {
if (batch.weight != nullptr) {
info.weights_.insert(info.weights_.end(), batch.weight, batch.weight + batch.size);
}
if (batch.qid != nullptr) {
info.qids.insert(info.qids.end(), batch.qid, batch.qid + batch.size);
// get group
for (size_t i = 0; i < batch.size; ++i) {
bst_float curGroupId = batch.qid[i];
if (lastSessionId == -1) {
info.group_ptr.push_back(0);
}
else if (lastSessionId != curGroupId) {
info.group_ptr.push_back(groupSizeAll);
}
lastSessionId = curGroupId;
groupSizeAll++;
}
}

// Remove the assertion on batch.index, which can be null in the case that the data in this
// batch is entirely sparse. Although it's true that this indicates a likely issue with the
// user's data workflows, passing XGBoost entirely sparse data should not cause it to fail.
Expand All @@ -56,6 +75,11 @@ void SimpleCSRSource::CopyFrom(dmlc::Parser<uint32_t>* parser) {
page_.offset.push_back(page_.offset[top - 1] + batch.offset[i + 1] - batch.offset[0]);
}
}
if (lastSessionId != -1) {
if (groupSizeAll > info.group_ptr.back()) {
info.group_ptr.push_back(groupSizeAll);
}
}
this->info.num_nonzero_ = static_cast<uint64_t>(page_.data.size());
}

Expand Down

0 comments on commit 9d29c5d

Please sign in to comment.