Skip to content

Commit 86871d4

Browse files
committed
[JVM] Add Iterator loading API
1 parent 770b345 commit 86871d4

File tree

10 files changed

+451
-5
lines changed

10 files changed

+451
-5
lines changed

include/xgboost/c_api.h

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
#endif
1313

1414
// XGBoost C API will include APIs in Rabit C API
15+
XGB_EXTERN_C {
16+
#include <stdio.h>
17+
}
1518
#include <rabit/c_api.h>
1619

1720
#if defined(_MSC_VER) || defined(_WIN32)
@@ -26,6 +29,51 @@ typedef unsigned long bst_ulong; // NOLINT(*)
2629
typedef void *DMatrixHandle;
2730
/*! \brief handle to Booster */
2831
typedef void *BoosterHandle;
32+
/*! \brief handle to a data iterator */
33+
typedef void *DataIterHandle;
34+
/*! \brief handle to a internal data holder. */
35+
typedef void *DataHolderHandle;
36+
37+
/*! \brief Mini batch used in XGBoost Data Iteration */
38+
typedef struct {
39+
/*! \brief number of rows in the minibatch */
40+
size_t size;
41+
/*! \brief row pointer to the rows in the data */
42+
long* offset; // NOLINT(*)
43+
/*! \brief labels of each instance */
44+
float* label;
45+
/*! \brief weight of each instance, can be NULL */
46+
float* weight;
47+
/*! \brief feature index */
48+
int* index;
49+
/*! \brief feature values */
50+
float* value;
51+
} XGBoostBatchCSR;
52+
53+
54+
/*!
55+
* \brief Callback to set the data to handle,
56+
* \param handle The handle to the callback.
57+
* \param batch The data content to be setted.
58+
*/
59+
XGB_EXTERN_C typedef int XGBCallbackSetData(
60+
DataHolderHandle handle, XGBoostBatchCSR batch);
61+
62+
/*!
63+
* \brief The data reading callback function.
64+
* The iterator will be able to give subset of batch in the data.
65+
*
66+
* If there is data, the function will call set_function to set the data.
67+
*
68+
* \param data_handle The handle to the callback.
69+
* \param set_function The batch returned by the iterator
70+
* \param set_function_handle The handle to be passed to set function.
71+
* \return 0 if we are reaching the end and batch is not returned.
72+
*/
73+
XGB_EXTERN_C typedef int XGBCallbackDataIterNext(
74+
DataIterHandle data_handle,
75+
XGBCallbackSetData* set_function,
76+
DataHolderHandle set_function_handle);
2977

3078
/*!
3179
* \brief get string message of the last error
@@ -50,6 +98,20 @@ XGB_DLL int XGDMatrixCreateFromFile(const char *fname,
5098
int silent,
5199
DMatrixHandle *out);
52100

101+
/*!
102+
* \brief Create a DMatrix from a data iterator.
103+
* \param data_handle The handle to the data.
104+
* \param callback The callback to get the data.
105+
* \param cache_info Additional information about cache file, can be null.
106+
* \param out The created DMatrix
107+
* \return 0 when success, -1 when failure happens.
108+
*/
109+
XGB_DLL int XGDMatrixCreateFromDataIter(
110+
DataIterHandle data_handle,
111+
XGBCallbackDataIterNext* callback,
112+
const char* cache_info,
113+
DMatrixHandle *out);
114+
53115
/*!
54116
* \brief create a matrix content from csr format
55117
* \param indptr pointer to row headers

jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/DMatrix.java

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
package ml.dmlc.xgboost4j;
1717

1818
import java.io.IOException;
19+
import java.util.Iterator;
1920

2021
import org.apache.commons.logging.Log;
2122
import org.apache.commons.logging.LogFactory;
@@ -47,6 +48,33 @@ public static enum SparseType {
4748
CSC;
4849
}
4950

51+
/**
52+
* Create DMatrix from iterator.
53+
*
54+
* @param iter The data iterator of mini batch to provide the data.
55+
* @param cache_info Cache path information, used for external memory setting, can be null.
56+
* @throws XGBoostError
57+
*/
58+
public DMatrix(Iterator<DataBatch> iter, String cache_info) throws XGBoostError {
59+
if (iter == null) {
60+
throw new NullPointerException("iter: null");
61+
}
62+
try {
63+
logger.info(iter.getClass().getMethod("next").toString());
64+
} catch(NoSuchMethodException e) {
65+
logger.info(e.toString());
66+
}
67+
long[] out = new long[1];
68+
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromDataIter(iter, cache_info, out));
69+
handle = out[0];
70+
}
71+
72+
/**
73+
* Create DMatrix by loading libsvm file from dataPath
74+
*
75+
* @param dataPath The path to the data.
76+
* @throws XGBoostError
77+
*/
5078
public DMatrix(String dataPath) throws XGBoostError {
5179
if (dataPath == null) {
5280
throw new NullPointerException("dataPath: null");
@@ -56,6 +84,14 @@ public DMatrix(String dataPath) throws XGBoostError {
5684
handle = out[0];
5785
}
5886

87+
/**
88+
* Create DMatrix from Sparse matrix in CSR/CSC format.
89+
* @param headers The row index of the matrix.
90+
* @param indices The indices of presenting entries.
91+
* @param data The data content.
92+
* @param st Type of sparsity.
93+
* @throws XGBoostError
94+
*/
5995
public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) throws XGBoostError {
6096
long[] out = new long[1];
6197
if (st == SparseType.CSR) {
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package ml.dmlc.xgboost4j;
2+
3+
/**
4+
* A mini-batch of data that can be converted to DMatrix.
5+
* The data is in sparse matrix CSR format.
6+
*
7+
* Usually this object is not needed.
8+
*
9+
* This class is used to support advanced creation of DMatrix from Iterator of DataBatch,
10+
*/
11+
public class DataBatch {
12+
/** The offset of each rows in the sparse matrix */
13+
long[] rowOffset = null;
14+
/** weight of each data point, can be null */
15+
float[] weight = null;
16+
/** label of each data point, can be null */
17+
float[] label = null;
18+
/** index of each feature(column) in the sparse matrix */
19+
int[] featureIndex = null;
20+
/** value of each non-missing entry in the sparse matrix */
21+
float[] featureValue = null;
22+
/**
23+
* Get number of rows in the data batch.
24+
* @return Number of rows in the data batch.
25+
*/
26+
public int numRows() {
27+
return rowOffset.length - 1;
28+
}
29+
30+
/**
31+
* Shallow copy a DataBatch
32+
* @return a copy of the batch
33+
*/
34+
public DataBatch shallowCopy() {
35+
DataBatch b = new DataBatch();
36+
b.rowOffset = this.rowOffset;
37+
b.weight = this.weight;
38+
b.label = this.label;
39+
b.featureIndex = this.featureIndex;
40+
b.featureValue = this.featureValue;
41+
return b;
42+
}
43+
}

jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XgboostJNI.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
*/
1616
package ml.dmlc.xgboost4j;
1717

18+
1819
/**
1920
* xgboost JNI functions
2021
* change 2015-7-6: *use a long[] (length=1) as container of handle to get the output DMatrix or Booster
@@ -26,6 +27,8 @@ class XgboostJNI {
2627

2728
public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out);
2829

30+
public final static native int XGDMatrixCreateFromDataIter(java.util.Iterator<DataBatch> iter, String cache_info, long[] out);
31+
2932
public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices, float[] data,
3033
long[] out);
3134

jvm-packages/xgboost4j/src/native/xgboost4j.cpp

Lines changed: 139 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,124 @@
2020
#include <vector>
2121
#include <string>
2222

23-
//helper functions
24-
//set handle
23+
// helper functions
24+
// set handle
2525
void setHandle(JNIEnv *jenv, jlongArray jhandle, void* handle) {
2626
long out = (long) handle;
2727
jenv->SetLongArrayRegion(jhandle, 0, 1, &out);
2828
}
2929

30+
// global JVM
31+
static JavaVM* global_jvm = nullptr;
32+
33+
// overrides JNI on load
34+
jint JNI_OnLoad(JavaVM *vm, void *reserved) {
35+
global_jvm = vm;
36+
return JNI_VERSION_1_6;
37+
}
38+
39+
XGB_EXTERN_C int XGBoost4jCallbackDataIterNext(
40+
DataIterHandle data_handle,
41+
XGBCallbackSetData* set_function,
42+
DataHolderHandle set_function_handle) {
43+
jobject jiter = static_cast<jobject>(data_handle);
44+
JNIEnv* jenv;
45+
int jni_status = global_jvm->GetEnv((void **)&jenv, JNI_VERSION_1_6);
46+
if (jni_status == JNI_EDETACHED) {
47+
global_jvm->AttachCurrentThread(reinterpret_cast<void **>(&jenv), nullptr);
48+
} else {
49+
CHECK(jni_status == JNI_OK);
50+
}
51+
try {
52+
jclass iterClass = jenv->FindClass("java/util/Iterator");
53+
jmethodID hasNext = jenv->GetMethodID(iterClass,
54+
"hasNext", "()Z");
55+
jmethodID next = jenv->GetMethodID(iterClass,
56+
"next", "()Ljava/lang/Object;");
57+
int ret_value;
58+
if (jenv->CallBooleanMethod(jiter, hasNext)) {
59+
ret_value = 1;
60+
jobject batch = jenv->CallObjectMethod(jiter, next);
61+
jclass batchClass = jenv->GetObjectClass(batch);
62+
jlongArray joffset = (jlongArray)jenv->GetObjectField(
63+
batch, jenv->GetFieldID(batchClass, "rowOffset", "[J"));
64+
jfloatArray jlabel = (jfloatArray)jenv->GetObjectField(
65+
batch, jenv->GetFieldID(batchClass, "label", "[F"));
66+
jfloatArray jweight = (jfloatArray)jenv->GetObjectField(
67+
batch, jenv->GetFieldID(batchClass, "weight", "[F"));
68+
jintArray jindex = (jintArray)jenv->GetObjectField(
69+
batch, jenv->GetFieldID(batchClass, "featureIndex", "[I"));
70+
jfloatArray jvalue = (jfloatArray)jenv->GetObjectField(
71+
batch, jenv->GetFieldID(batchClass, "featureValue", "[F"));
72+
XGBoostBatchCSR cbatch;
73+
cbatch.size = jenv->GetArrayLength(joffset) - 1;
74+
cbatch.offset = jenv->GetLongArrayElements(joffset, 0);
75+
if (jlabel != nullptr) {
76+
cbatch.label = jenv->GetFloatArrayElements(jlabel, 0);
77+
CHECK_EQ(jenv->GetArrayLength(jlabel), static_cast<long>(cbatch.size))
78+
<< "batch.label.length must equal batch.numRows()";
79+
} else {
80+
cbatch.label = nullptr;
81+
}
82+
if (jweight != nullptr) {
83+
cbatch.weight = jenv->GetFloatArrayElements(jweight, 0);
84+
CHECK_EQ(jenv->GetArrayLength(jweight), static_cast<long>(cbatch.size))
85+
<< "batch.weight.length must equal batch.numRows()";
86+
} else {
87+
cbatch.weight = nullptr;
88+
}
89+
long max_elem = cbatch.offset[cbatch.size];
90+
cbatch.index = jenv->GetIntArrayElements(jindex, 0);
91+
cbatch.value = jenv->GetFloatArrayElements(jvalue, 0);
92+
CHECK_EQ(jenv->GetArrayLength(jindex), max_elem)
93+
<< "batch.index.length must equal batch.offset.back()";
94+
CHECK_EQ(jenv->GetArrayLength(jvalue), max_elem)
95+
<< "batch.index.length must equal batch.offset.back()";
96+
// cbatch is ready
97+
CHECK_EQ((*set_function)(set_function_handle, cbatch), 0)
98+
<< XGBGetLastError();
99+
// release the elements.
100+
jenv->ReleaseLongArrayElements(joffset, cbatch.offset, 0);
101+
jenv->DeleteLocalRef(joffset);
102+
if (jlabel != nullptr) {
103+
jenv->ReleaseFloatArrayElements(jlabel, cbatch.label, 0);
104+
jenv->DeleteLocalRef(jlabel);
105+
}
106+
if (jweight != nullptr) {
107+
jenv->ReleaseFloatArrayElements(jweight, cbatch.weight, 0);
108+
jenv->DeleteLocalRef(jweight);
109+
}
110+
jenv->ReleaseIntArrayElements(jindex, cbatch.index, 0);
111+
jenv->DeleteLocalRef(jindex);
112+
jenv->ReleaseFloatArrayElements(jvalue, cbatch.value, 0);
113+
jenv->DeleteLocalRef(jvalue);
114+
jenv->DeleteLocalRef(batch);
115+
jenv->DeleteLocalRef(batchClass);
116+
ret_value = 1;
117+
} else {
118+
ret_value = 0;
119+
}
120+
jenv->DeleteLocalRef(iterClass);
121+
// only detach if it is a async call.
122+
if (jni_status == JNI_EDETACHED) {
123+
global_jvm->DetachCurrentThread();
124+
}
125+
return ret_value;
126+
} catch(dmlc::Error e) {
127+
// only detach if it is a async call.
128+
if (jni_status == JNI_EDETACHED) {
129+
global_jvm->DetachCurrentThread();
130+
}
131+
LOG(FATAL) << e.what();
132+
return -1;
133+
}
134+
}
135+
136+
/*
137+
* Class: ml_dmlc_xgboost4j_XgboostJNI
138+
* Method: XGBGetLastError
139+
* Signature: ()Ljava/lang/String;
140+
*/
30141
JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBGetLastError
31142
(JNIEnv *jenv, jclass jcls) {
32143
jstring jresult = 0;
@@ -37,6 +148,32 @@ JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBGetLastError
37148
return jresult;
38149
}
39150

151+
/*
152+
* Class: ml_dmlc_xgboost4j_XgboostJNI
153+
* Method: XGDMatrixCreateFromDataIter
154+
* Signature: (Ljava/util/Iterator;Ljava/lang/String;[J)I
155+
*/
156+
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromDataIter
157+
(JNIEnv *jenv, jclass jcls, jobject jiter, jstring jcache_info, jlongArray jout) {
158+
DMatrixHandle result;
159+
const char* cache_info = nullptr;
160+
if (jcache_info != nullptr) {
161+
cache_info = jenv->GetStringUTFChars(jcache_info, 0);
162+
}
163+
int ret = XGDMatrixCreateFromDataIter(
164+
jiter, XGBoost4jCallbackDataIterNext, cache_info, &result);
165+
if (cache_info) {
166+
jenv->ReleaseStringUTFChars(jcache_info, cache_info);
167+
}
168+
setHandle(jenv, jout, result);
169+
return ret;
170+
}
171+
172+
/*
173+
* Class: ml_dmlc_xgboost4j_XgboostJNI
174+
* Method: XGDMatrixCreateFromFile
175+
* Signature: (Ljava/lang/String;I[J)I
176+
*/
40177
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromFile
41178
(JNIEnv *jenv, jclass jcls, jstring jfname, jint jsilent, jlongArray jout) {
42179
DMatrixHandle result;

jvm-packages/xgboost4j/src/native/xgboost4j.h

Lines changed: 8 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)