2020#include < vector>
2121#include < string>
2222
23- // helper functions
24- // set handle
23+ // helper functions
24+ // set handle
2525void setHandle (JNIEnv *jenv, jlongArray jhandle, void * handle) {
2626 long out = (long ) handle;
2727 jenv->SetLongArrayRegion (jhandle, 0 , 1 , &out);
2828}
2929
30+ // global JVM
31+ static JavaVM* global_jvm = nullptr ;
32+
33+ // overrides JNI on load
34+ jint JNI_OnLoad (JavaVM *vm, void *reserved) {
35+ global_jvm = vm;
36+ return JNI_VERSION_1_6;
37+ }
38+
39+ XGB_EXTERN_C int XGBoost4jCallbackDataIterNext (
40+ DataIterHandle data_handle,
41+ XGBCallbackSetData* set_function,
42+ DataHolderHandle set_function_handle) {
43+ jobject jiter = static_cast <jobject>(data_handle);
44+ JNIEnv* jenv;
45+ int jni_status = global_jvm->GetEnv ((void **)&jenv, JNI_VERSION_1_6);
46+ if (jni_status == JNI_EDETACHED) {
47+ global_jvm->AttachCurrentThread (reinterpret_cast <void **>(&jenv), nullptr );
48+ } else {
49+ CHECK (jni_status == JNI_OK);
50+ }
51+ try {
52+ jclass iterClass = jenv->FindClass (" java/util/Iterator" );
53+ jmethodID hasNext = jenv->GetMethodID (iterClass,
54+ " hasNext" , " ()Z" );
55+ jmethodID next = jenv->GetMethodID (iterClass,
56+ " next" , " ()Ljava/lang/Object;" );
57+ int ret_value;
58+ if (jenv->CallBooleanMethod (jiter, hasNext)) {
59+ ret_value = 1 ;
60+ jobject batch = jenv->CallObjectMethod (jiter, next);
61+ jclass batchClass = jenv->GetObjectClass (batch);
62+ jlongArray joffset = (jlongArray)jenv->GetObjectField (
63+ batch, jenv->GetFieldID (batchClass, " rowOffset" , " [J" ));
64+ jfloatArray jlabel = (jfloatArray)jenv->GetObjectField (
65+ batch, jenv->GetFieldID (batchClass, " label" , " [F" ));
66+ jfloatArray jweight = (jfloatArray)jenv->GetObjectField (
67+ batch, jenv->GetFieldID (batchClass, " weight" , " [F" ));
68+ jintArray jindex = (jintArray)jenv->GetObjectField (
69+ batch, jenv->GetFieldID (batchClass, " featureIndex" , " [I" ));
70+ jfloatArray jvalue = (jfloatArray)jenv->GetObjectField (
71+ batch, jenv->GetFieldID (batchClass, " featureValue" , " [F" ));
72+ XGBoostBatchCSR cbatch;
73+ cbatch.size = jenv->GetArrayLength (joffset) - 1 ;
74+ cbatch.offset = jenv->GetLongArrayElements (joffset, 0 );
75+ if (jlabel != nullptr ) {
76+ cbatch.label = jenv->GetFloatArrayElements (jlabel, 0 );
77+ CHECK_EQ (jenv->GetArrayLength (jlabel), static_cast <long >(cbatch.size ))
78+ << " batch.label.length must equal batch.numRows()" ;
79+ } else {
80+ cbatch.label = nullptr ;
81+ }
82+ if (jweight != nullptr ) {
83+ cbatch.weight = jenv->GetFloatArrayElements (jweight, 0 );
84+ CHECK_EQ (jenv->GetArrayLength (jweight), static_cast <long >(cbatch.size ))
85+ << " batch.weight.length must equal batch.numRows()" ;
86+ } else {
87+ cbatch.weight = nullptr ;
88+ }
89+ long max_elem = cbatch.offset [cbatch.size ];
90+ cbatch.index = jenv->GetIntArrayElements (jindex, 0 );
91+ cbatch.value = jenv->GetFloatArrayElements (jvalue, 0 );
92+ CHECK_EQ (jenv->GetArrayLength (jindex), max_elem)
93+ << " batch.index.length must equal batch.offset.back()" ;
94+ CHECK_EQ (jenv->GetArrayLength (jvalue), max_elem)
95+ << " batch.index.length must equal batch.offset.back()" ;
96+ // cbatch is ready
97+ CHECK_EQ ((*set_function)(set_function_handle, cbatch), 0 )
98+ << XGBGetLastError ();
99+ // release the elements.
100+ jenv->ReleaseLongArrayElements (joffset, cbatch.offset , 0 );
101+ jenv->DeleteLocalRef (joffset);
102+ if (jlabel != nullptr ) {
103+ jenv->ReleaseFloatArrayElements (jlabel, cbatch.label , 0 );
104+ jenv->DeleteLocalRef (jlabel);
105+ }
106+ if (jweight != nullptr ) {
107+ jenv->ReleaseFloatArrayElements (jweight, cbatch.weight , 0 );
108+ jenv->DeleteLocalRef (jweight);
109+ }
110+ jenv->ReleaseIntArrayElements (jindex, cbatch.index , 0 );
111+ jenv->DeleteLocalRef (jindex);
112+ jenv->ReleaseFloatArrayElements (jvalue, cbatch.value , 0 );
113+ jenv->DeleteLocalRef (jvalue);
114+ jenv->DeleteLocalRef (batch);
115+ jenv->DeleteLocalRef (batchClass);
116+ ret_value = 1 ;
117+ } else {
118+ ret_value = 0 ;
119+ }
120+ jenv->DeleteLocalRef (iterClass);
121+ // only detach if it is a async call.
122+ if (jni_status == JNI_EDETACHED) {
123+ global_jvm->DetachCurrentThread ();
124+ }
125+ return ret_value;
126+ } catch (dmlc::Error e) {
127+ // only detach if it is a async call.
128+ if (jni_status == JNI_EDETACHED) {
129+ global_jvm->DetachCurrentThread ();
130+ }
131+ LOG (FATAL) << e.what ();
132+ return -1 ;
133+ }
134+ }
135+
136+ /*
137+ * Class: ml_dmlc_xgboost4j_XgboostJNI
138+ * Method: XGBGetLastError
139+ * Signature: ()Ljava/lang/String;
140+ */
30141JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBGetLastError
31142 (JNIEnv *jenv, jclass jcls) {
32143 jstring jresult = 0 ;
@@ -37,6 +148,32 @@ JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBGetLastError
37148 return jresult;
38149}
39150
151+ /*
152+ * Class: ml_dmlc_xgboost4j_XgboostJNI
153+ * Method: XGDMatrixCreateFromDataIter
154+ * Signature: (Ljava/util/Iterator;Ljava/lang/String;[J)I
155+ */
156+ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromDataIter
157+ (JNIEnv *jenv, jclass jcls, jobject jiter, jstring jcache_info, jlongArray jout) {
158+ DMatrixHandle result;
159+ const char * cache_info = nullptr ;
160+ if (jcache_info != nullptr ) {
161+ cache_info = jenv->GetStringUTFChars (jcache_info, 0 );
162+ }
163+ int ret = XGDMatrixCreateFromDataIter (
164+ jiter, XGBoost4jCallbackDataIterNext, cache_info, &result);
165+ if (cache_info) {
166+ jenv->ReleaseStringUTFChars (jcache_info, cache_info);
167+ }
168+ setHandle (jenv, jout, result);
169+ return ret;
170+ }
171+
172+ /*
173+ * Class: ml_dmlc_xgboost4j_XgboostJNI
174+ * Method: XGDMatrixCreateFromFile
175+ * Signature: (Ljava/lang/String;I[J)I
176+ */
40177JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromFile
41178 (JNIEnv *jenv, jclass jcls, jstring jfname, jint jsilent, jlongArray jout) {
42179 DMatrixHandle result;
0 commit comments