Skip to content

Commit

Permalink
Range check for DCG position discount lookup (#4069)
Browse files Browse the repository at this point in the history
* Add check to prevent out of index lookup in the position discount table. Add debug logging to report number of queries found in the data.

* Change debug logging location so that we can print the data file name as well.

* Revert "Change debug logging location so that we can print the data file name as well."

This reverts commit 3981b34.

* Add data file name to debug logging.

* Move log line to a place where it is output even when query IDs are read from a separate file.

* Also add the out-of-range check to rank metrics.

* Perform check after number of queries is initialized.

* Update
  • Loading branch information
ashok-ponnuswami-msft authored Mar 17, 2021
1 parent e9f50a5 commit 4580393
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 1 deletion.
8 changes: 8 additions & 0 deletions include/LightGBM/metric.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,14 @@ class DCGCalculator {
static double CalMaxDCGAtK(data_size_t k,
const label_t* label, data_size_t num_data);


/*!
* \brief Check the metadata for NDCG and lambdarank
* \param metadata Metadata
* \param num_queries Number of queries
*/
static void CheckMetadata(const Metadata& metadata, data_size_t num_queries);

/*!
* \brief Check the label range for NDCG and lambdarank
* \param label Pointer of label
Expand Down
4 changes: 4 additions & 0 deletions src/io/metadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,10 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data
// re-load query weight
LoadQueryWeights();
}
if (num_queries_ > 0) {
Log::Debug("Number of queries in %s: %i. Average number of rows per query: %f.",
data_filename_.c_str(), static_cast<int>(num_queries_), static_cast<double>(num_data_) / num_queries_);
}
}

void Metadata::SetInitScore(const double* init_score, data_size_t len) {
Expand Down
13 changes: 13 additions & 0 deletions src/metric/dcg_calculator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,19 @@ void DCGCalculator::CalDCG(const std::vector<data_size_t>& ks, const label_t* la
}
}

void DCGCalculator::CheckMetadata(const Metadata& metadata, data_size_t num_queries) {
const data_size_t* query_boundaries = metadata.query_boundaries();
if (num_queries > 0 && query_boundaries != nullptr) {
for (data_size_t i = 0; i < num_queries; i++) {
data_size_t num_rows = query_boundaries[i + 1] - query_boundaries[i];
if (num_rows > kMaxPosition) {
Log::Fatal("Number of rows %i exceeds upper limit of %i for a query", static_cast<int>(num_rows), static_cast<int>(kMaxPosition));
}
}
}
}


void DCGCalculator::CheckLabel(const label_t* label, data_size_t num_data) {
for (data_size_t i = 0; i < num_data; ++i) {
label_t delta = std::fabs(label[i] - static_cast<int>(label[i]));
Expand Down
3 changes: 2 additions & 1 deletion src/metric/rank_metric.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,14 @@ class NDCGMetric:public Metric {
num_data_ = num_data;
// get label
label_ = metadata.label();
num_queries_ = metadata.num_queries();
DCGCalculator::CheckMetadata(metadata, num_queries_);
DCGCalculator::CheckLabel(label_, num_data_);
// get query boundaries
query_boundaries_ = metadata.query_boundaries();
if (query_boundaries_ == nullptr) {
Log::Fatal("The NDCG metric requires query information");
}
num_queries_ = metadata.num_queries();
// get query weights
query_weights_ = metadata.query_weights();
if (query_weights_ == nullptr) {
Expand Down
1 change: 1 addition & 0 deletions src/objective/rank_objective.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ class LambdarankNDCG : public RankingObjective {

void Init(const Metadata& metadata, data_size_t num_data) override {
RankingObjective::Init(metadata, num_data);
DCGCalculator::CheckMetadata(metadata, num_queries_);
DCGCalculator::CheckLabel(label_, num_data_);
inverse_max_dcgs_.resize(num_queries_);
#pragma omp parallel for schedule(static)
Expand Down

0 comments on commit 4580393

Please sign in to comment.