From 3f92c534bd3b3c5cadbc4c4982198a7632073337 Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Fri, 27 Sep 2019 20:00:10 +0800 Subject: [PATCH 01/11] check the shape for mat, csr and csc --- src/c_api.cpp | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/src/c_api.cpp b/src/c_api.cpp index 17cb3ff4bd6f..1a0fb398e048 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -249,17 +249,19 @@ class Booster { boosting_->RollbackOneIter(); } - void PredictSingleRow(int num_iteration, int predict_type, + void PredictSingleRow(int num_iteration, int predict_type, int ncol, std::function>(int row_idx)> get_row_fun, const Config& config, double* out_result, int64_t* out_len) { + if (ncol != boosting_->MaxFeatureIdx() + 1) { + Log::Fatal("The number of feature in data (%d) is not same as in training (%d).", ncol, boosting_->MaxFeatureIdx() + 1); + } std::lock_guard lock(mutex_); if (single_row_predictor_[predict_type].get() == nullptr || !single_row_predictor_[predict_type]->IsPredictorEqual(config, num_iteration, boosting_.get())) { single_row_predictor_[predict_type].reset(new SingleRowPredictor(predict_type, boosting_.get(), config, num_iteration)); } - auto one_row = get_row_fun(0); auto pred_wrt_ptr = out_result; single_row_predictor_[predict_type]->predict_function(one_row, pred_wrt_ptr); @@ -268,10 +270,13 @@ class Booster { } - void Predict(int num_iteration, int predict_type, int nrow, + void Predict(int num_iteration, int predict_type, int nrow, int ncol, std::function>(int row_idx)> get_row_fun, const Config& config, double* out_result, int64_t* out_len) { + if (ncol != boosting_->MaxFeatureIdx() + 1) { + Log::Fatal("The number of feature in data (%d) is not same as in training (%d).", ncol, boosting_->MaxFeatureIdx() + 1); + } std::lock_guard lock(mutex_); bool is_predict_leaf = false; bool is_raw_score = false; @@ -1299,7 +1304,7 @@ int LGBM_BoosterPredictForCSR(BoosterHandle handle, int data_type, int64_t nindptr, int64_t nelem, - int64_t, + int64_t num_col, int predict_type, int num_iteration, const char* parameter, @@ -1315,7 +1320,7 @@ int LGBM_BoosterPredictForCSR(BoosterHandle handle, Booster* ref_booster = reinterpret_cast(handle); auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem); int nrow = static_cast(nindptr - 1); - ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun, + ref_booster->Predict(num_iteration, predict_type, nrow, static_cast(num_col), get_row_fun, config, out_result, out_len); API_END(); } @@ -1328,7 +1333,7 @@ int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle, int data_type, int64_t nindptr, int64_t nelem, - int64_t, + int64_t num_col, int predict_type, int num_iteration, const char* parameter, @@ -1343,7 +1348,7 @@ int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle, } Booster* ref_booster = reinterpret_cast(handle); auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem); - ref_booster->PredictSingleRow(num_iteration, predict_type, get_row_fun, config, out_result, out_len); + ref_booster->PredictSingleRow(num_iteration, predict_type, static_cast(num_col), get_row_fun, config, out_result, out_len); API_END(); } @@ -1395,7 +1400,7 @@ int LGBM_BoosterPredictForCSC(BoosterHandle handle, } return one_row; }; - ref_booster->Predict(num_iteration, predict_type, static_cast(num_row), get_row_fun, config, + ref_booster->Predict(num_iteration, predict_type, static_cast(num_row), ncol, get_row_fun, config, out_result, out_len); API_END(); } @@ -1420,7 +1425,7 @@ int LGBM_BoosterPredictForMat(BoosterHandle handle, } Booster* ref_booster = reinterpret_cast(handle); auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major); - ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun, + ref_booster->Predict(num_iteration, predict_type, nrow, ncol, get_row_fun, config, out_result, out_len); API_END(); } @@ -1444,7 +1449,7 @@ int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle, } Booster* ref_booster = reinterpret_cast(handle); auto get_row_fun = RowPairFunctionFromDenseMatric(data, 1, ncol, data_type, is_row_major); - ref_booster->PredictSingleRow(num_iteration, predict_type, get_row_fun, config, out_result, out_len); + ref_booster->PredictSingleRow(num_iteration, predict_type, ncol, get_row_fun, config, out_result, out_len); API_END(); } @@ -1468,7 +1473,7 @@ int LGBM_BoosterPredictForMats(BoosterHandle handle, } Booster* ref_booster = reinterpret_cast(handle); auto get_row_fun = RowPairFunctionFromDenseRows(data, ncol, data_type); - ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun, config, out_result, out_len); + ref_booster->Predict(num_iteration, predict_type, nrow, ncol, get_row_fun, config, out_result, out_len); API_END(); } From b184edb38fc4063bdde03125911b63972c46a179 Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Sat, 28 Sep 2019 10:26:41 +0800 Subject: [PATCH 02/11] guess from csr --- src/c_api.cpp | 34 ++++++++++++++++++++++++++++++---- src/io/parser.cpp | 2 -- 2 files changed, 30 insertions(+), 6 deletions(-) diff --git a/src/c_api.cpp b/src/c_api.cpp index 1a0fb398e048..8146e9b79181 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -253,7 +253,7 @@ class Booster { std::function>(int row_idx)> get_row_fun, const Config& config, double* out_result, int64_t* out_len) { - if (ncol != boosting_->MaxFeatureIdx() + 1) { + if (ncol > 0 && ncol != boosting_->MaxFeatureIdx() + 1) { Log::Fatal("The number of feature in data (%d) is not same as in training (%d).", ncol, boosting_->MaxFeatureIdx() + 1); } std::lock_guard lock(mutex_); @@ -274,8 +274,8 @@ class Booster { std::function>(int row_idx)> get_row_fun, const Config& config, double* out_result, int64_t* out_len) { - if (ncol != boosting_->MaxFeatureIdx() + 1) { - Log::Fatal("The number of feature in data (%d) is not same as in training (%d).", ncol, boosting_->MaxFeatureIdx() + 1); + if (ncol > 0 && ncol != boosting_->MaxFeatureIdx() + 1) { + Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).", ncol, boosting_->MaxFeatureIdx() + 1); } std::lock_guard lock(mutex_); bool is_predict_leaf = false; @@ -431,6 +431,8 @@ RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int d std::function>(int row_idx)> RowPairFunctionFromDenseRows(const void** data, int num_col, int data_type); +int GuessNumColFromCSR(const void* indptr, int indptr_type, const int32_t* indices, int64_t nindptr); + std::function>(int idx)> RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem); @@ -1320,7 +1322,11 @@ int LGBM_BoosterPredictForCSR(BoosterHandle handle, Booster* ref_booster = reinterpret_cast(handle); auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem); int nrow = static_cast(nindptr - 1); - ref_booster->Predict(num_iteration, predict_type, nrow, static_cast(num_col), get_row_fun, + int ncol = static_cast(num_col); + if (ncol <= 0) { + ncol = GuessNumColFromCSR(indptr, indptr_type, indices, nindptr); + } + ref_booster->Predict(num_iteration, predict_type, nrow, ncol, get_row_fun, config, out_result, out_len); API_END(); } @@ -1348,6 +1354,7 @@ int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle, } Booster* ref_booster = reinterpret_cast(handle); auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem); + // For single row, we cannot guess its num_col. ref_booster->PredictSingleRow(num_iteration, predict_type, static_cast(num_col), get_row_fun, config, out_result, out_len); API_END(); } @@ -1666,6 +1673,25 @@ RowPairFunctionFromDenseRows(const void** data, int num_col, int data_type) { }; } +int GuessNumColFromCSR(const void* indptr, int indptr_type, const int32_t* indices, int64_t nindptr) { + const int64_t used_row = std::min(static_cast(20), nindptr - 1); + int ncol = 0; + if (indptr_type == C_API_DTYPE_INT32) { + const int32_t* ptr_indptr = reinterpret_cast(indptr); + for (int64_t i = 0; i < used_row; ++i) { + auto col_idx = indices[ptr_indptr[i + 1]]; + ncol = std::max(ncol, col_idx + 1); + } + } else if (indptr_type == C_API_DTYPE_INT64) { + const int64_t* ptr_indptr = reinterpret_cast(indptr); + for(int64_t i = 0; i < used_row; ++i) { + auto col_idx = indices[ptr_indptr[i + 1]]; + ncol = std::max(ncol, col_idx + 1); + } + } + return ncol; +} + std::function>(int idx)> RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t , int64_t ) { if (data_type == C_API_DTYPE_FLOAT32) { diff --git a/src/io/parser.cpp b/src/io/parser.cpp index 45df64f49881..d080ad6513b4 100644 --- a/src/io/parser.cpp +++ b/src/io/parser.cpp @@ -142,10 +142,8 @@ Parser* Parser::CreateParser(const char* filename, bool header, int num_features type = DataType::LIBSVM; } else if (tab_cnt == tab_cnt2 && tab_cnt > 0) { type = DataType::TSV; - CHECK(tab_cnt == tab_cnt2); } else if (comma_cnt == comma_cnt2 && comma_cnt > 0) { type = DataType::CSV; - CHECK(comma_cnt == comma_cnt2); } } if (type == DataType::INVALID) { From a38f9ed3f5a75f3178d43093b8f3d4acc2127e2b Mon Sep 17 00:00:00 2001 From: guolinke Date: Sat, 28 Sep 2019 11:47:31 +0800 Subject: [PATCH 03/11] support file checking --- include/LightGBM/c_api.h | 4 +- include/LightGBM/dataset.h | 3 +- src/application/predictor.hpp | 4 +- src/c_api.cpp | 48 +++++--------- src/io/dataset.cpp | 4 +- src/io/dataset_loader.cpp | 6 +- src/io/parser.cpp | 121 ++++++++++++++++++++++++---------- src/io/parser.hpp | 17 ++--- 8 files changed, 125 insertions(+), 82 deletions(-) diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index e9f65827791f..33565d5e9a13 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -683,7 +683,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterCalcNumPredict(BoosterHandle handle, * \param data_type Type of ``data`` pointer, can be ``C_API_DTYPE_FLOAT32`` or ``C_API_DTYPE_FLOAT64`` * \param nindptr Number of rows in the matrix + 1 * \param nelem Number of nonzero elements in the matrix - * \param num_col Number of columns; when it's set to 0, then guess from data + * \param num_col Number of columns * \param predict_type What should be predicted * - ``C_API_PREDICT_NORMAL``: normal prediction, with transform (if needed); * - ``C_API_PREDICT_RAW_SCORE``: raw score; @@ -726,7 +726,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle, * \param data_type Type of ``data`` pointer, can be ``C_API_DTYPE_FLOAT32`` or ``C_API_DTYPE_FLOAT64`` * \param nindptr Number of rows in the matrix + 1 * \param nelem Number of nonzero elements in the matrix - * \param num_col Number of columns; when it's set to 0, then guess from data + * \param num_col Number of columns * \param predict_type What should be predicted * - ``C_API_PREDICT_NORMAL``: normal prediction, with transform (if needed); * - ``C_API_PREDICT_RAW_SCORE``: raw score; diff --git a/include/LightGBM/dataset.h b/include/LightGBM/dataset.h index c7147a32fe33..603aa7d57617 100644 --- a/include/LightGBM/dataset.h +++ b/include/LightGBM/dataset.h @@ -265,7 +265,7 @@ class Parser { virtual void ParseOneLine(const char* str, std::vector>* out_features, double* out_label) const = 0; - virtual int TotalColumns() const = 0; + virtual int NumFeatures() const = 0; /*! * \brief Create a object of parser, will auto choose the format depend on file @@ -290,6 +290,7 @@ class Dataset { void Construct( std::vector>* bin_mappers, + int num_total_features, int** sample_non_zero_indices, const int* num_per_col, size_t total_sample_cnt, diff --git a/src/application/predictor.hpp b/src/application/predictor.hpp index 9d9acdadee0e..e188962d7efd 100644 --- a/src/application/predictor.hpp +++ b/src/application/predictor.hpp @@ -140,7 +140,9 @@ class Predictor { if (parser == nullptr) { Log::Fatal("Could not recognize the data format of data file %s", data_filename); } - + if (parser->NumFeatures() != boosting_->MaxFeatureIdx() + 1) { + Log::Fatal("The number of feature in data (%d) is not same as in training (%d).", parser->NumFeatures(), boosting_->MaxFeatureIdx() + 1); + } TextReader predict_data_reader(data_filename, header); std::unordered_map feature_names_map_; bool need_adjust = false; diff --git a/src/c_api.cpp b/src/c_api.cpp index 8146e9b79181..dc648f2fd0d2 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -253,7 +253,7 @@ class Booster { std::function>(int row_idx)> get_row_fun, const Config& config, double* out_result, int64_t* out_len) { - if (ncol > 0 && ncol != boosting_->MaxFeatureIdx() + 1) { + if (ncol != boosting_->MaxFeatureIdx() + 1) { Log::Fatal("The number of feature in data (%d) is not same as in training (%d).", ncol, boosting_->MaxFeatureIdx() + 1); } std::lock_guard lock(mutex_); @@ -274,7 +274,7 @@ class Booster { std::function>(int row_idx)> get_row_fun, const Config& config, double* out_result, int64_t* out_len) { - if (ncol > 0 && ncol != boosting_->MaxFeatureIdx() + 1) { + if (ncol != boosting_->MaxFeatureIdx() + 1) { Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).", ncol, boosting_->MaxFeatureIdx() + 1); } std::lock_guard lock(mutex_); @@ -431,8 +431,6 @@ RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int d std::function>(int row_idx)> RowPairFunctionFromDenseRows(const void** data, int num_col, int data_type); -int GuessNumColFromCSR(const void* indptr, int indptr_type, const int32_t* indices, int64_t nindptr); - std::function>(int idx)> RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem); @@ -654,7 +652,7 @@ int LGBM_DatasetCreateFromMats(int32_t nmat, DatasetLoader loader(config, nullptr, 1, nullptr); ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr(&sample_values).data(), Common::Vector2Ptr(&sample_idx).data(), - static_cast(sample_values.size()), + ncol, Common::VectorSize(sample_values).data(), sample_cnt, total_nrow)); } else { @@ -694,6 +692,9 @@ int LGBM_DatasetCreateFromCSR(const void* indptr, const DatasetHandle reference, DatasetHandle* out) { API_BEGIN(); + if (num_col <= 0) { + Log::Fatal("The number of columns should greater than zero."); + } auto param = Config::Str2Map(parameters); Config config; config.Set(param); @@ -725,7 +726,7 @@ int LGBM_DatasetCreateFromCSR(const void* indptr, DatasetLoader loader(config, nullptr, 1, nullptr); ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr(&sample_values).data(), Common::Vector2Ptr(&sample_idx).data(), - static_cast(sample_values.size()), + static_cast(num_col), Common::VectorSize(sample_values).data(), sample_cnt, nrow)); } else { @@ -755,9 +756,10 @@ int LGBM_DatasetCreateFromCSRFunc(void* get_row_funptr, const DatasetHandle reference, DatasetHandle* out) { API_BEGIN(); - + if (num_col <= 0) { + Log::Fatal("The number of columns should greater than zero."); + } auto get_row_fun = *static_cast>&)>*>(get_row_funptr); - auto param = Config::Str2Map(parameters); Config config; config.Set(param); @@ -790,7 +792,7 @@ int LGBM_DatasetCreateFromCSRFunc(void* get_row_funptr, DatasetLoader loader(config, nullptr, 1, nullptr); ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr(&sample_values).data(), Common::Vector2Ptr(&sample_idx).data(), - static_cast(sample_values.size()), + static_cast(num_col), Common::VectorSize(sample_values).data(), sample_cnt, nrow)); } else { @@ -1313,6 +1315,9 @@ int LGBM_BoosterPredictForCSR(BoosterHandle handle, int64_t* out_len, double* out_result) { API_BEGIN(); + if (num_col <= 0) { + Log::Fatal("The number of columns should greater than zero."); + } auto param = Config::Str2Map(parameter); Config config; config.Set(param); @@ -1324,7 +1329,7 @@ int LGBM_BoosterPredictForCSR(BoosterHandle handle, int nrow = static_cast(nindptr - 1); int ncol = static_cast(num_col); if (ncol <= 0) { - ncol = GuessNumColFromCSR(indptr, indptr_type, indices, nindptr); + Log::Fatal("The number of columns should greater than zero."); } ref_booster->Predict(num_iteration, predict_type, nrow, ncol, get_row_fun, config, out_result, out_len); @@ -1346,6 +1351,9 @@ int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle, int64_t* out_len, double* out_result) { API_BEGIN(); + if (num_col <= 0) { + Log::Fatal("The number of columns should greater than zero."); + } auto param = Config::Str2Map(parameter); Config config; config.Set(param); @@ -1354,7 +1362,6 @@ int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle, } Booster* ref_booster = reinterpret_cast(handle); auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem); - // For single row, we cannot guess its num_col. ref_booster->PredictSingleRow(num_iteration, predict_type, static_cast(num_col), get_row_fun, config, out_result, out_len); API_END(); } @@ -1673,25 +1680,6 @@ RowPairFunctionFromDenseRows(const void** data, int num_col, int data_type) { }; } -int GuessNumColFromCSR(const void* indptr, int indptr_type, const int32_t* indices, int64_t nindptr) { - const int64_t used_row = std::min(static_cast(20), nindptr - 1); - int ncol = 0; - if (indptr_type == C_API_DTYPE_INT32) { - const int32_t* ptr_indptr = reinterpret_cast(indptr); - for (int64_t i = 0; i < used_row; ++i) { - auto col_idx = indices[ptr_indptr[i + 1]]; - ncol = std::max(ncol, col_idx + 1); - } - } else if (indptr_type == C_API_DTYPE_INT64) { - const int64_t* ptr_indptr = reinterpret_cast(indptr); - for(int64_t i = 0; i < used_row; ++i) { - auto col_idx = indices[ptr_indptr[i + 1]]; - ncol = std::max(ncol, col_idx + 1); - } - } - return ncol; -} - std::function>(int idx)> RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t , int64_t ) { if (data_type == C_API_DTYPE_FLOAT32) { diff --git a/src/io/dataset.cpp b/src/io/dataset.cpp index 3967fabb98a7..fab8d1c7f0d2 100644 --- a/src/io/dataset.cpp +++ b/src/io/dataset.cpp @@ -214,11 +214,13 @@ std::vector> FastFeatureBundling(const std::vector>* bin_mappers, + int num_total_features, int** sample_non_zero_indices, const int* num_per_col, size_t total_sample_cnt, const Config& io_config) { - num_total_features_ = static_cast(bin_mappers->size()); + num_total_features_ = num_total_features; + CHECK(num_total_features_ == static_cast(bin_mappers->size())); sparse_threshold_ = io_config.sparse_threshold; // get num_features std::vector used_features; diff --git a/src/io/dataset_loader.cpp b/src/io/dataset_loader.cpp index 265adc27ff41..085082f0b15b 100644 --- a/src/io/dataset_loader.cpp +++ b/src/io/dataset_loader.cpp @@ -692,7 +692,7 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values, } } auto dataset = std::unique_ptr(new Dataset(num_data)); - dataset->Construct(&bin_mappers, sample_indices, num_per_col, total_sample_size, config_); + dataset->Construct(&bin_mappers, num_col, sample_indices, num_per_col, total_sample_size, config_); dataset->set_feature_names(feature_names_); return dataset.release(); } @@ -864,7 +864,7 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines, if (feature_names_.empty()) { // -1 means doesn't use this feature - dataset->num_total_features_ = std::max(static_cast(sample_values.size()), parser->TotalColumns() - 1); + dataset->num_total_features_ = std::max(static_cast(sample_values.size()), parser->NumFeatures()); dataset->used_feature_map_ = std::vector(dataset->num_total_features_, -1); } else { dataset->used_feature_map_ = std::vector(feature_names_.size(), -1); @@ -1018,7 +1018,7 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines, } } sample_values.clear(); - dataset->Construct(&bin_mappers, Common::Vector2Ptr(&sample_indices).data(), + dataset->Construct(&bin_mappers, dataset->num_total_features_, Common::Vector2Ptr(&sample_indices).data(), Common::VectorSize(sample_indices).data(), sample_data.size(), config_); } diff --git a/src/io/parser.cpp b/src/io/parser.cpp index d080ad6513b4..6f17c70db779 100644 --- a/src/io/parser.cpp +++ b/src/io/parser.cpp @@ -89,47 +89,52 @@ void GetLine(std::stringstream* ss, std::string* line, const VirtualFileReader* } } -Parser* Parser::CreateParser(const char* filename, bool header, int num_features, int label_idx) { +std::vector ReadKLineFromFile(const char* filename, bool header, int k) { auto reader = VirtualFileReader::Make(filename); if (!reader->Init()) { Log::Fatal("Data file %s doesn't exist", filename); } - std::string line1, line2; - size_t buffer_size = 64 * 1024; + std::vector ret; + std::string cur_line; + const size_t buffer_size = 1024 * 1024; auto buffer = std::vector(buffer_size); size_t read_len = reader->Read(buffer.data(), buffer_size); if (read_len <= 0) { Log::Fatal("Data file %s couldn't be read", filename); } - - std::stringstream tmp_file(std::string(buffer.data(), read_len)); + std::string read_str = std::string(buffer.data(), read_len); + std::stringstream tmp_file(read_str); if (header) { if (!tmp_file.eof()) { - GetLine(&tmp_file, &line1, reader.get(), &buffer, buffer_size); + GetLine(&tmp_file, &cur_line, reader.get(), &buffer, buffer_size); } } - if (!tmp_file.eof()) { - GetLine(&tmp_file, &line1, reader.get(), &buffer, buffer_size); - } else { - Log::Fatal("Data file %s should have at least one line", filename); + for (int i = 0; i < k; ++i) { + if (!tmp_file.eof()) { + GetLine(&tmp_file, &cur_line, reader.get(), &buffer, buffer_size); + ret.push_back(cur_line); + } else { + break; + } } - if (!tmp_file.eof()) { - GetLine(&tmp_file, &line2, reader.get(), &buffer, buffer_size); - } else { + if (ret.empty()) { + Log::Fatal("Data file %s should have at least one line", filename); + } else if (ret.size() == 1) { Log::Warning("Data file %s only has one line", filename); } - int comma_cnt = 0, comma_cnt2 = 0; - int tab_cnt = 0, tab_cnt2 = 0; - int colon_cnt = 0, colon_cnt2 = 0; - // Get some statistic from 2 line - GetStatistic(line1.c_str(), &comma_cnt, &tab_cnt, &colon_cnt); - GetStatistic(line2.c_str(), &comma_cnt2, &tab_cnt2, &colon_cnt2); - - + return ret; +} +DataType GetDataType(const std::vector& lines, int* num_col) { DataType type = DataType::INVALID; - if (line2.size() == 0) { - // if only have one line on file + if (lines.empty()) { + return type; + } + int comma_cnt = 0; + int tab_cnt = 0; + int colon_cnt = 0; + GetStatistic(lines[0].c_str(), &comma_cnt, &tab_cnt, &colon_cnt); + if (lines.size() == 1) { if (colon_cnt > 0) { type = DataType::LIBSVM; } else if (tab_cnt > 0) { @@ -137,28 +142,72 @@ Parser* Parser::CreateParser(const char* filename, bool header, int num_features } else if (comma_cnt > 0) { type = DataType::CSV; } - } else { - if (colon_cnt > 0 || colon_cnt2 > 0) { - type = DataType::LIBSVM; - } else if (tab_cnt == tab_cnt2 && tab_cnt > 0) { - type = DataType::TSV; - } else if (comma_cnt == comma_cnt2 && comma_cnt > 0) { - type = DataType::CSV; + } + int comma_cnt2 = 0; + int tab_cnt2 = 0; + int colon_cnt2 = 0; + GetStatistic(lines[1].c_str(), &comma_cnt2, &tab_cnt2, &colon_cnt2); + if (colon_cnt > 0 || colon_cnt2 > 0) { + type = DataType::LIBSVM; + } else if (tab_cnt == tab_cnt2 && tab_cnt > 0) { + type = DataType::TSV; + } else if (comma_cnt == comma_cnt2 && comma_cnt > 0) { + type = DataType::CSV; + } + + // valid the type + for (size_t i = 2; i < lines.size(); ++i) { + GetStatistic(lines[i].c_str(), &comma_cnt2, &tab_cnt2, &colon_cnt2); + if (type == DataType::TSV) { + if (tab_cnt2 != tab_cnt) { + type = DataType::INVALID; + break; + } + } else if (type == DataType::CSV) { + if (comma_cnt != comma_cnt2) { + type = DataType::INVALID; + break; + } } } + if (type == DataType::LIBSVM) { + int max_col_idx = 0; + for (size_t i = 0; i < lines.size(); ++i) { + auto str = Common::Trim(lines[i]); + auto colon_pos = str.find_last_of(":"); + auto space_pos = str.find_last_of(" \f\t\v"); + auto sub_str = str.substr(space_pos + 1, space_pos - colon_pos - 1); + int cur_idx = 0; + Common::Atoi(sub_str.c_str(), &cur_idx); + max_col_idx = std::max(cur_idx, max_col_idx); + } + *num_col = max_col_idx + 1; + } else if (type == DataType::CSV) { + *num_col = comma_cnt + 1; + } else if (type == DataType::TSV) { + *num_col = tab_cnt + 1; + } + return type; +} + +Parser* Parser::CreateParser(const char* filename, bool header, int num_features, int label_idx) { + const int n_read_line = 20; + auto lines = ReadKLineFromFile(filename, header, n_read_line); + int num_col = 0; + DataType type = GetDataType(lines, &num_col); if (type == DataType::INVALID) { Log::Fatal("Unknown format of training data"); } std::unique_ptr ret; if (type == DataType::LIBSVM) { - label_idx = GetLabelIdxForLibsvm(line1, num_features, label_idx); - ret.reset(new LibSVMParser(label_idx)); + label_idx = GetLabelIdxForLibsvm(lines[0], num_features, label_idx); + ret.reset(new LibSVMParser(label_idx, num_col)); } else if (type == DataType::TSV) { - label_idx = GetLabelIdxForTSV(line1, num_features, label_idx); - ret.reset(new TSVParser(label_idx, tab_cnt + 1)); + label_idx = GetLabelIdxForTSV(lines[0], num_features, label_idx); + ret.reset(new TSVParser(label_idx, num_col)); } else if (type == DataType::CSV) { - label_idx = GetLabelIdxForCSV(line1, num_features, label_idx); - ret.reset(new CSVParser(label_idx, comma_cnt + 1)); + label_idx = GetLabelIdxForCSV(lines[0], num_features, label_idx); + ret.reset(new CSVParser(label_idx, num_col)); } if (label_idx < 0) { diff --git a/src/io/parser.hpp b/src/io/parser.hpp index 46698172f120..6bfe94a3f036 100644 --- a/src/io/parser.hpp +++ b/src/io/parser.hpp @@ -43,8 +43,8 @@ class CSVParser: public Parser { } } - inline int TotalColumns() const override { - return total_columns_; + inline int NumFeatures() const override { + return total_columns_ - (label_idx_ >= 0); } private: @@ -79,8 +79,8 @@ class TSVParser: public Parser { } } - inline int TotalColumns() const override { - return total_columns_; + inline int NumFeatures() const override { + return total_columns_ - (label_idx_ >= 0); } private: @@ -90,8 +90,8 @@ class TSVParser: public Parser { class LibSVMParser: public Parser { public: - explicit LibSVMParser(int label_idx) - :label_idx_(label_idx) { + explicit LibSVMParser(int label_idx, int total_columns) + :label_idx_(label_idx), total_columns_(total_columns) { if (label_idx > 0) { Log::Fatal("Label should be the first column in a LibSVM file"); } @@ -119,12 +119,13 @@ class LibSVMParser: public Parser { } } - inline int TotalColumns() const override { - return -1; + inline int NumFeatures() const override { + return total_columns_; } private: int label_idx_ = 0; + int total_columns_ = -1; }; } // namespace LightGBM From 22d340ebaf8641710732459cb7883ca0313d3f50 Mon Sep 17 00:00:00 2001 From: guolinke Date: Sat, 28 Sep 2019 11:50:23 +0800 Subject: [PATCH 04/11] better error msg --- src/application/predictor.hpp | 2 +- src/c_api.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/application/predictor.hpp b/src/application/predictor.hpp index e188962d7efd..704ea0e46d6d 100644 --- a/src/application/predictor.hpp +++ b/src/application/predictor.hpp @@ -141,7 +141,7 @@ class Predictor { Log::Fatal("Could not recognize the data format of data file %s", data_filename); } if (parser->NumFeatures() != boosting_->MaxFeatureIdx() + 1) { - Log::Fatal("The number of feature in data (%d) is not same as in training (%d).", parser->NumFeatures(), boosting_->MaxFeatureIdx() + 1); + Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).", parser->NumFeatures(), boosting_->MaxFeatureIdx() + 1); } TextReader predict_data_reader(data_filename, header); std::unordered_map feature_names_map_; diff --git a/src/c_api.cpp b/src/c_api.cpp index dc648f2fd0d2..017bb3673afa 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -254,7 +254,7 @@ class Booster { const Config& config, double* out_result, int64_t* out_len) { if (ncol != boosting_->MaxFeatureIdx() + 1) { - Log::Fatal("The number of feature in data (%d) is not same as in training (%d).", ncol, boosting_->MaxFeatureIdx() + 1); + Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).", ncol, boosting_->MaxFeatureIdx() + 1); } std::lock_guard lock(mutex_); if (single_row_predictor_[predict_type].get() == nullptr || From ad22c8ee564f60bb9873cce48b17da88cd48ee97 Mon Sep 17 00:00:00 2001 From: guolinke Date: Sat, 28 Sep 2019 11:51:17 +0800 Subject: [PATCH 05/11] grammar --- src/c_api.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/c_api.cpp b/src/c_api.cpp index 017bb3673afa..6d63c79e0d6f 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -693,7 +693,7 @@ int LGBM_DatasetCreateFromCSR(const void* indptr, DatasetHandle* out) { API_BEGIN(); if (num_col <= 0) { - Log::Fatal("The number of columns should greater than zero."); + Log::Fatal("The number of columns should be greater than zero."); } auto param = Config::Str2Map(parameters); Config config; @@ -757,7 +757,7 @@ int LGBM_DatasetCreateFromCSRFunc(void* get_row_funptr, DatasetHandle* out) { API_BEGIN(); if (num_col <= 0) { - Log::Fatal("The number of columns should greater than zero."); + Log::Fatal("The number of columns should be greater than zero."); } auto get_row_fun = *static_cast>&)>*>(get_row_funptr); auto param = Config::Str2Map(parameters); @@ -1316,7 +1316,7 @@ int LGBM_BoosterPredictForCSR(BoosterHandle handle, double* out_result) { API_BEGIN(); if (num_col <= 0) { - Log::Fatal("The number of columns should greater than zero."); + Log::Fatal("The number of columns should be greater than zero."); } auto param = Config::Str2Map(parameter); Config config; @@ -1329,7 +1329,7 @@ int LGBM_BoosterPredictForCSR(BoosterHandle handle, int nrow = static_cast(nindptr - 1); int ncol = static_cast(num_col); if (ncol <= 0) { - Log::Fatal("The number of columns should greater than zero."); + Log::Fatal("The number of columns should be greater than zero."); } ref_booster->Predict(num_iteration, predict_type, nrow, ncol, get_row_fun, config, out_result, out_len); @@ -1352,7 +1352,7 @@ int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle, double* out_result) { API_BEGIN(); if (num_col <= 0) { - Log::Fatal("The number of columns should greater than zero."); + Log::Fatal("The number of columns should be greater than zero."); } auto param = Config::Str2Map(parameter); Config config; From 39ac72b333162fa8938b795856f3b2666548febe Mon Sep 17 00:00:00 2001 From: guolinke Date: Sat, 28 Sep 2019 11:52:46 +0800 Subject: [PATCH 06/11] clean code --- src/c_api.cpp | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/c_api.cpp b/src/c_api.cpp index 6d63c79e0d6f..c01c634be8c0 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -1327,11 +1327,7 @@ int LGBM_BoosterPredictForCSR(BoosterHandle handle, Booster* ref_booster = reinterpret_cast(handle); auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem); int nrow = static_cast(nindptr - 1); - int ncol = static_cast(num_col); - if (ncol <= 0) { - Log::Fatal("The number of columns should be greater than zero."); - } - ref_booster->Predict(num_iteration, predict_type, nrow, ncol, get_row_fun, + ref_booster->Predict(num_iteration, predict_type, nrow, static_cast(num_col), get_row_fun, config, out_result, out_len); API_END(); } From f7e24fe0819678b8a7df17f426e64778b8d80a44 Mon Sep 17 00:00:00 2001 From: guolinke Date: Sat, 28 Sep 2019 11:58:13 +0800 Subject: [PATCH 07/11] code clean --- src/io/parser.cpp | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/src/io/parser.cpp b/src/io/parser.cpp index 6f17c70db779..0d61c5f53482 100644 --- a/src/io/parser.cpp +++ b/src/io/parser.cpp @@ -92,7 +92,7 @@ void GetLine(std::stringstream* ss, std::string* line, const VirtualFileReader* std::vector ReadKLineFromFile(const char* filename, bool header, int k) { auto reader = VirtualFileReader::Make(filename); if (!reader->Init()) { - Log::Fatal("Data file %s doesn't exist", filename); + Log::Fatal("Data file %s doesn't exist.", filename); } std::vector ret; std::string cur_line; @@ -100,7 +100,7 @@ std::vector ReadKLineFromFile(const char* filename, bool header, in auto buffer = std::vector(buffer_size); size_t read_len = reader->Read(buffer.data(), buffer_size); if (read_len <= 0) { - Log::Fatal("Data file %s couldn't be read", filename); + Log::Fatal("Data file %s couldn't be read.", filename); } std::string read_str = std::string(buffer.data(), read_len); std::stringstream tmp_file(read_str); @@ -118,9 +118,9 @@ std::vector ReadKLineFromFile(const char* filename, bool header, in } } if (ret.empty()) { - Log::Fatal("Data file %s should have at least one line", filename); + Log::Fatal("Data file %s should have at least one line.", filename); } else if (ret.size() == 1) { - Log::Warning("Data file %s only has one line", filename); + Log::Warning("Data file %s only has one line.", filename); } return ret; } @@ -154,22 +154,20 @@ DataType GetDataType(const std::vector& lines, int* num_col) { } else if (comma_cnt == comma_cnt2 && comma_cnt > 0) { type = DataType::CSV; } - - // valid the type - for (size_t i = 2; i < lines.size(); ++i) { - GetStatistic(lines[i].c_str(), &comma_cnt2, &tab_cnt2, &colon_cnt2); - if (type == DataType::TSV) { - if (tab_cnt2 != tab_cnt) { + if (type == DataType::TSV || type == DataType::CSV) { + // valid the type + for (size_t i = 2; i < lines.size(); ++i) { + GetStatistic(lines[i].c_str(), &comma_cnt2, &tab_cnt2, &colon_cnt2); + if (type == DataType::TSV && tab_cnt2 != tab_cnt) { type = DataType::INVALID; break; - } - } else if (type == DataType::CSV) { - if (comma_cnt != comma_cnt2) { + } else if (type == DataType::CSV && comma_cnt != comma_cnt2) { type = DataType::INVALID; break; } } } + if (type == DataType::LIBSVM) { int max_col_idx = 0; for (size_t i = 0; i < lines.size(); ++i) { @@ -196,7 +194,7 @@ Parser* Parser::CreateParser(const char* filename, bool header, int num_features int num_col = 0; DataType type = GetDataType(lines, &num_col); if (type == DataType::INVALID) { - Log::Fatal("Unknown format of training data"); + Log::Fatal("Unknown format of training data."); } std::unique_ptr ret; if (type == DataType::LIBSVM) { @@ -211,7 +209,7 @@ Parser* Parser::CreateParser(const char* filename, bool header, int num_features } if (label_idx < 0) { - Log::Info("Data file %s doesn't contain a label column", filename); + Log::Info("Data file %s doesn't contain a label column.", filename); } return ret.release(); } From 6d1b87fcfb7b46e54daf224bb99ccbf5623916a8 Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Sun, 29 Sep 2019 17:56:07 +0800 Subject: [PATCH 08/11] check range for CSR --- src/c_api.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/c_api.cpp b/src/c_api.cpp index c01c634be8c0..fbf272d3dab1 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -694,6 +694,8 @@ int LGBM_DatasetCreateFromCSR(const void* indptr, API_BEGIN(); if (num_col <= 0) { Log::Fatal("The number of columns should be greater than zero."); + } else if (num_col >= INT32_MAX) { + Log::Fatal("The number of columns should be smaller than INT32_MAX."); } auto param = Config::Str2Map(parameters); Config config; @@ -758,6 +760,8 @@ int LGBM_DatasetCreateFromCSRFunc(void* get_row_funptr, API_BEGIN(); if (num_col <= 0) { Log::Fatal("The number of columns should be greater than zero."); + } else if (num_col >= INT32_MAX) { + Log::Fatal("The number of columns should be smaller than INT32_MAX."); } auto get_row_fun = *static_cast>&)>*>(get_row_funptr); auto param = Config::Str2Map(parameters); @@ -1317,6 +1321,8 @@ int LGBM_BoosterPredictForCSR(BoosterHandle handle, API_BEGIN(); if (num_col <= 0) { Log::Fatal("The number of columns should be greater than zero."); + } else if (num_col >= INT32_MAX) { + Log::Fatal("The number of columns should be smaller than INT32_MAX."); } auto param = Config::Str2Map(parameter); Config config; @@ -1349,6 +1355,8 @@ int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle, API_BEGIN(); if (num_col <= 0) { Log::Fatal("The number of columns should be greater than zero."); + } else if (num_col >= INT32_MAX) { + Log::Fatal("The number of columns should be smaller than INT32_MAX."); } auto param = Config::Str2Map(parameter); Config config; From 1a3e67d725f705e998d72ec55a0ddebc61bf7f02 Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Sun, 29 Sep 2019 19:49:05 +0800 Subject: [PATCH 09/11] Update test_.py --- tests/c_api_test/test_.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/c_api_test/test_.py b/tests/c_api_test/test_.py index d712e4f4fce0..fa152d284ccb 100644 --- a/tests/c_api_test/test_.py +++ b/tests/c_api_test/test_.py @@ -141,9 +141,9 @@ def load_from_csc(filename, reference): c_array(ctypes.c_int, csr.indices), csr.data.ctypes.data_as(ctypes.POINTER(ctypes.c_void_p)), dtype_float64, - len(csr.indptr), - len(csr.data), - csr.shape[0], + ctypes.c_int64(len(csr.indptr)), + ctypes.c_int64(len(csr.data)), + ctypes.c_int64(csr.shape[0]), c_str('max_bin=15'), ref, ctypes.byref(handle)) From ac981fb196b874b133dafd9254062ee5062c5df7 Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Mon, 30 Sep 2019 00:27:29 +0800 Subject: [PATCH 10/11] Update test_.py --- tests/c_api_test/test_.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/c_api_test/test_.py b/tests/c_api_test/test_.py index fa152d284ccb..d98761eddc78 100644 --- a/tests/c_api_test/test_.py +++ b/tests/c_api_test/test_.py @@ -105,9 +105,9 @@ def load_from_csr(filename, reference): c_array(ctypes.c_int, csr.indices), csr.data.ctypes.data_as(ctypes.POINTER(ctypes.c_void_p)), dtype_float64, - len(csr.indptr), - len(csr.data), - csr.shape[1], + ctypes.c_int64(len(csr.indptr)), + ctypes.c_int64(len(csr.data)), + ctypes.c_int64(csr.shape[1]), c_str('max_bin=15'), ref, ctypes.byref(handle)) From 4b1f0c085b12a49be64fee05d14cbcd74c9012e6 Mon Sep 17 00:00:00 2001 From: StrikerRUS Date: Thu, 3 Oct 2019 01:43:34 +0300 Subject: [PATCH 11/11] added tests --- tests/python_package_test/test_basic.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index ea40e98d65c1..303256dc330d 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -6,6 +6,8 @@ import lightgbm as lgb import numpy as np + +from scipy import sparse from sklearn.datasets import load_breast_cancer, dump_svmlight_file, load_svmlight_file from sklearn.model_selection import train_test_split @@ -53,6 +55,7 @@ def test(self): # check saved model persistence bst = lgb.Booster(params, model_file="model.txt") + os.remove("model.txt") pred_from_model_file = bst.predict(X_test) self.assertEqual(len(pred_from_matr), len(pred_from_model_file)) for preds in zip(pred_from_matr, pred_from_model_file): @@ -67,6 +70,25 @@ def test(self): # scores likely to be different, but prediction should still be the same self.assertEqual(preds[0] > 0, preds[1] > 0) + # test that shape is checked during prediction + bad_X_test = X_test[:, 1:] + bad_shape_error_msg = "The number of features in data*" + np.testing.assert_raises_regex(lgb.basic.LightGBMError, bad_shape_error_msg, + bst.predict, bad_X_test) + np.testing.assert_raises_regex(lgb.basic.LightGBMError, bad_shape_error_msg, + bst.predict, sparse.csr_matrix(bad_X_test)) + np.testing.assert_raises_regex(lgb.basic.LightGBMError, bad_shape_error_msg, + bst.predict, sparse.csc_matrix(bad_X_test)) + with open(tname, "w+b") as f: + dump_svmlight_file(bad_X_test, y_test, f) + np.testing.assert_raises_regex(lgb.basic.LightGBMError, bad_shape_error_msg, + bst.predict, tname) + with open(tname, "w+b") as f: + dump_svmlight_file(X_test, y_test, f, zero_based=False) + np.testing.assert_raises_regex(lgb.basic.LightGBMError, bad_shape_error_msg, + bst.predict, tname) + os.remove(tname) + def test_chunked_dataset(self): X_train, X_test, y_train, y_test = train_test_split(*load_breast_cancer(True), test_size=0.1, random_state=2)