Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented memory allocation strategy for different prediction types #60

Merged
merged 1 commit into from
Oct 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/main/java/com/microsoft/ml/lightgbm/PredictionType.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ public String getDescription() {
return description;
}

public boolean equals(PredictionType that) {
return this.type == that.type;
}

@Override
public String toString() {
return "PredictionType{" +
Expand Down
103 changes: 74 additions & 29 deletions src/main/java/io/github/metarank/lightgbm4j/LGBMBooster.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import com.microsoft.ml.lightgbm.*;

import java.io.*;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Locale;

import static com.microsoft.ml.lightgbm.lightgbmlib.*;
Expand All @@ -18,6 +16,7 @@ public class LGBMBooster implements AutoCloseable {


private static volatile boolean nativeLoaded = false;

static {
try {
LGBMBooster.loadNative();
Expand All @@ -28,6 +27,7 @@ public class LGBMBooster implements AutoCloseable {

/**
* Called from tests.
*
* @return true if JNI libraries were loaded successfully.
*/
public static boolean isNativeLoaded() {
Expand All @@ -37,6 +37,7 @@ public static boolean isNativeLoaded() {
/**
* Loads all corresponsing native libraries for current platform. Called from the class initializer,
* so usually there is no need to call it directly.
*
* @throws IOException
*/
public synchronized static void loadNative() throws IOException {
Expand Down Expand Up @@ -78,7 +79,7 @@ private static void loadNative(String path, String name) throws IOException {
} else {
extractResource(path, name, libFile);
}
System.out.println("Extracted file: exists=" + libFile.exists() + " path="+ libFile);
System.out.println("Extracted file: exists=" + libFile.exists() + " path=" + libFile);
try {
System.load(libFile.toString());
} catch (UnsatisfiedLinkError err) {
Expand Down Expand Up @@ -108,6 +109,7 @@ private static void copyStream(InputStream source, OutputStream target) throws I

/**
* Constructor is private because you need to have a JNI handle for native LightGBM instance.
*
* @param iterations
* @param handle
*/
Expand All @@ -118,6 +120,7 @@ private static void copyStream(InputStream source, OutputStream target) throws I

/**
* Load an existing booster from model file.
*
* @param file Filename of model
* @return Booster instance.
* @throws LGBMException
Expand All @@ -137,6 +140,7 @@ public static LGBMBooster createFromModelfile(String file) throws LGBMException

/**
* Load an existing booster from string.
*
* @param model Model string
* @return Booster instance.
* @throws LGBMException
Expand All @@ -156,6 +160,7 @@ public static LGBMBooster loadModelFromString(String model) throws LGBMException

/**
* Deallocate all native memory for the LightGBM model.
*
* @throws LGBMException
*/
@Override
Expand All @@ -168,10 +173,11 @@ public void close() throws LGBMException {

/**
* Make prediction for a new float[] dataset.
* @param input input matrix, as a 1D array. Size should be rows * cols.
* @param rows number of rows
* @param cols number of cols
* @param isRowMajor is the 1d encoding a row-major?
*
* @param input input matrix, as a 1D array. Size should be rows * cols.
* @param rows number of rows
* @param cols number of cols
* @param isRowMajor is the 1d encoding a row-major?
* @param predictionType the prediction type
* @return array of predictions
* @throws LGBMException
Expand All @@ -182,7 +188,8 @@ public double[] predictForMat(float[] input, int rows, int cols, boolean isRowMa
floatArray_setitem(dataBuffer, i, input[i]);
}
SWIGTYPE_p_long_long outLength = new_int64_tp();
SWIGTYPE_p_double outBuffer = new_doubleArray(2L * rows);
long outSize = outBufferSize(rows, cols, predictionType);
SWIGTYPE_p_double outBuffer = new_doubleArray(outSize);
int result = LGBM_BoosterPredictForMat(
voidpp_value(handle),
float_to_voidp_ptr(dataBuffer),
Expand All @@ -203,7 +210,7 @@ public double[] predictForMat(float[] input, int rows, int cols, boolean isRowMa
throw new LGBMException(LGBM_GetLastError());
} else {
long length = int64_tp_value(outLength);
double[] values = new double[(int)length];
double[] values = new double[(int) length];
for (int i = 0; i < length; i++) {
values[i] = doubleArray_getitem(outBuffer, i);
}
Expand All @@ -213,12 +220,14 @@ public double[] predictForMat(float[] input, int rows, int cols, boolean isRowMa
return values;
}
}

/**
* Make prediction for a new double[] dataset.
* @param input input matrix, as a 1D array. Size should be rows * cols.
* @param rows number of rows
* @param cols number of cols
* @param isRowMajor is the 1 d encoding a row-major?
*
* @param input input matrix, as a 1D array. Size should be rows * cols.
* @param rows number of rows
* @param cols number of cols
* @param isRowMajor is the 1 d encoding a row-major?
* @param predictionType the prediction type
* @return array of predictions
* @throws LGBMException
Expand All @@ -230,7 +239,8 @@ public double[] predictForMat(double[] input, int rows, int cols, boolean isRowM
doubleArray_setitem(dataBuffer, i, input[i]);
}
SWIGTYPE_p_long_long outLength = new_int64_tp();
SWIGTYPE_p_double outBuffer = new_doubleArray(2L * rows);
long outSize = outBufferSize(rows, cols, predictionType);
SWIGTYPE_p_double outBuffer = new_doubleArray(outSize);
int result = LGBM_BoosterPredictForMat(
voidpp_value(handle),
double_to_voidp_ptr(dataBuffer),
Expand All @@ -251,7 +261,7 @@ public double[] predictForMat(double[] input, int rows, int cols, boolean isRowM
throw new LGBMException(LGBM_GetLastError());
} else {
long length = int64_tp_value(outLength);
double[] values = new double[(int)length];
double[] values = new double[(int) length];
for (int i = 0; i < length; i++) {
values[i] = doubleArray_getitem(outBuffer, i);
}
Expand All @@ -264,7 +274,8 @@ public double[] predictForMat(double[] input, int rows, int cols, boolean isRowM

/**
* Create a new boosting learner.
* @param dataset a LGBMDataset with the training data.
*
* @param dataset a LGBMDataset with the training data.
* @param parameters Parameters in format ‘key1=value1 key2=value2’
* @return
* @throws LGBMException
Expand All @@ -281,6 +292,7 @@ public static LGBMBooster create(LGBMDataset dataset, String parameters) throws

/**
* Update the model for one iteration.
*
* @return true if there are no more splits possible, so training is finished.
* @throws LGBMException
*/
Expand All @@ -305,12 +317,13 @@ public enum FeatureImportanceType {

/**
* Save model to string.
* @param startIteration Start index of the iteration that should be saved
* @param numIteration Index of the iteration that should be saved, 0 and negative means save all
*
* @param startIteration Start index of the iteration that should be saved
* @param numIteration Index of the iteration that should be saved, 0 and negative means save all
* @param featureImportance Type of feature importance, can be FeatureImportanceType.SPLIT or FeatureImportanceType.GAIN
* @return
*/
public String saveModelToString(int startIteration, int numIteration, FeatureImportanceType featureImportance) {
public String saveModelToString(int startIteration, int numIteration, FeatureImportanceType featureImportance) {
SWIGTYPE_p_long_long outLength = new_int64_tp();
String result = LGBM_BoosterSaveModelToStringSWIG(
voidpp_value(handle),
Expand All @@ -326,6 +339,7 @@ public String saveModelToString(int startIteration, int numIteration, FeatureImp

/**
* Get names of features.
*
* @return a list of feature names.
*/
public String[] getFeatureNames() {
Expand All @@ -337,6 +351,7 @@ public String[] getFeatureNames() {

/**
* Add new validation data to booster.
*
* @param dataset dataset to validate
* @throws LGBMException
*/
Expand All @@ -349,6 +364,7 @@ public void addValidData(LGBMDataset dataset) throws LGBMException {

/**
* Get evaluation for training data and validation data.
*
* @param dataIndex Index of data, 0: training data, 1: 1st validation data, 2: 2nd validation data and so on
* @return
* @throws LGBMException
Expand All @@ -363,7 +379,7 @@ public double[] getEval(int dataIndex) throws LGBMException {
throw new LGBMException(LGBM_GetLastError());
} else {
double[] evals = new double[intp_value(outLength)];
for (int i=0; i < evals.length; i++) {
for (int i = 0; i < evals.length; i++) {
evals[i] = doubleArray_getitem(outBuffer, i);
}
delete_intp(outLength);
Expand All @@ -374,6 +390,7 @@ public double[] getEval(int dataIndex) throws LGBMException {

/**
* Get names of evaluation datasets.
*
* @return array of eval dataset names.
* @throws LGBMException
*/
Expand All @@ -386,7 +403,8 @@ public String[] getEvalNames() throws LGBMException {

/**
* Get model feature importance.
* @param numIteration Number of iterations for which feature importance is calculated, 0 or less means use all
*
* @param numIteration Number of iterations for which feature importance is calculated, 0 or less means use all
* @param importanceType GAIN or SPLIT
* @return Result array with feature importance
* @throws LGBMException
Expand All @@ -405,7 +423,7 @@ public double[] featureImportance(int numIteration, FeatureImportanceType import
throw new LGBMException(LGBM_GetLastError());
} else {
double[] importance = new double[numFeatures];
for (int i=0; i < numFeatures; i++) {
for (int i = 0; i < numFeatures; i++) {
importance[i] = doubleArray_getitem(outBuffer, i);
}
delete_doubleArray(outBuffer);
Expand All @@ -415,6 +433,7 @@ public double[] featureImportance(int numIteration, FeatureImportanceType import

/**
* Get number of features.
*
* @return number of features
* @throws LGBMException
*/
Expand All @@ -434,7 +453,8 @@ public int getNumFeature() throws LGBMException {
/**
* Make prediction for a new double[] row dataset. This method re-uses the internal predictor structure from previous calls
* and is optimized for single row invocation.
* @param data input vector
*
* @param data input vector
* @param predictionType the prediction type
* @return score
* @throws LGBMException
Expand All @@ -445,7 +465,8 @@ public double predictForMatSingleRow(double[] data, PredictionType predictionTyp
doubleArray_setitem(dataBuffer, i, data[i]);
}
SWIGTYPE_p_long_long outLength = new_int64_tp();
SWIGTYPE_p_double outBuffer = new_doubleArray(1);
long outBufferSize = outBufferSize(1, data.length, predictionType);
SWIGTYPE_p_double outBuffer = new_doubleArray(outBufferSize);

int result = LGBM_BoosterPredictForMatSingleRow(
voidpp_value(handle),
Expand All @@ -459,15 +480,15 @@ public double predictForMatSingleRow(double[] data, PredictionType predictionTyp
"",
outLength,
outBuffer
);
);
if (result < 0) {
delete_doubleArray(dataBuffer);
delete_doubleArray(outBuffer);
delete_int64_tp(outLength);
throw new LGBMException(LGBM_GetLastError());
} else {
long length = int64_tp_value(outLength);
double[] values = new double[(int)length];
double[] values = new double[(int) length];
for (int i = 0; i < length; i++) {
values[i] = doubleArray_getitem(outBuffer, i);
}
Expand All @@ -477,10 +498,12 @@ public double predictForMatSingleRow(double[] data, PredictionType predictionTyp
return values[0];
}
}

/**
* Make prediction for a new float[] row dataset. This method re-uses the internal predictor structure from previous calls
* and is optimized for single row invocation.
* @param data input vector
*
* @param data input vector
* @param predictionType the prediction type
* @return score
* @throws LGBMException
Expand All @@ -491,7 +514,8 @@ public double predictForMatSingleRow(float[] data, PredictionType predictionType
floatArray_setitem(dataBuffer, i, data[i]);
}
SWIGTYPE_p_long_long outLength = new_int64_tp();
SWIGTYPE_p_double outBuffer = new_doubleArray(1);
long outBufferSize = outBufferSize(1, data.length, predictionType);
SWIGTYPE_p_double outBuffer = new_doubleArray(outBufferSize);

int result = LGBM_BoosterPredictForMatSingleRow(
voidpp_value(handle),
Expand All @@ -513,7 +537,7 @@ public double predictForMatSingleRow(float[] data, PredictionType predictionType
throw new LGBMException(LGBM_GetLastError());
} else {
long length = int64_tp_value(outLength);
double[] values = new double[(int)length];
double[] values = new double[(int) length];
for (int i = 0; i < length; i++) {
values[i] = doubleArray_getitem(outBuffer, i);
}
Expand All @@ -536,4 +560,25 @@ private int importanceType(FeatureImportanceType tpe) {
}
return importanceType;
}

/**
* Calculates the output buffer size for the different prediction types. See the notes at:
* <a href="https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterPredictForMat">predictForMat</a> &
* <a href="https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterPredictForMatSingleRow">predictForMatSingleRow</a>
* for more info.
*
* @param rows the number of rows in the input data
* @param cols the number of columns in the input data
* @param predictionType the type of prediction we are trying to achieve
* @return number of elements in the output result (size)
*/
private long outBufferSize(int rows, int cols, PredictionType predictionType) {
long defaultSize = 2L * rows;
if (PredictionType.C_API_PREDICT_CONTRIB.equals(predictionType))
return defaultSize * (cols + 1);
else if (PredictionType.C_API_PREDICT_LEAF_INDEX.equals(predictionType))
return defaultSize * iterations;
else // for C_API_PREDICT_NORMAL & C_API_PREDICT_RAW_SCORE
return defaultSize;
}
}
Loading