diff --git a/lib/lightgbm/booster.rb b/lib/lightgbm/booster.rb index c9dc1dd..1f49bcd 100644 --- a/lib/lightgbm/booster.rb +++ b/lib/lightgbm/booster.rb @@ -141,10 +141,16 @@ def num_trees out.read_int end - def predict(input, start_iteration: 0, num_iteration: nil, raw_score: false, pred_leaf: false, pred_contrib: false, **params) - num_iteration ||= best_iteration + def predict(data, start_iteration: 0, num_iteration: -1, raw_score: false, pred_leaf: false, pred_contrib: false, **kwargs) + if num_iteration.nil? + if start_iteration <= 0 + num_iteration = best_iteration + else + num_iteration = -1 + end + end - if input.is_a?(Dataset) + if data.is_a?(Dataset) raise TypeError, "Cannot use Dataset instance for prediction, please use raw data instead" end @@ -159,47 +165,15 @@ def predict(input, start_iteration: 0, num_iteration: nil, raw_score: false, pre predict_type = FFI::C_API_PREDICT_CONTRIB end - input = - if daru?(input) - input[*cached_feature_name].map_rows(&:to_a) - elsif input.is_a?(Hash) # sort feature.values to match the order of model.feature_name - sorted_feature_values(input) - elsif input.is_a?(Array) && input.first.is_a?(Hash) # on multiple elems, if 1st is hash, assume they all are - input.map(&method(:sorted_feature_values)) - elsif rover?(input) - # TODO improve performance - input[cached_feature_name].to_numo.to_a - else - input.to_a - end - - singular = !input.first.is_a?(Array) - input = [input] if singular - - nrow = input.count - n_preds = - num_preds( + preds, nrow, singular = + preds_for_data( + data, start_iteration, num_iteration, - nrow, - predict_type + predict_type, + **kwargs ) - flat_input = input.flatten - handle_missing(flat_input) - data = ::FFI::MemoryPointer.new(:double, input.count * input.first.count) - data.write_array_of_double(flat_input) - - out_len = ::FFI::MemoryPointer.new(:int64) - out_result = ::FFI::MemoryPointer.new(:double, n_preds) - check_result FFI.LGBM_BoosterPredictForMat(handle_pointer, data, 1, input.count, input.first.count, 1, predict_type, start_iteration, num_iteration, params_str(params), out_len, out_result) - - if n_preds != out_len.read_int64 - raise Error, "Wrong length for predict results" - end - - preds = out_result.read_array_of_double(out_len.read_int64) - if pred_leaf preds = preds.map(&:to_i) end @@ -287,6 +261,51 @@ def num_class out.read_int end + def preds_for_data(input, start_iteration, num_iteration, predict_type, **params) + input = + if daru?(input) + input[*cached_feature_name].map_rows(&:to_a) + elsif input.is_a?(Hash) # sort feature.values to match the order of model.feature_name + sorted_feature_values(input) + elsif input.is_a?(Array) && input.first.is_a?(Hash) # on multiple elems, if 1st is hash, assume they all are + input.map(&method(:sorted_feature_values)) + elsif rover?(input) + # TODO improve performance + input[cached_feature_name].to_numo.to_a + else + input.to_a + end + + singular = !input.first.is_a?(Array) + input = [input] if singular + + nrow = input.count + n_preds = + num_preds( + start_iteration, + num_iteration, + nrow, + predict_type + ) + + flat_input = input.flatten + handle_missing(flat_input) + data = ::FFI::MemoryPointer.new(:double, input.count * input.first.count) + data.write_array_of_double(flat_input) + + out_len = ::FFI::MemoryPointer.new(:int64) + out_result = ::FFI::MemoryPointer.new(:double, n_preds) + check_result FFI.LGBM_BoosterPredictForMat(handle_pointer, data, 1, input.count, input.first.count, 1, predict_type, start_iteration, num_iteration, params_str(params), out_len, out_result) + + if n_preds != out_len.read_int64 + raise Error, "Wrong length for predict results" + end + + preds = out_result.read_array_of_double(out_len.read_int64) + + [preds, nrow, singular] + end + def num_preds(start_iteration, num_iteration, nrow, predict_type) out = ::FFI::MemoryPointer.new(:int64) check_result FFI.LGBM_BoosterCalcNumPredict(handle_pointer, nrow, predict_type, start_iteration, num_iteration, out)