From 83d124e399e16e4da2372331702a60b9a443c471 Mon Sep 17 00:00:00 2001 From: Mridul Jain Date: Tue, 19 Jul 2016 18:57:04 -0700 Subject: [PATCH 1/3] Exception handling JNI --- caffe-distri/include/common.hpp | 15 +- caffe-distri/include/util/socket.hpp | 2 +- caffe-distri/src/main/cpp/CaffeNet.cpp | 8 +- caffe-distri/src/main/cpp/common.cpp | 68 ++- caffe-distri/src/main/cpp/jni/JniCaffeNet.cpp | 532 ++++++++++++------ .../src/main/cpp/jni/JniFloatArray.cpp | 79 ++- .../src/main/cpp/jni/JniFloatBlob.cpp | 256 ++++++--- .../main/cpp/jni/JniFloatDataTransformer.cpp | 81 ++- caffe-distri/src/main/cpp/jni/JniMat.cpp | 195 ++++--- .../src/main/cpp/jni/JniMatVector.cpp | 168 ++++-- caffe-distri/src/main/cpp/util/socket.cpp | 25 +- .../com/yahoo/ml/jcaffe/CaffeNetTest.java | 128 ++++- .../com/yahoo/ml/jcaffe/FloatArrayTest.java | 43 ++ .../com/yahoo/ml/jcaffe/FloatBlobTest.java | 70 ++- .../java/com/yahoo/ml/jcaffe/MatTest.java | 284 ++++++---- .../com/yahoo/ml/jcaffe/MatVectorTest.java | 138 +++++ .../com/yahoo/ml/jcaffe/TransformTest.java | 256 +++++---- 17 files changed, 1648 insertions(+), 700 deletions(-) create mode 100644 caffe-distri/src/test/java/com/yahoo/ml/jcaffe/FloatArrayTest.java create mode 100644 caffe-distri/src/test/java/com/yahoo/ml/jcaffe/MatVectorTest.java diff --git a/caffe-distri/include/common.hpp b/caffe-distri/include/common.hpp index 3333156..bf124fa 100755 --- a/caffe-distri/include/common.hpp +++ b/caffe-distri/include/common.hpp @@ -15,13 +15,14 @@ using namespace caffe; #ifdef __cplusplus extern "C" { #endif - - bool SetNativeAddress(JNIEnv *env, jobject object, void* address); - void* GetNativeAddress(JNIEnv *env, jobject object); - - bool GetStringVector(vector& vec, JNIEnv *env, jobjectArray array, int length); - bool GetFloatBlobVector(vector< Blob* >& vec, JNIEnv *env, jobjectArray array, int length); - + + bool SetNativeAddress(JNIEnv *env, jobject object, void* address); + void* GetNativeAddress(JNIEnv *env, jobject object); + + bool GetStringVector(vector& vec, JNIEnv *env, jobjectArray array, int length); + bool GetFloatBlobVector(vector< Blob* >& vec, JNIEnv *env, jobjectArray array, int length); + void ThrowJavaException(const std::exception& ex, JNIEnv* env); + void ThrowCosJavaException(char* message, JNIEnv* env); #ifdef __cplusplus } #endif diff --git a/caffe-distri/include/util/socket.hpp b/caffe-distri/include/util/socket.hpp index 8708d85..3fbaa2b 100644 --- a/caffe-distri/include/util/socket.hpp +++ b/caffe-distri/include/util/socket.hpp @@ -53,7 +53,7 @@ class SocketChannel { public: SocketChannel(); ~SocketChannel(); - void Connect(string peer); + bool Connect(string peer); int client_fd; caffe::BlockingQueue receive_queue; int serving_fd; diff --git a/caffe-distri/src/main/cpp/CaffeNet.cpp b/caffe-distri/src/main/cpp/CaffeNet.cpp index 120a290..c6c5d4a 100755 --- a/caffe-distri/src/main/cpp/CaffeNet.cpp +++ b/caffe-distri/src/main/cpp/CaffeNet.cpp @@ -368,7 +368,7 @@ bool RDMACaffeNet::connect(vector& peer_addresses) { rdma_channels_, this->node_rank_)); // Pair devices for map-reduce synchronization - this->syncs_[0]->prepare(this->local_devices_, + this->syncs_[0]->Prepare(this->local_devices_, &this->syncs_); return true; @@ -382,7 +382,8 @@ bool SocketCaffeNet::connect(vector& peer_addresses) { if (i != this->node_rank_) { const char* addr = peer_addresses[i]; string addr_str(addr, strlen(addr)); - sockt_channels_[i]->Connect(addr_str); + if(!sockt_channels_[i]->Connect(addr_str)) + return false; } #ifndef CPU_ONLY @@ -559,9 +560,8 @@ void CaffeNet::predict(int solver_index, input_adapter_[solver_index]->feed(input_data, input_labels); //invoke network's Forward operation - const vector*> dummy_bottom_vec; CHECK(nets_[solver_index]); - nets_[solver_index]->Forward(dummy_bottom_vec); + nets_[solver_index]->Forward(); //grab the output blobs via names int num_features = output_blob_names.size(); diff --git a/caffe-distri/src/main/cpp/common.cpp b/caffe-distri/src/main/cpp/common.cpp index 54a5a9b..7410157 100755 --- a/caffe-distri/src/main/cpp/common.cpp +++ b/caffe-distri/src/main/cpp/common.cpp @@ -13,13 +13,13 @@ bool SetNativeAddress(JNIEnv *env, jobject object, void* address) { } /* Get a reference to JVM object class */ jclass claz = env->GetObjectClass(object); - if (claz == NULL) { + if (claz == NULL || env->ExceptionCheck()) { LOG(ERROR) << "unable to get object's class"; return false; } /* Locate init(long) method */ jmethodID methodId = env->GetMethodID(claz, "init", "(J)V"); - if (methodId == NULL) { + if (methodId == NULL || env->ExceptionCheck()) { LOG(ERROR) << "could not locate init() method"; return false; } @@ -40,13 +40,13 @@ void* GetNativeAddress(JNIEnv *env, jobject object) { } /* Get a reference to JVM object class */ jclass claz = env->GetObjectClass(object); - if (claz == NULL) { + if (claz == NULL || env->ExceptionCheck()) { LOG(ERROR) << "unable to get object's class"; return 0; } /* Getting the field id in the class */ jfieldID fieldId = env->GetFieldID(claz, "address", "J"); - if (fieldId == NULL) { + if (fieldId == NULL || env->ExceptionCheck()) { LOG(ERROR) << "could not locate field 'address'"; return 0; } @@ -57,11 +57,15 @@ void* GetNativeAddress(JNIEnv *env, jobject object) { bool GetStringVector(vector& vec, JNIEnv *env, jobjectArray array, int length) { for (int i = 0; i < length; i++) { jstring addr = (jstring)env->GetObjectArrayElement(array, i); - if (addr == NULL) { + if (addr == NULL || env->ExceptionCheck()) { LOG(INFO) << i << "-th string is NULL"; vec[i] = NULL; } else { const char *cStr = env->GetStringUTFChars(addr, NULL); + if (env->ExceptionCheck()) { + LOG(ERROR) << "GetStringUTFChars failed"; + return false; + } vec[i] = cStr; //Too many local refs could get created due to the loop, so delete them //CHECKME:cStr is also a local ref called in loop, but it's not clear if deleting it via DeleteLocalRef deletes the memory pointed by it too @@ -73,25 +77,45 @@ bool GetStringVector(vector& vec, JNIEnv *env, jobjectArray array, } bool GetFloatBlobVector(vector< Blob* >& vec, JNIEnv *env, jobjectArray array, int length) { - if (array == NULL) { - LOG(ERROR) << "array is NULL"; - return false; - } - + if (array == NULL) { + LOG(ERROR) << "array is NULL"; + return false; + } + for (int i = 0; i < length; i++) { - //get i-th FloatBlob object (JVM) - jobject object = env->GetObjectArrayElement(array, i); - if (object == NULL) { - LOG(WARNING) << i << "-th FloatBlob is NULL"; - vec[i] = NULL; - } else { - //find the native Blob object - vec[i] = (Blob*) GetNativeAddress(env, object); - } - //Too many local refs could get created due to the loop, so delete them - env->DeleteLocalRef(object); + //get i-th FloatBlob object (JVM) + jobject object = env->GetObjectArrayElement(array, i); + if (object == NULL || env->ExceptionCheck()) { + LOG(WARNING) << i << "-th FloatBlob is NULL"; + vec[i] = NULL; + } else { + try{ + //find the native Blob object + vec[i] = (Blob*) GetNativeAddress(env, object); + } catch (const std::exception& ex) { + ThrowJavaException(ex, env); + return false; + } } + //Too many local refs could get created due to the loop, so delete them + env->DeleteLocalRef(object); + if (env->ExceptionCheck()) { + LOG(ERROR) << "DeleteLocalRef failed"; + return false; + } + } + + return true; +} - return true; +void ThrowJavaException(const std::exception& ex, JNIEnv *env) { + char exMsg[sizeof(typeid(ex).name())+sizeof(ex.what())+3]; + sprintf(exMsg, "%s : %s", typeid(ex).name(), ex.what()); + jclass exClass = env->FindClass("java/lang/Exception"); + env->ThrowNew(exClass, exMsg); } +void ThrowCosJavaException(char* message, JNIEnv *env) { + jclass exClass = env->FindClass("java/lang/Exception"); + env->ThrowNew(exClass, message); +} diff --git a/caffe-distri/src/main/cpp/jni/JniCaffeNet.cpp b/caffe-distri/src/main/cpp/jni/JniCaffeNet.cpp index c07aea6..cf54661 100755 --- a/caffe-distri/src/main/cpp/jni/JniCaffeNet.cpp +++ b/caffe-distri/src/main/cpp/jni/JniCaffeNet.cpp @@ -20,47 +20,58 @@ JNIEXPORT jboolean JNICALL Java_com_yahoo_ml_jcaffe_CaffeNet_allocate jboolean isCopy_solver = false; const char* solver_conf_file_chars = env->GetStringUTFChars(solver_conf_file, &isCopy_solver); - if (solver_conf_file_chars == NULL) { + if (solver_conf_file_chars == NULL || env->ExceptionCheck()) { LOG(ERROR) << "solver_conf_file_chars == NULL"; return false; } jboolean isCopy_model = false; const char* model_file_chars = env->GetStringUTFChars(model_file, &isCopy_model); - if (model_file_chars == NULL) { + if (model_file_chars == NULL || env->ExceptionCheck()) { LOG(ERROR) << "model_file_chars == NULL"; return false; } jboolean isCopy_state = false; const char* state_file_chars = env->GetStringUTFChars(state_file, &isCopy_state); - if (state_file_chars == NULL) { + if (state_file_chars == NULL || env->ExceptionCheck()) { LOG(ERROR) << "state_file_chars == NUL"; return false; } - if (cluster_size ==1) + try { + if (cluster_size ==1) native_ptr = new LocalCaffeNet(solver_conf_file_chars, model_file_chars, state_file_chars, num_local_devices, isTraining, start_device_id); - else { - switch (connection_type) { + else { + switch (connection_type) { #ifdef INFINIBAND - case com_yahoo_ml_jcaffe_CaffeNet_RDMA: - native_ptr = new RDMACaffeNet(solver_conf_file_chars, + case com_yahoo_ml_jcaffe_CaffeNet_RDMA: + native_ptr = new RDMACaffeNet(solver_conf_file_chars, model_file_chars, state_file_chars, num_local_devices, cluster_size, myRank, isTraining, start_device_id); - break; + break; #endif - case com_yahoo_ml_jcaffe_CaffeNet_SOCKET: - native_ptr = new SocketCaffeNet(solver_conf_file_chars, + case com_yahoo_ml_jcaffe_CaffeNet_SOCKET: + native_ptr = new SocketCaffeNet(solver_conf_file_chars, model_file_chars, state_file_chars, num_local_devices, cluster_size, myRank, isTraining, start_device_id); - break; + break; + } } + } catch(const std::exception& ex) { + ThrowJavaException(ex, env); + return false; + } + + if (native_ptr == NULL) { + LOG(ERROR) << "unable to create CaffeNet object"; + return false; } + if (isCopy_solver) env->ReleaseStringUTFChars(solver_conf_file, solver_conf_file_chars); if (isCopy_model) @@ -68,6 +79,10 @@ JNIEXPORT jboolean JNICALL Java_com_yahoo_ml_jcaffe_CaffeNet_allocate if (isCopy_state) env->ReleaseStringUTFChars(state_file, state_file_chars); + if (env->ExceptionCheck()) { + LOG(ERROR) << "ReleaseStringUTFChars failed"; + return false; + } /* associate native object with JVM object */ return SetNativeAddress(env, object, native_ptr); } @@ -89,38 +104,56 @@ JNIEXPORT void JNICALL Java_com_yahoo_ml_jcaffe_CaffeNet_deallocate */ JNIEXPORT jobjectArray JNICALL Java_com_yahoo_ml_jcaffe_CaffeNet_localAddresses (JNIEnv *env, jobject object) { - CaffeNet* native_ptr = (CaffeNet*) GetNativeAddress(env, object); - - vector addrs; + if (object == NULL) { + LOG(ERROR) << "localAddresses object is NULL"; + return NULL; + } + + CaffeNet* native_ptr = NULL; + try { + native_ptr = (CaffeNet*) GetNativeAddress(env, object); + } catch(const std::exception& ex) { + ThrowJavaException(ex, env); + return NULL; + } + + vector addrs; + try { native_ptr->localAddresses(addrs); - - // Get a class reference for com.yahoo.ml.jcaffe..FloatBlob - jclass classString = env->FindClass("java/lang/String"); - - // Allocate a jobjectArray of com.yahoo.ml.jcaffe.FloatBlob - jsize len = addrs.size(); - jobjectArray outJNIArray = env->NewObjectArray(len, classString, NULL); - if (outJNIArray == NULL) { - LOG(ERROR) << "Unable to create a new array"; + } catch (const std::exception& ex) { + ThrowJavaException(ex, env); + return NULL; + } + + // Get a class reference for com.yahoo.ml.jcaffe..FloatBlob + jclass classString = env->FindClass("java/lang/String"); + if (env->ExceptionCheck()) { + LOG(ERROR) << "FindClass failed"; + return NULL; + } + // Allocate a jobjectArray of com.yahoo.ml.jcaffe.FloatBlob + jsize len = addrs.size(); + jobjectArray outJNIArray = env->NewObjectArray(len, classString, NULL); + if (outJNIArray == NULL || env->ExceptionCheck()) { + LOG(ERROR) << "Unable to create a new array"; + return NULL; + } + //construct a set of JVM String object + int i; + for (i=0; iNewStringUTF(addrs[i].c_str()); + if (str == NULL || env->ExceptionCheck()) { + LOG(ERROR) << "Unable to create new String"; return NULL; } - //construct a set of JVM String object - int i; - for (i=0; iNewStringUTF(addrs[i].c_str()); - if (str == NULL) { - LOG(ERROR) << "Unable to create new String"; - return NULL; - } - env->SetObjectArrayElement(outJNIArray, i, str); - if (env->ExceptionOccurred()) { - LOG(ERROR) << "Unable to set Array Elements"; - return NULL; - } + env->SetObjectArrayElement(outJNIArray, i, str); + if (env->ExceptionOccurred()) { + LOG(ERROR) << "Unable to set Array Elements"; + return NULL; } - - return outJNIArray; + } + return outJNIArray; } /* @@ -130,9 +163,15 @@ JNIEXPORT jobjectArray JNICALL Java_com_yahoo_ml_jcaffe_CaffeNet_localAddresses */ JNIEXPORT jboolean JNICALL Java_com_yahoo_ml_jcaffe_CaffeNet_sync (JNIEnv *env, jobject object) { - CaffeNet* native_ptr = (CaffeNet*) GetNativeAddress(env, object); + CaffeNet* native_ptr = NULL; + try { + native_ptr = (CaffeNet*) GetNativeAddress(env, object); native_ptr->sync(); - return true; + } catch (const std::exception& ex) { + ThrowJavaException(ex, env); + return false; + } + return true; } /* @@ -142,23 +181,44 @@ JNIEXPORT jboolean JNICALL Java_com_yahoo_ml_jcaffe_CaffeNet_sync */ JNIEXPORT jboolean JNICALL Java_com_yahoo_ml_jcaffe_CaffeNet_connect (JNIEnv *env, jobject object, jobjectArray address_array) { - CaffeNet* native_ptr = (CaffeNet*) GetNativeAddress(env, object); - - jsize length = (address_array==NULL? 0 : env->GetArrayLength(address_array)); - vector addresses(length); - if(!GetStringVector(addresses, env, address_array, length)) { - LOG(ERROR) << "Unable to retrieve StringVector"; + CaffeNet* native_ptr = NULL; + try { + native_ptr = (CaffeNet*) GetNativeAddress(env, object); + } catch(const std::exception& ex) { + ThrowJavaException(ex, env); + return false; + } + + jsize length = (address_array==NULL? 0 : env->GetArrayLength(address_array)); + vector addresses(length); + if(!GetStringVector(addresses, env, address_array, length)) { + LOG(ERROR) << "Unable to retrieve StringVector"; + return false; + } + + try { + if (!native_ptr->connect(addresses)) return false; + } catch (const std::exception& ex) { + ThrowJavaException(ex, env); + return false; + } + + for (int i = 0; i < length; i++) { + if(addresses[i] != NULL){ + jstring addr = (jstring)env->GetObjectArrayElement(address_array, i); + if (env->ExceptionCheck()) { + LOG(ERROR) << "GetObjectArrayElement failed"; + return false; + } + env->ReleaseStringUTFChars(addr, addresses[i]); + if (env->ExceptionCheck()) { + LOG(ERROR) << "ReleaseStringUTFChar failed"; + return false; + } } - - native_ptr->connect(addresses); - for (int i = 0; i < length; i++) { - if(addresses[i] != NULL){ - jstring addr = (jstring)env->GetObjectArrayElement(address_array, i); - env->ReleaseStringUTFChars(addr, addresses[i]); - } - } - return true; + } + return true; } /* @@ -168,9 +228,18 @@ JNIEXPORT jboolean JNICALL Java_com_yahoo_ml_jcaffe_CaffeNet_connect */ JNIEXPORT jint JNICALL Java_com_yahoo_ml_jcaffe_CaffeNet_deviceID (JNIEnv *env, jobject object, jint solver_index) { - - CaffeNet* native_ptr = (CaffeNet*) GetNativeAddress(env, object); + if (solver_index < 0) { + LOG(ERROR) << "Solver index invalid"; + return -1; + } + CaffeNet* native_ptr = NULL; + try { + native_ptr = (CaffeNet*) GetNativeAddress(env, object); return native_ptr->deviceID(solver_index); + } catch(const std::exception& ex) { + ThrowJavaException(ex, env); + return -1; + } } /* @@ -180,9 +249,18 @@ JNIEXPORT jint JNICALL Java_com_yahoo_ml_jcaffe_CaffeNet_deviceID */ JNIEXPORT jboolean JNICALL Java_com_yahoo_ml_jcaffe_CaffeNet_init (JNIEnv *env, jobject object, jint solver_index, jboolean enableNN) { - CaffeNet* native_ptr = (CaffeNet*) GetNativeAddress(env, object); - + CaffeNet* native_ptr = NULL; + if (solver_index < 0) { + LOG(ERROR) << "Solver index invalid"; + return false; + } + try { + native_ptr = (CaffeNet*) GetNativeAddress(env, object); return native_ptr->init(solver_index, enableNN); + } catch(const std::exception& ex) { + ThrowJavaException(ex, env); + return false; + } } /* @@ -192,86 +270,125 @@ JNIEXPORT jboolean JNICALL Java_com_yahoo_ml_jcaffe_CaffeNet_init */ JNIEXPORT jobjectArray JNICALL Java_com_yahoo_ml_jcaffe_CaffeNet_predict (JNIEnv *env, jobject object, jint solver_index, jobjectArray input_data, jobject input_labels, jobjectArray output_blobnames) { - CaffeNet* native_ptr = (CaffeNet*) GetNativeAddress(env, object); - - size_t length = env->GetArrayLength(input_data); - vector< Blob* > data_vec(length); - - if(!GetFloatBlobVector(data_vec, env, input_data, length)) { - LOG(ERROR) << "Could not get FoatBlob vector"; - return NULL; - } - - length = env->GetArrayLength(output_blobnames); - if (length==0) return NULL; - vector output_blobnames_chars(length); - - if(!GetStringVector(output_blobnames_chars, env, output_blobnames, length)){ - LOG(ERROR) << "Could not get String vector"; - return NULL; - } - - /* Get a reference to JVM object class */ - jclass claz = env->GetObjectClass(input_labels); - if (claz == NULL) { - LOG(ERROR) << "unable to get input_label's class (FloatArray)"; - return 0; - } - /* Getting the field id in the class */ - jfieldID fieldId = env->GetFieldID(claz, "arrayAddress", "J"); - if (fieldId == NULL) { - LOG(ERROR) << "could not locate field 'arrayAddress'"; - return 0; - } - - jfloat* labels = (jfloat*) env->GetLongField(input_labels, fieldId); - if (labels==NULL) { - LOG(ERROR) << "labels are NULL"; - return NULL; - } - vector* > results(length); + CaffeNet* native_ptr = NULL; + try { + native_ptr = (CaffeNet*) GetNativeAddress(env, object); + } catch(const std::exception& ex) { + ThrowJavaException(ex, env); + return NULL; + } + + if (input_data == NULL) { + LOG(ERROR) << "data is NULL"; + ThrowCosJavaException((char*)"data is NULL", env); + return NULL; + } + + size_t length = env->GetArrayLength(input_data); + if (env->ExceptionCheck()) { + LOG(ERROR) << "GetArrayLength failed"; + return NULL; + } + vector< Blob* > data_vec(length); + + if(!GetFloatBlobVector(data_vec, env, input_data, length)) { + LOG(ERROR) << "Could not get FoatBlob vector"; + return NULL; + } + + length = env->GetArrayLength(output_blobnames); + if (env->ExceptionCheck()) { + LOG(ERROR) << "GetArrayLength failed"; + return NULL; + } + if (length==0) return NULL; + vector output_blobnames_chars(length); + + if(!GetStringVector(output_blobnames_chars, env, output_blobnames, length)){ + LOG(ERROR) << "Could not get String vector"; + return NULL; + } + + if (input_labels == NULL) { + LOG(ERROR) << "labels is NULL"; + ThrowCosJavaException((char*)"labels is NULL", env); + return NULL; + } + + /* Get a reference to JVM object class */ + jclass claz = env->GetObjectClass(input_labels); + if (claz == NULL || env->ExceptionCheck()) { + LOG(ERROR) << "unable to get input_label's class (FloatArray)"; + return NULL; + } + /* Getting the field id in the class */ + jfieldID fieldId = env->GetFieldID(claz, "arrayAddress", "J"); + if (fieldId == NULL || env->ExceptionCheck()) { + LOG(ERROR) << "could not locate field 'arrayAddress'"; + return NULL; + } + + jfloat* labels = (jfloat*) env->GetLongField(input_labels, fieldId); + if (labels==NULL || env->ExceptionCheck()) { + LOG(ERROR) << "labels are NULL"; + return NULL; + } + vector* > results(length); + try { native_ptr->predict(solver_index, data_vec, labels, output_blobnames_chars, results); - - // Get a class reference for com.yahoo.ml.jcaffe.FloatBlob - jclass classFloatBlob = env->FindClass("com/yahoo/ml/jcaffe/FloatBlob"); - if (env->ExceptionOccurred()) { - LOG(ERROR) << "Unable to find class FloatBlob"; + } catch (const std::exception& ex) { + ThrowJavaException(ex, env); + return NULL; + } + + // Get a class reference for com.yahoo.ml.jcaffe.FloatBlob + jclass classFloatBlob = env->FindClass("com/yahoo/ml/jcaffe/FloatBlob"); + if (env->ExceptionOccurred()) { + LOG(ERROR) << "Unable to find class FloatBlob"; + return NULL; + } + jmethodID midFloatBlobInit = env->GetMethodID(classFloatBlob, "", "(JZ)V"); + if (midFloatBlobInit == NULL || env->ExceptionCheck()) { + LOG(ERROR) << "Unable to locate method init"; + return NULL; + } + // Allocate a jobjectArray of com.yahoo.ml.jcaffe.FloatBlob + jobjectArray outJNIArray = env->NewObjectArray(length, classFloatBlob, NULL); + if (outJNIArray == NULL || env->ExceptionCheck()) { + LOG(ERROR) << "Unable to allocate a new array"; + return NULL; + } + //construct a set of JVM FloatBlob object from native Blob + for (int i=0; i object + jobject obj = env->NewObject(classFloatBlob, midFloatBlobInit, results[i], false); + if (obj == NULL || env->ExceptionCheck()) { + LOG(ERROR) << "Unable to construct new object"; return NULL; } - jmethodID midFloatBlobInit = env->GetMethodID(classFloatBlob, "", "(JZ)V"); - if (midFloatBlobInit == NULL) { - LOG(ERROR) << "Unable to locate method init"; - return NULL; - } - // Allocate a jobjectArray of com.yahoo.ml.jcaffe.FloatBlob - jobjectArray outJNIArray = env->NewObjectArray(length, classFloatBlob, NULL); - if (outJNIArray == NULL) { - LOG(ERROR) << "Unable to allocate a new array"; + env->SetObjectArrayElement(outJNIArray, i, obj); + if (env->ExceptionOccurred()) { + LOG(ERROR) << "Unable to set Array Elements"; return NULL; } - //construct a set of JVM FloatBlob object from native Blob - for (int i=0; i object - jobject obj = env->NewObject(classFloatBlob, midFloatBlobInit, results[i], false); - if (obj == NULL) { - LOG(ERROR) << "Unable to construct new object"; - return NULL; - } - env->SetObjectArrayElement(outJNIArray, i, obj); - if (env->ExceptionOccurred()) { - LOG(ERROR) << "Unable to set Array Elements"; - return NULL; - } - } - - //release JNI objects - for (int i = 0; i < length; i++) { - if (output_blobnames_chars[i] != NULL) { - jstring output_blobname = (jstring)env->GetObjectArrayElement(output_blobnames, i); - env->ReleaseStringUTFChars(output_blobname, output_blobnames_chars[i]); - } + } + + //release JNI objects + for (int i = 0; i < length; i++) { + if (output_blobnames_chars[i] != NULL) { + jstring output_blobname = (jstring)env->GetObjectArrayElement(output_blobnames, i); + if (env->ExceptionCheck()) { + LOG(ERROR) << "GetObjectArrayElement failed"; + return NULL; + } + env->ReleaseStringUTFChars(output_blobname, output_blobnames_chars[i]); + if (env->ExceptionCheck()) { + LOG(ERROR) << "ReleaseStringUTFChars failed"; + return NULL; + } } - return outJNIArray; + } + return outJNIArray; } /* @@ -281,37 +398,59 @@ JNIEXPORT jobjectArray JNICALL Java_com_yahoo_ml_jcaffe_CaffeNet_predict */ JNIEXPORT jboolean JNICALL Java_com_yahoo_ml_jcaffe_CaffeNet_train (JNIEnv *env, jobject object, jint solver_index, jobjectArray input_data, jobject input_labels) { - CaffeNet* native_ptr = (CaffeNet*) GetNativeAddress(env, object); - - size_t length = (input_data != NULL? env->GetArrayLength(input_data) : 0); - vector< Blob* > data_vec(length); - /* Get a reference to JVM object class */ - jclass claz = env->GetObjectClass(input_labels); - if (claz == NULL) { - LOG(ERROR) << "unable to get input_label's class (FloatArray)"; - return 0; - } - /* Getting the field id in the class */ - jfieldID fieldId = env->GetFieldID(claz, "arrayAddress", "J"); - if (fieldId == NULL) { - LOG(ERROR) << "could not locate field 'arrayAddress'"; - return 0; - } - - jfloat* labels = (jfloat*) env->GetLongField(input_labels, fieldId); - if (labels==NULL) { - LOG(ERROR) << "labels are NULL"; - return false; - } - - if(!GetFloatBlobVector(data_vec, env, input_data, length)) { - LOG(ERROR) << "Could not retrieve FloatBlobVector"; - return false; - } - + CaffeNet* native_ptr = NULL; + try { + native_ptr = (CaffeNet*) GetNativeAddress(env, object); + } catch(const std::exception& ex) { + ThrowJavaException(ex, env); + return false; + } + + if (input_data == NULL) { + LOG(ERROR) << "data is NULL"; + ThrowCosJavaException((char*)"data is NULL", env); + return false; + } + + size_t length = (input_data != NULL? env->GetArrayLength(input_data) : 0); + vector< Blob* > data_vec(length); + /* Get a reference to JVM object class */ + jclass claz = env->GetObjectClass(input_labels); + if (claz == NULL || env->ExceptionCheck()) { + LOG(ERROR) << "unable to get input_label's class (FloatArray)"; + return false; + } + /* Getting the field id in the class */ + jfieldID fieldId = env->GetFieldID(claz, "arrayAddress", "J"); + if (fieldId == NULL || env->ExceptionCheck()) { + LOG(ERROR) << "could not locate field 'arrayAddress'"; + return false; + } + + if (input_labels == NULL) { + LOG(ERROR) << "labels is NULL"; + ThrowCosJavaException((char*)"label is NULL", env); + return false; + } + + jfloat* labels = (jfloat*) env->GetLongField(input_labels, fieldId); + if (labels==NULL || env->ExceptionCheck()) { + LOG(ERROR) << "labels are NULL"; + return false; + } + + if(!GetFloatBlobVector(data_vec, env, input_data, length)) { + LOG(ERROR) << "Could not retrieve FloatBlobVector"; + return false; + } + + try { native_ptr->train(solver_index, data_vec, labels); - - return true; + } catch (const std::exception& ex) { + ThrowJavaException(ex, env); + return false; + } + return true; } /* @@ -321,9 +460,19 @@ JNIEXPORT jboolean JNICALL Java_com_yahoo_ml_jcaffe_CaffeNet_train */ JNIEXPORT jint JNICALL Java_com_yahoo_ml_jcaffe_CaffeNet_getInitIter (JNIEnv *env, jobject object, jint solver_index) { - CaffeNet* native_ptr = (CaffeNet*) GetNativeAddress(env, object); - + if (solver_index < 0) { + LOG(ERROR) << "Solver index invalid"; + return -1; + } + + CaffeNet* native_ptr = NULL; + try { + native_ptr = (CaffeNet*) GetNativeAddress(env, object); return native_ptr->getInitIter(solver_index); + } catch(const std::exception& ex) { + ThrowJavaException(ex, env); + return -1; + } } /* @@ -333,9 +482,18 @@ JNIEXPORT jint JNICALL Java_com_yahoo_ml_jcaffe_CaffeNet_getInitIter */ JNIEXPORT jint JNICALL Java_com_yahoo_ml_jcaffe_CaffeNet_getMaxIter (JNIEnv *env, jobject object, jint solver_index) { - CaffeNet* native_ptr = (CaffeNet*) GetNativeAddress(env, object); - + if (solver_index < 0) { + LOG(ERROR) << "Solver index invalid"; + return -1; + } + CaffeNet* native_ptr = NULL; + try { + native_ptr = (CaffeNet*) GetNativeAddress(env, object); return native_ptr->getMaxIter(solver_index); + } catch(const std::exception& ex) { + ThrowJavaException(ex, env); + return -1; + } } /* @@ -345,10 +503,18 @@ JNIEXPORT jint JNICALL Java_com_yahoo_ml_jcaffe_CaffeNet_getMaxIter */ JNIEXPORT jint JNICALL Java_com_yahoo_ml_jcaffe_CaffeNet_snapshot (JNIEnv *env, jobject object) { - - CaffeNet* native_ptr = (CaffeNet*) GetNativeAddress(env, object); - + if (object == NULL) { + LOG(ERROR) << "Snapshot object is NULL"; + return -1; + } + CaffeNet* native_ptr = NULL; + try { + native_ptr = (CaffeNet*) GetNativeAddress(env, object); return native_ptr->snapshot(); + } catch(const std::exception& ex) { + ThrowJavaException(ex, env); + return -1; + } } /* @@ -358,7 +524,23 @@ JNIEXPORT jint JNICALL Java_com_yahoo_ml_jcaffe_CaffeNet_snapshot */ JNIEXPORT jstring JNICALL Java_com_yahoo_ml_jcaffe_CaffeNet_getTestOutputBlobNames (JNIEnv *env, jobject object) { - CaffeNet* native_ptr = (CaffeNet*) GetNativeAddress(env, object); - string blob_names = native_ptr->getTestOutputBlobNames(); - return env->NewStringUTF(blob_names.c_str()); + if (object == NULL) { + LOG(ERROR) << "NULL object for OutputBlobNames"; + return NULL; + } + CaffeNet* native_ptr = NULL; + try { + native_ptr = (CaffeNet*) GetNativeAddress(env, object); + } catch(const std::exception& ex) { + ThrowJavaException(ex, env); + return NULL; + } + string blob_names; + try { + blob_names = native_ptr->getTestOutputBlobNames(); + } catch (const std::exception& ex) { + ThrowJavaException(ex, env); + return NULL; + } + return env->NewStringUTF(blob_names.c_str()); } diff --git a/caffe-distri/src/main/cpp/jni/JniFloatArray.cpp b/caffe-distri/src/main/cpp/jni/JniFloatArray.cpp index 1a2b25e..dc062d9 100755 --- a/caffe-distri/src/main/cpp/jni/JniFloatArray.cpp +++ b/caffe-distri/src/main/cpp/jni/JniFloatArray.cpp @@ -13,43 +13,64 @@ * Signature: (I)F */ JNIEXPORT jfloat JNICALL Java_com_yahoo_ml_jcaffe_FloatArray_get(JNIEnv *env, jobject object, jint index){ - /* Get a reference to JVM object class */ - jclass claz = env->GetObjectClass(object); - if (claz == NULL) { - LOG(ERROR) << "unable to get object's class"; - return 0; - } - /* Getting the field id in the class */ - jfieldID fieldId = env->GetFieldID(claz, "arrayAddress", "J"); - if (fieldId == NULL) { - LOG(ERROR) << "could not locate field 'arrayAddress'"; - return 0; - } - - jfloat* float_array_ptr = (jfloat*) env->GetLongField(object, fieldId); - return float_array_ptr[index]; + /* Get a reference to JVM object class */ + jclass claz = env->GetObjectClass(object); + if (claz == NULL || env->ExceptionCheck()) { + LOG(ERROR) << "unable to get object's class"; + return 0; + } + /* Getting the field id in the class */ + jfieldID fieldId = env->GetFieldID(claz, "arrayAddress", "J"); + if (fieldId == NULL || env->ExceptionCheck()) { + LOG(ERROR) << "could not locate field 'arrayAddress'"; + return 0; + } + + jfloat* float_array_ptr = (jfloat*) env->GetLongField(object, fieldId); + if (env->ExceptionCheck()) { + LOG(ERROR) << "GetLongField failed"; + return 0; + } + if (index < 0) { + LOG(ERROR) << "negative index"; + return 0; + } + return float_array_ptr[index]; } /* * Class: com_yahoo_ml_jcaffe_FloatArray * Method: set * Signature: (IF)V - */ +*/ JNIEXPORT void JNICALL Java_com_yahoo_ml_jcaffe_FloatArray_set(JNIEnv *env, jobject object, jint index, jfloat data){ /* Get a reference to JVM object class */ - jclass claz = env->GetObjectClass(object); - if (claz == NULL) { - LOG(ERROR) << "unable to get object's class"; - return; - } - /* Getting the field id in the class */ - jfieldID fieldId = env->GetFieldID(claz, "arrayAddress", "J"); - if (fieldId == NULL) { - LOG(ERROR) << "could not locate field 'arrayAddress'"; - return; - } + jclass claz = env->GetObjectClass(object); + if (claz == NULL || env->ExceptionCheck()) { + LOG(ERROR) << "unable to get object's class"; + ThrowCosJavaException((char*)"unable to get object's class", env); + return; + } + /* Getting the field id in the class */ + jfieldID fieldId = env->GetFieldID(claz, "arrayAddress", "J"); + if (fieldId == NULL || env->ExceptionCheck()) { + LOG(ERROR) << "could not locate field 'arrayAddress'"; + ThrowCosJavaException((char*)"could not locate field 'arrayAddress'", env); + return; + } + + jfloat* float_array_ptr = (jfloat*) env->GetLongField(object, fieldId); + if (env->ExceptionCheck()) { + LOG(ERROR) << "GetLongField failed"; + ThrowCosJavaException((char*)"GetLongField failed", env); + return; + } + if (index < 0) { + LOG(ERROR) << "Invalid index for floatarray"; + ThrowCosJavaException((char*)"Invalid index for floatarray" , env); + return; + } - jfloat* float_array_ptr = (jfloat*) env->GetLongField(object, fieldId); - float_array_ptr[index] = data; + float_array_ptr[index] = data; } diff --git a/caffe-distri/src/main/cpp/jni/JniFloatBlob.cpp b/caffe-distri/src/main/cpp/jni/JniFloatBlob.cpp index d153346..985b47c 100755 --- a/caffe-distri/src/main/cpp/jni/JniFloatBlob.cpp +++ b/caffe-distri/src/main/cpp/jni/JniFloatBlob.cpp @@ -12,11 +12,21 @@ * Signature: ()V */ JNIEXPORT jboolean JNICALL Java_com_yahoo_ml_jcaffe_FloatBlob_allocate(JNIEnv *env, jobject object) { - /* create a native FloatBlob object */ - Blob* native_ptr = new Blob(); + /* create a native FloatBlob object */ + Blob* native_ptr = NULL; + try { + native_ptr = new Blob(); + } catch(const std::exception& ex) { + ThrowJavaException(ex, env); + return false; + } - /* associate native object with JVM object */ - return SetNativeAddress(env, object, native_ptr); + if (native_ptr == NULL) { + LOG(ERROR) << "Unable to allocate memory for Blob"; + return false; + } + /* associate native object with JVM object */ + return SetNativeAddress(env, object, native_ptr); } /* @@ -33,9 +43,14 @@ JNIEXPORT void JNICALL Java_com_yahoo_ml_jcaffe_FloatBlob_deallocate1(JNIEnv *en JNIEXPORT jint JNICALL Java_com_yahoo_ml_jcaffe_FloatBlob_count (JNIEnv *env, jobject object) { - Blob* native_ptr = (Blob*) GetNativeAddress(env, object); - + Blob* native_ptr = NULL; + try { + native_ptr = (Blob*) GetNativeAddress(env, object); return native_ptr->count(); + } catch (const std::exception& ex) { + ThrowJavaException(ex, env); + return -1; + } } /* @@ -45,29 +60,48 @@ JNIEXPORT jint JNICALL Java_com_yahoo_ml_jcaffe_FloatBlob_count */ JNIEXPORT jboolean JNICALL Java_com_yahoo_ml_jcaffe_FloatBlob_reshape (JNIEnv *env, jobject object, jintArray shape) { - Blob* native_ptr = (Blob*) GetNativeAddress(env, object); - - size_t size = env->GetArrayLength(shape); - jint *vals = env->GetIntArrayElements(shape, NULL); - if (vals == NULL) { - LOG(ERROR) << "vals == NULL"; + Blob* native_ptr = NULL; + try { + native_ptr = (Blob*) GetNativeAddress(env, object); + } catch (const std::exception& ex) { + ThrowJavaException(ex, env); + return false; + } + + size_t size = env->GetArrayLength(shape); + if (env->ExceptionCheck()) { + LOG(ERROR) << "GetArrayLength failed"; + return false; + } + jint *vals = env->GetIntArrayElements(shape, NULL); + if (vals == NULL || env->ExceptionCheck()) { + LOG(ERROR) << "vals == NULL"; + return false; + } + + vector shap_vec(size); + for (int i=0; i shap_vec(size); - for (int i=0; iReshape(shap_vec); - - //release JNI objects - env->ReleaseIntArrayElements(shape, vals, JNI_ABORT); - if (env->ExceptionOccurred()) { - LOG(ERROR) << "Unable to release Array Elements"; - return false; - } - - return true; + } catch (const std::exception& ex) { + ThrowJavaException(ex, env); + return false; + } + //release JNI objects + env->ReleaseIntArrayElements(shape, vals, JNI_ABORT); + if (env->ExceptionOccurred()) { + LOG(ERROR) << "Unable to release Array Elements"; + return false; + } + + return true; } /* @@ -76,14 +110,22 @@ JNIEXPORT jboolean JNICALL Java_com_yahoo_ml_jcaffe_FloatBlob_reshape * Signature: (Lcom/yahoo/ml/jcaffe/FloatBlob;)V */ JNIEXPORT jboolean JNICALL Java_com_yahoo_ml_jcaffe_FloatBlob_copyFrom(JNIEnv *env, jobject object, jobject source) { - Blob* native_ptr = (Blob*) GetNativeAddress(env, object); - - Blob* source_ptr = (Blob*) GetNativeAddress(env, source); - + Blob* native_ptr = NULL; + Blob* source_ptr = NULL; + if (source == NULL) { + LOG(ERROR) << "source is NULL"; + return false; + } + try { + native_ptr = (Blob*) GetNativeAddress(env, object); + source_ptr = (Blob*) GetNativeAddress(env, source); //perform operation native_ptr->CopyFrom(*source_ptr); - - return true; + } catch (const std::exception& ex) { + ThrowJavaException(ex, env); + return false; + } + return true; } /* @@ -92,20 +134,34 @@ JNIEXPORT jboolean JNICALL Java_com_yahoo_ml_jcaffe_FloatBlob_copyFrom(JNIEnv *e * Signature: ()[F */ JNIEXPORT jobject JNICALL Java_com_yahoo_ml_jcaffe_FloatBlob_cpu_1data(JNIEnv *env, jobject object) { - Blob* native_ptr = (Blob*) GetNativeAddress(env, object); - + Blob* native_ptr = NULL; + try { + native_ptr = (Blob*) GetNativeAddress(env, object); + } catch (const std::exception& ex) { + ThrowJavaException(ex, env); + return NULL; + } + jfloat* cpu_data = NULL; + try { //retrieve cpu_data() - jfloat* cpu_data = (jfloat*) native_ptr->mutable_cpu_data(); - if (cpu_data == NULL) { - LOG(ERROR) << "cpu_data == NULL"; - return NULL; - } - - jclass claz = env->FindClass("com/yahoo/ml/jcaffe/FloatArray"); - jmethodID constructorId = env->GetMethodID( claz, "", "(J)V"); - jobject objectFloatArray = env->NewObject(claz,constructorId,(long)cpu_data); - - return objectFloatArray; + cpu_data = (jfloat*) native_ptr->mutable_cpu_data(); + } catch(const std::exception& ex) { + ThrowJavaException(ex, env); + return NULL; + } + if (cpu_data == NULL) { + LOG(ERROR) << "cpu_data == NULL"; + return NULL; + } + + jclass claz = env->FindClass("com/yahoo/ml/jcaffe/FloatArray"); + jmethodID constructorId = env->GetMethodID( claz, "", "(J)V"); + jobject objectFloatArray = env->NewObject(claz,constructorId,(long)cpu_data); + if (env->ExceptionCheck()) { + LOG(ERROR) << "FloatArray object creation failed"; + return NULL; + } + return objectFloatArray; } /* @@ -115,34 +171,55 @@ JNIEXPORT jobject JNICALL Java_com_yahoo_ml_jcaffe_FloatBlob_cpu_1data(JNIEnv *e */ JNIEXPORT jlong JNICALL Java_com_yahoo_ml_jcaffe_FloatBlob_set_1cpu_1data(JNIEnv *env, jobject object, jfloatArray array, jlong dataaddress) { - Blob* native_ptr = (Blob*) GetNativeAddress(env, object); - - jboolean copied = false; - float* data = env->GetFloatArrayElements(array, &copied); - if (data == NULL) { - LOG(ERROR) << "GetFloatArrayElements() == NULL"; - return 0; + Blob* native_ptr = NULL; + try { + native_ptr = (Blob*) GetNativeAddress(env, object); + } catch (const std::exception& ex) { + ThrowJavaException(ex, env); + return 0; + } + + jboolean copied = false; + if (array == NULL) { + LOG(ERROR) << "input array is NULL"; + return 0; + } + float* data = env->GetFloatArrayElements(array, &copied); + if (data == NULL || env->ExceptionCheck()) { + LOG(ERROR) << "GetFloatArrayElements() == NULL"; + return 0; + } + + if (!copied) { + size_t len = 0; + try { + len = native_ptr->count(); + } catch (const std::exception& ex) { + ThrowJavaException(ex, env); + return 0; } - - if (!copied) { - size_t len = native_ptr->count(); - float* new_data = new float[len]; - if (new_data == NULL) { - LOG(ERROR) << "fail to float[] for new data"; - return 0; - } - - memcpy(new_data, data, len * sizeof(float)); - //set new data - data = new_data; + float* new_data = new float[len]; + if (new_data == NULL) { + LOG(ERROR) << "fail to float[] for new data"; + return 0; } - + + memcpy(new_data, data, len * sizeof(float)); + //set new data + data = new_data; + } + + try { native_ptr->set_cpu_data(data); - - if(dataaddress) - delete (jbyte*) dataaddress; - - return (long) data; + } catch (const std::exception& ex) { + ThrowJavaException(ex, env); + return 0; + } + + if(dataaddress) + delete (jbyte*) dataaddress; + + return (long) data; } /* @@ -151,18 +228,33 @@ JNIEXPORT jlong JNICALL Java_com_yahoo_ml_jcaffe_FloatBlob_set_1cpu_1data(JNIEnv * Signature: ()[F */ JNIEXPORT jobject JNICALL Java_com_yahoo_ml_jcaffe_FloatBlob_gpu_1data(JNIEnv *env, jobject object) { - Blob* native_ptr = (Blob*) GetNativeAddress(env, object); - + Blob* native_ptr = NULL; + try { + native_ptr = (Blob*) GetNativeAddress(env, object); + } catch (const std::exception& ex) { + ThrowJavaException(ex, env); + return NULL; + } + jfloat* gpu_data = NULL; + try { //retrieve gpu_data() - jfloat* gpu_data = (jfloat*)native_ptr->mutable_gpu_data(); - if (gpu_data == NULL) { - LOG(ERROR) << "gpu_data == NULL"; - return NULL; - } - - jclass claz = env->FindClass("com/yahoo/ml/jcaffe/FloatArray"); - jmethodID constructorId = env->GetMethodID( claz, "", "(J)V"); - jobject objectFloatArray = env->NewObject(claz,constructorId,(long)gpu_data); + gpu_data = (jfloat*)native_ptr->mutable_gpu_data(); + } catch(const std::exception& ex) { + ThrowJavaException(ex, env); + return NULL; + } + if (gpu_data == NULL || env->ExceptionCheck()) { + LOG(ERROR) << "gpu_data == NULL"; + return NULL; + } + + jclass claz = env->FindClass("com/yahoo/ml/jcaffe/FloatArray"); + jmethodID constructorId = env->GetMethodID( claz, "", "(J)V"); + jobject objectFloatArray = env->NewObject(claz,constructorId,(long)gpu_data); + if (env->ExceptionCheck()) { + LOG(ERROR) << "FloatArray object creation failed"; + return NULL; + } - return objectFloatArray; + return objectFloatArray; } diff --git a/caffe-distri/src/main/cpp/jni/JniFloatDataTransformer.cpp b/caffe-distri/src/main/cpp/jni/JniFloatDataTransformer.cpp index a9a1383..d126f92 100755 --- a/caffe-distri/src/main/cpp/jni/JniFloatDataTransformer.cpp +++ b/caffe-distri/src/main/cpp/jni/JniFloatDataTransformer.cpp @@ -17,29 +17,53 @@ JNIEXPORT jboolean JNICALL Java_com_yahoo_ml_jcaffe_FloatDataTransformer_allocate (JNIEnv *env, jobject object, jstring xform_param_str, jboolean isTrain) { - TransformationParameter param; - jboolean isCopy = false; - const char* xform_chars = env->GetStringUTFChars(xform_param_str, &isCopy); + TransformationParameter param; + jboolean isCopy = false; + const char* xform_chars = env->GetStringUTFChars(xform_param_str, &isCopy); + if (env->ExceptionCheck()) { + LOG(ERROR) << "GetStringUTFChars failed"; + return false; + } + try { google::protobuf::TextFormat::ParseFromString(string(xform_chars), ¶m); - - DataTransformer* xformer = NULL; + } catch (const std::exception& ex) { + ThrowJavaException(ex, env); + return false; + } + + DataTransformer* xformer = NULL; + try { if (isTrain) - xformer = new DataTransformer(param, TRAIN); + xformer = new DataTransformer(param, TRAIN); else - xformer = new DataTransformer(param, TEST); + xformer = new DataTransformer(param, TEST); + } catch (const std::exception& ex) { + ThrowJavaException(ex, env); + return false; + } + if (xformer == NULL) { + LOG(ERROR) << "unable to create DataTransformer"; + return false; + } + + try { //initialize randomizer xformer->InitRand(); - - if (isCopy){ - env->ReleaseStringUTFChars(xform_param_str, xform_chars); - if (env->ExceptionOccurred()) { - LOG(ERROR) << "Unable to release String"; - return false; - } + } catch (const std::exception& ex) { + ThrowJavaException(ex, env); + return false; + } + + if (isCopy){ + env->ReleaseStringUTFChars(xform_param_str, xform_chars); + if (env->ExceptionOccurred()) { + LOG(ERROR) << "Unable to release String"; + return false; } - /* associate native object with JVM object */ - return SetNativeAddress(env, object, xformer); + } + /* associate native object with JVM object */ + return SetNativeAddress(env, object, xformer); } /* @@ -49,7 +73,7 @@ JNIEXPORT jboolean JNICALL Java_com_yahoo_ml_jcaffe_FloatDataTransformer_allocat */ JNIEXPORT void JNICALL Java_com_yahoo_ml_jcaffe_FloatDataTransformer_deallocate (JNIEnv *env, jobject object, jlong native_ptr) { - delete (DataTransformer*) native_ptr; + delete (DataTransformer*) native_ptr; } /* @@ -60,11 +84,22 @@ JNIEXPORT void JNICALL Java_com_yahoo_ml_jcaffe_FloatDataTransformer_deallocate JNIEXPORT void JNICALL Java_com_yahoo_ml_jcaffe_FloatDataTransformer_transform (JNIEnv *env, jobject object, jobject matVec, jobject transformed_blob) { - DataTransformer* xformer = (DataTransformer*) GetNativeAddress(env, object); - - vector* mat_vector_ptr = (vector*) GetNativeAddress(env, matVec); - - Blob* blob_ptr = (Blob*) GetNativeAddress(env, transformed_blob); - + DataTransformer* xformer = NULL; + + vector* mat_vector_ptr = NULL; + + Blob* blob_ptr = NULL; + if (matVec == NULL || transformed_blob == NULL) { + ThrowCosJavaException((char*)"NULL object", env); + return; + } + try { + xformer = (DataTransformer*) GetNativeAddress(env, object); + mat_vector_ptr = (vector*) GetNativeAddress(env, matVec); + blob_ptr = (Blob*) GetNativeAddress(env, transformed_blob); xformer->Transform((* mat_vector_ptr), blob_ptr); + } catch (const std::exception& ex) { + ThrowJavaException(ex, env); + return; + } } diff --git a/caffe-distri/src/main/cpp/jni/JniMat.cpp b/caffe-distri/src/main/cpp/jni/JniMat.cpp index 64b26d2..12cd189 100755 --- a/caffe-distri/src/main/cpp/jni/JniMat.cpp +++ b/caffe-distri/src/main/cpp/jni/JniMat.cpp @@ -16,34 +16,42 @@ JNIEXPORT jlong JNICALL Java_com_yahoo_ml_jcaffe_Mat_allocate (JNIEnv *env, jobject object, jint channels, jint height, jint width, jbyteArray array) { - jboolean isCopy = false; - jbyte *data = env->GetByteArrayElements(array, &isCopy); - - if (data==NULL) { - LOG(ERROR) << "invalid data array"; + jboolean isCopy = false; + jbyte *data = env->GetByteArrayElements(array, &isCopy); + if (data==NULL || env->ExceptionCheck()) { + LOG(ERROR) << "invalid data array"; + return 0; + } + + if (!isCopy) { + jsize len = env->GetArrayLength(array); + if (env->ExceptionCheck()) { + LOG(ERROR) << "GetArrayLength failed"; return 0; } - - if (!isCopy) { - jsize len = env->GetArrayLength(array); - jbyte* new_data = new jbyte[len]; - if (new_data == NULL) { - LOG(ERROR) << "fail to jbyte[] for new data"; - return 0; - } - - memcpy(new_data, data, len * sizeof(jbyte)); - //set new data - data = new_data; + jbyte* new_data = new jbyte[len]; + if (new_data == NULL) { + LOG(ERROR) << "fail to jbyte[] for new data"; + return 0; } + + memcpy(new_data, data, len * sizeof(jbyte)); + //set new data + data = new_data; + } + cv::Mat* native_ptr = NULL; + try { /* create a native Mat object */ - cv::Mat* native_ptr = new cv::Mat(height, width, CV_8UC(channels), data); - - /* associate native object with JVM object */ - SetNativeAddress(env, object, native_ptr); - - return (long) data; + native_ptr = new cv::Mat(height, width, CV_8UC(channels), data); + } catch (std::exception& ex) { + ThrowJavaException(ex, env); + return 0; + } + /* associate native object with JVM object */ + SetNativeAddress(env, object, native_ptr); + + return (long) data; } /* @@ -52,13 +60,13 @@ JNIEXPORT jlong JNICALL Java_com_yahoo_ml_jcaffe_Mat_allocate * Signature: (JZ)V */ JNIEXPORT void JNICALL Java_com_yahoo_ml_jcaffe_Mat_deallocate - (JNIEnv *env, jobject object, jlong native_ptr, jlong dataaddress) { - - //Mat object is only one responsible for cleaning itself and it's data - if(dataaddress){ - delete[] (jbyte*)dataaddress; - } - delete (cv::Mat*) native_ptr; + (JNIEnv *env, jobject object, jlong native_ptr, jlong dataaddress) { + + //Mat object is only one responsible for cleaning itself and it's data + if(dataaddress){ + delete (jbyte*)dataaddress; + } + delete (cv::Mat*) native_ptr; } /* @@ -69,19 +77,24 @@ JNIEXPORT void JNICALL Java_com_yahoo_ml_jcaffe_Mat_deallocate JNIEXPORT void JNICALL Java_com_yahoo_ml_jcaffe_Mat_decode (JNIEnv *env, jobject object, jint flags, jlong dataaddress) { - cv::Mat* native_ptr = (cv::Mat*) GetNativeAddress(env, object); - - cv::imdecode(cv::_InputArray(*native_ptr), flags, native_ptr); - - jclass claz = env->GetObjectClass(object); - if (claz == NULL) { - LOG(ERROR) << "unable to get object's class"; - return; - } - - if (dataaddress){ - delete (jbyte*)dataaddress; - } + try{ + cv::Mat* native_ptr = (cv::Mat*) GetNativeAddress(env, object); + cv::imdecode(cv::_InputArray(*native_ptr), flags, native_ptr); + } catch (const std::exception& ex) { + ThrowJavaException(ex, env); + return; + } + + jclass claz = env->GetObjectClass(object); + if (claz == NULL || env->ExceptionCheck()) { + LOG(ERROR) << "unable to get object's class"; + ThrowCosJavaException((char*)"unable to get the object's class", env); + return; + } + + if (dataaddress){ + delete (jbyte*)dataaddress; + } } /* * Class: com_yahoo_ml_jcaffe_Mat @@ -90,20 +103,30 @@ JNIEXPORT void JNICALL Java_com_yahoo_ml_jcaffe_Mat_decode */ JNIEXPORT void JNICALL Java_com_yahoo_ml_jcaffe_Mat_resize (JNIEnv *env, jobject object, jint height, jint width, jlong dataaddress) { - cv::Mat* native_ptr = (cv::Mat*) GetNativeAddress(env, object); - - cv::Size size(width, height); - cv::resize(cv::_InputArray(*native_ptr), cv::_OutputArray(*native_ptr), size, 0, 0, cv::INTER_LINEAR); - jclass claz = env->GetObjectClass(object); - if (claz == NULL) { - LOG(ERROR) << "unable to get object's class"; - return; - } - - if (dataaddress){ - delete (jbyte*)dataaddress; - } + if (height < 0 || width < 0) { + ThrowCosJavaException((char*)"invalid dimensions to resize", env); + return; + } + try{ + cv::Mat* native_ptr = (cv::Mat*) GetNativeAddress(env, object); + cv::Size size(width, height); + cv::resize(cv::_InputArray(*native_ptr), cv::_OutputArray(*native_ptr), size, 0, 0, cv::INTER_LINEAR); + } catch (const std::exception& ex) { + ThrowJavaException(ex, env); + return; + } + + jclass claz = env->GetObjectClass(object); + if (claz == NULL || env->ExceptionCheck()) { + LOG(ERROR) << "unable to get object's class"; + ThrowCosJavaException((char*)"unable to get the object's class", env); + return; + } + + if (dataaddress){ + delete (jbyte*)dataaddress; + } } /* @@ -114,9 +137,14 @@ JNIEXPORT void JNICALL Java_com_yahoo_ml_jcaffe_Mat_resize JNIEXPORT jint JNICALL Java_com_yahoo_ml_jcaffe_Mat_height (JNIEnv *env, jobject object) { - cv::Mat* native_ptr = (cv::Mat*) GetNativeAddress(env, object); - + cv::Mat* native_ptr = NULL; + try { + native_ptr = (cv::Mat*) GetNativeAddress(env, object); return native_ptr->rows; + } catch(const std::exception& ex) { + ThrowJavaException(ex, env); + return -1; + } } /* @@ -126,9 +154,14 @@ JNIEXPORT jint JNICALL Java_com_yahoo_ml_jcaffe_Mat_height */ JNIEXPORT jint JNICALL Java_com_yahoo_ml_jcaffe_Mat_width (JNIEnv *env, jobject object) { - cv::Mat* native_ptr = (cv::Mat*) GetNativeAddress(env, object); - + cv::Mat* native_ptr = NULL; + try { + native_ptr = (cv::Mat*) GetNativeAddress(env, object); return native_ptr->cols; + } catch(const std::exception& ex) { + ThrowJavaException(ex, env); + return -1; + } } /* @@ -139,16 +172,27 @@ JNIEXPORT jint JNICALL Java_com_yahoo_ml_jcaffe_Mat_width JNIEXPORT jbyteArray JNICALL Java_com_yahoo_ml_jcaffe_Mat_data (JNIEnv *env, jobject object) { - cv::Mat* native_ptr = (cv::Mat*) GetNativeAddress(env, object); - int size = native_ptr->total() * native_ptr->elemSize(); - - jbyteArray dataarray = env->NewByteArray(size); - if(dataarray == NULL){ - LOG(ERROR) << "Out of memory while allocating array for Mat data" ; - return NULL; - } - env->SetByteArrayRegion(dataarray,0, size, (jbyte*)native_ptr->data); - return dataarray; + cv::Mat* native_ptr = NULL; + int size = 0; + try { + native_ptr = (cv::Mat*) GetNativeAddress(env, object); + size = native_ptr->total() * native_ptr->elemSize(); + } catch(const std::exception& ex) { + ThrowJavaException(ex, env); + return NULL; + } + + jbyteArray dataarray = env->NewByteArray(size); + if(dataarray == NULL || env->ExceptionCheck()){ + LOG(ERROR) << "Out of memory while allocating array for Mat data" ; + return NULL; + } + env->SetByteArrayRegion(dataarray,0, size, (jbyte*)native_ptr->data); + if (env->ExceptionCheck()) { + LOG(ERROR) << "SetByteArrayRegion failed"; + return NULL; + } + return dataarray; } /* @@ -159,7 +203,12 @@ JNIEXPORT jbyteArray JNICALL Java_com_yahoo_ml_jcaffe_Mat_data JNIEXPORT jint JNICALL Java_com_yahoo_ml_jcaffe_Mat_channels (JNIEnv *env, jobject object) { - cv::Mat* native_ptr = (cv::Mat*) GetNativeAddress(env, object); - - return native_ptr->channels(); + cv::Mat* native_ptr = NULL; + try { + native_ptr = (cv::Mat*) GetNativeAddress(env, object); + return native_ptr->channels(); + } catch(const std::exception& ex) { + ThrowJavaException(ex, env); + return -1; + } } diff --git a/caffe-distri/src/main/cpp/jni/JniMatVector.cpp b/caffe-distri/src/main/cpp/jni/JniMatVector.cpp index a46d9bd..2860ea0 100755 --- a/caffe-distri/src/main/cpp/jni/JniMatVector.cpp +++ b/caffe-distri/src/main/cpp/jni/JniMatVector.cpp @@ -15,11 +15,27 @@ JNIEXPORT jboolean JNICALL Java_com_yahoo_ml_jcaffe_MatVector_allocate (JNIEnv *env, jobject object, jint size) { - /* create a native vector object */ - vector* native_ptr = new vector(size); - - /* associate native object with JVM object */ - return SetNativeAddress(env, object, native_ptr); + /* create a native vector object */ + vector* native_ptr = NULL; + if (size < 0) { + LOG(ERROR) << "Negative MatVector size specified"; + ThrowCosJavaException((char*)"Negative MatVector size specified", env); + return false; + } + + try { + native_ptr = new vector(size); + } catch(const std::exception& ex) { + ThrowJavaException(ex, env); + return false; + } + + if (native_ptr == NULL) { + LOG(ERROR) << "unable to allocate memory for vector of Mats"; + return false; + } + /* associate native object with JVM object */ + return SetNativeAddress(env, object, native_ptr); } /* @@ -29,9 +45,9 @@ JNIEXPORT jboolean JNICALL Java_com_yahoo_ml_jcaffe_MatVector_allocate */ JNIEXPORT void JNICALL Java_com_yahoo_ml_jcaffe_MatVector_deallocateVec(JNIEnv *env, jobject object, jlong address) { - vector* native_ptr = (vector*) address; + vector* native_ptr = (vector*) address; - delete native_ptr; + delete native_ptr; } /* @@ -42,15 +58,36 @@ JNIEXPORT void JNICALL Java_com_yahoo_ml_jcaffe_MatVector_deallocateVec(JNIEnv * JNIEXPORT void JNICALL Java_com_yahoo_ml_jcaffe_MatVector_putnative (JNIEnv *env, jobject object, jint pos, jobject mat) { - vector *native_ptr = (vector*) GetNativeAddress(env, object); - - cv::Mat* mat_ptr = (cv::Mat*) GetNativeAddress(env, mat); - if (mat_ptr==NULL) { - LOG(ERROR) << "invalid native address of Mat"; - return; - } + vector *native_ptr = NULL; + try { + native_ptr = (vector*) GetNativeAddress(env, object); + } catch(const std::exception& ex) { + ThrowJavaException(ex, env); + return; + } - (*native_ptr)[pos] = *mat_ptr; + + cv::Mat* mat_ptr = NULL; + try { + mat_ptr = (cv::Mat*) GetNativeAddress(env, mat); + } catch(const std::exception& ex) { + ThrowJavaException(ex, env); + return; + } + + if (mat_ptr==NULL) { + LOG(ERROR) << "invalid native address of Mat"; + ThrowCosJavaException((char*)"invalid native address of Mat", env); + return; + } + + if (pos < 0 || pos >= native_ptr->size()) { + LOG(ERROR) << "invalid index in MatVector"; + ThrowCosJavaException((char*)"invalid index in MatVector", env); + return; + } + + (*native_ptr)[pos] = *mat_ptr; } /* @@ -61,17 +98,39 @@ JNIEXPORT void JNICALL Java_com_yahoo_ml_jcaffe_MatVector_putnative JNIEXPORT jbyteArray JNICALL Java_com_yahoo_ml_jcaffe_MatVector_data (JNIEnv *env, jobject object, jint pos) { - vector *native_ptr = (vector*) GetNativeAddress(env, object); - cv::Mat mat = (cv::Mat)(*native_ptr)[pos]; - int size = mat.total() * mat.elemSize(); - - jbyteArray dataarray = env->NewByteArray(size); - if(dataarray == NULL){ - LOG(ERROR) << "Out of memory while allocating array for Mat data" ; - return NULL; - } - env->SetByteArrayRegion(dataarray,0, size, (jbyte*)mat.data); - return dataarray; + vector *native_ptr = NULL; + try { + native_ptr = (vector*) GetNativeAddress(env, object); + } catch(const std::exception& ex) { + ThrowJavaException(ex, env); + return NULL; + } + + if (pos < 0 || pos > native_ptr->size()) { + LOG(ERROR) << "Invalid Mat index in MatVector"; + return NULL; + } + + cv::Mat mat = (cv::Mat)(*native_ptr)[pos]; + int size = 0; + try { + size = mat.total() * mat.elemSize(); + } catch(const std::exception& ex) { + ThrowJavaException(ex, env); + return NULL; + } + + jbyteArray dataarray = env->NewByteArray(size); + if(dataarray == NULL || env->ExceptionCheck()){ + LOG(ERROR) << "Out of memory while allocating array for Mat data" ; + return NULL; + } + env->SetByteArrayRegion(dataarray,0, size, (jbyte*)mat.data); + if (env->ExceptionCheck()) { + LOG(ERROR) << "SetByteArrayRegion failed"; + return NULL; + } + return dataarray; } /* @@ -83,9 +142,24 @@ JNIEXPORT jbyteArray JNICALL Java_com_yahoo_ml_jcaffe_MatVector_data JNIEXPORT jint JNICALL Java_com_yahoo_ml_jcaffe_MatVector_height (JNIEnv *env, jobject object, jint pos) { - vector *native_ptr = (vector*) GetNativeAddress(env, object); - cv::Mat mat = (cv::Mat)(*native_ptr)[pos]; + vector *native_ptr = NULL; + try { + native_ptr = (vector*) GetNativeAddress(env, object); + } catch(const std::exception& ex) { + ThrowJavaException(ex, env); + return -1; + } + if (pos < 0 || pos > native_ptr->size()) { + LOG(ERROR) << "Invalid Mat index in MatVector"; + return -1; + } + cv::Mat mat = (cv::Mat)(*native_ptr)[pos]; + try { return mat.rows; + } catch(const std::exception& ex) { + ThrowJavaException(ex, env); + return -1; + } } /* @@ -97,9 +171,24 @@ JNIEXPORT jint JNICALL Java_com_yahoo_ml_jcaffe_MatVector_height JNIEXPORT jint JNICALL Java_com_yahoo_ml_jcaffe_MatVector_width (JNIEnv *env, jobject object, jint pos) { - vector *native_ptr = (vector*) GetNativeAddress(env, object); - cv::Mat mat = (cv::Mat)(*native_ptr)[pos]; + vector *native_ptr = NULL; + try { + native_ptr = (vector*) GetNativeAddress(env, object); + } catch (const std::exception& ex) { + ThrowJavaException(ex, env); + return -1; + } + if (pos < 0 || pos > native_ptr->size()) { + LOG(ERROR) << "Invalid Mat index in MatVector"; + return -1; + } + cv::Mat mat = (cv::Mat)(*native_ptr)[pos]; + try { return mat.cols; + } catch(const std::exception& ex) { + ThrowJavaException(ex, env); + return -1; + } } /* @@ -111,7 +200,22 @@ JNIEXPORT jint JNICALL Java_com_yahoo_ml_jcaffe_MatVector_width JNIEXPORT jint JNICALL Java_com_yahoo_ml_jcaffe_MatVector_channels (JNIEnv *env, jobject object, jint pos) { - vector *native_ptr = (vector*) GetNativeAddress(env, object); + vector *native_ptr = NULL; + try { + native_ptr = (vector*) GetNativeAddress(env, object); + } catch (const std::exception& ex) { + ThrowJavaException(ex, env); + return -1; + } + if (pos < 0 || pos > native_ptr->size()) { + LOG(ERROR) << "Invalid Mat index in MatVector"; + return -1; + } cv::Mat mat = (cv::Mat)(*native_ptr)[pos]; - return mat.channels(); + try { + return mat.channels(); + } catch(const std::exception& ex) { + ThrowJavaException(ex, env); + return -1; + } } diff --git a/caffe-distri/src/main/cpp/util/socket.cpp b/caffe-distri/src/main/cpp/util/socket.cpp index 403bc43..786a39f 100644 --- a/caffe-distri/src/main/cpp/util/socket.cpp +++ b/caffe-distri/src/main/cpp/util/socket.cpp @@ -238,20 +238,26 @@ void SocketAdapter::start_sockt_srvr() { } // Connect called by client with inbuilt support for retries -void SocketChannel::Connect(string peer) { +bool SocketChannel::Connect(string peer) { bool retry = true; int attempts = 0; int client_fd = 0; vector name_port; boost::split(name_port, peer, boost::is_any_of(":")); - while (retry && (attempts < 2)) { + int backoff = 1; + while (retry && (attempts < 5)) { retry = false; if (client_fd == 0) { + string peername = name_port.at(0).c_str();; + string portnumber; + if (name_port.size() > 1) + portnumber = name_port.at(1).c_str(); + LOG(INFO) << "Trying to connect with ...[" - << name_port.at(0).c_str() <<":" - << name_port.at(1).c_str()<< "]"; - client_fd = connect_to_peer(name_port.at(0), - name_port.at(1)); + << peername <<":" + << portnumber << "]"; + client_fd = connect_to_peer(peername, + portnumber); if (!client_fd) { retry = true; } else { @@ -264,8 +270,13 @@ void SocketChannel::Connect(string peer) { } attempts++; // Retry after 10 secs - usleep(10000000); + usleep(backoff*1000000); + backoff = backoff * 2; } + if (retry) + return false; + + return true; } // Real connect call without retries (called by connect above) diff --git a/caffe-distri/src/test/java/com/yahoo/ml/jcaffe/CaffeNetTest.java b/caffe-distri/src/test/java/com/yahoo/ml/jcaffe/CaffeNetTest.java index 5b63438..37815d7 100755 --- a/caffe-distri/src/test/java/com/yahoo/ml/jcaffe/CaffeNetTest.java +++ b/caffe-distri/src/test/java/com/yahoo/ml/jcaffe/CaffeNetTest.java @@ -20,7 +20,7 @@ public class CaffeNetTest { String rootPath, solver_config_path, imagePath; - CaffeNet net, test_net; + CaffeNet net, test_net, socket_net; SolverParameter solver_param; int index = 0; List file_list; @@ -58,6 +58,16 @@ public void setUp() throws Exception { -1); assertTrue(test_net != null); + socket_net = new CaffeNet(solver_config_path, + "", + "", + 1, //num_local_devices, + 2, //cluster_size, + 0, //node_rank, + false, //isTraining, + CaffeNet.SOCKET, //NONE + -1); + assertTrue(socket_net != null); solver_param = Utils.GetSolverParam(solver_config_path); assertEquals(solver_param.getSolverMode(), SolverParameter.SolverMode.CPU); @@ -70,6 +80,49 @@ public void tearDown() throws Exception { net.deallocate(); } + @Test + public void initinvalid() { + assertFalse(net.init(-1)); + } + + @Test + public void deviceIDinvalid() { + assertEquals(net.deviceID(-1), -1); + } + + @Test + public void inititerinvalid() { + assertEquals(net.getInitIter(-1), -1); + } + + @Test + public void maxiterinvalid() { + assertEquals(net.getMaxIter(-1), -1); + } + + @Test + public void snapshotfilenameinvalid() { + assertNull(net.snapshotFilename(-1,false)); + } + + @Test + public void connectnull(){ + String[] addrs = null; + assertTrue(net.connect(addrs)); + } + + @Test + public void connectbogus(){ + String[] addrs = {"0x222", "0x333"}; + boolean pass = true; + try { + pass = socket_net.connect(addrs); + } catch(Exception e) { + pass = false; + } + assertFalse(pass); + } + @Test public void testBasic() { String[] addrs = net.localAddresses(); @@ -138,6 +191,79 @@ private void nextBatch(MatVector matVec, FloatBlob labels) throws Exception { } } + @Test + public void trainnull() throws Exception { + SolverParameter solver_param = Utils.GetSolverParam(rootPath + "caffe-distri/src/test/resources/caffenet_solver.prototxt"); + + String net_proto_file = solver_param.getNet(); + NetParameter net_param = Utils.GetNetParam(rootPath + "caffe-distri/" + net_proto_file); + + //blob + MatVector matVec = new MatVector(batch_size); + FloatBlob[] dataBlobs = new FloatBlob[1]; + FloatBlob data_blob = new FloatBlob(); + data_blob.reshape(batch_size, channels, height, width); + dataBlobs[0] = data_blob; + + FloatBlob labelblob = new FloatBlob(); + labelblob.reshape(batch_size, 1, 1, 1); + + //transformer + LayerParameter train_layer_param = net_param.getLayer(0); + TransformationParameter param = train_layer_param.getTransformParam(); + FloatDataTransformer xform = new FloatDataTransformer(param, true); + + nextBatch(matVec, labelblob); + xform.transform(matVec, data_blob); + boolean fail = false; + try { + net.train(0, null, labelblob.cpu_data()); + } catch(Exception e) { + fail = true; + } + assertTrue(fail); + xform.deallocate(); + data_blob.deallocate(); + matVec.deallocate(); + } + + @Test + public void predictnull() throws Exception { + SolverParameter solver_param = Utils.GetSolverParam(rootPath + "caffe-distri/src/test/resources/caffenet_solver.prototxt"); + + String net_proto_file = solver_param.getNet(); + NetParameter net_param = Utils.GetNetParam(rootPath + "caffe-distri/" + net_proto_file); + + //blob + MatVector matVec = new MatVector(batch_size); + FloatBlob[] dataBlobs = new FloatBlob[1]; + FloatBlob data_blob = new FloatBlob(); + data_blob.reshape(batch_size, channels, height, width); + dataBlobs[0] = data_blob; + + FloatBlob labelblob = new FloatBlob(); + labelblob.reshape(batch_size, 1, 1, 1); + + //transformer + LayerParameter train_layer_param = net_param.getLayer(0); + TransformationParameter param = train_layer_param.getTransformParam(); + FloatDataTransformer xform = new FloatDataTransformer(param, true); + + nextBatch(matVec, labelblob); + xform.transform(matVec, data_blob); + boolean fail = false; + String[] test_features = {"loss"}; + try { + FloatBlob[] top_blobs_vec = net.predict(0, null, labelblob.cpu_data(), test_features); + } catch(Exception e) { + fail = true; + } + assertTrue(fail); + xform.deallocate(); + data_blob.deallocate(); + matVec.deallocate(); + } + @Test public void testTrain() throws Exception { SolverParameter solver_param = Utils.GetSolverParam(rootPath + "caffe-distri/src/test/resources/caffenet_solver.prototxt"); diff --git a/caffe-distri/src/test/java/com/yahoo/ml/jcaffe/FloatArrayTest.java b/caffe-distri/src/test/java/com/yahoo/ml/jcaffe/FloatArrayTest.java new file mode 100644 index 0000000..1a868ec --- /dev/null +++ b/caffe-distri/src/test/java/com/yahoo/ml/jcaffe/FloatArrayTest.java @@ -0,0 +1,43 @@ +// Copyright 2016 Yahoo Inc. +// Licensed under the terms of the Apache 2.0 license. +// Please see LICENSE file in the project root for terms. +package com.yahoo.ml.jcaffe; + +import org.testng.annotations.Test; +import static org.testng.Assert.*; +import java.util.Random; + +public class FloatArrayTest { + @Test + public void floatarraygetnegative(){ + FloatBlob data_blob = new FloatBlob(); + data_blob.reshape(5, 1, 1, 1); + FloatArray fa = null; + fa = data_blob.cpu_data(); + boolean fail = false; + if (fa.get(-1) == 0) + fail = true; + + assertTrue(fail); + } + + @Test + public void floatarraysetinvalid(){ + FloatBlob data_blob = new FloatBlob(); + data_blob.reshape(5, 1, 1, 1); + FloatArray fa = null; + boolean fail = false; + fa = data_blob.cpu_data(); + if (fa == null) { + System.out.println("cpu_data returned null"); + return; + } + try { + fa.set(-1,-1); + } catch(Exception e) { + fail = true; + } + assertTrue(fail); + data_blob.deallocate(); + } +} diff --git a/caffe-distri/src/test/java/com/yahoo/ml/jcaffe/FloatBlobTest.java b/caffe-distri/src/test/java/com/yahoo/ml/jcaffe/FloatBlobTest.java index 509b309..e957cf7 100755 --- a/caffe-distri/src/test/java/com/yahoo/ml/jcaffe/FloatBlobTest.java +++ b/caffe-distri/src/test/java/com/yahoo/ml/jcaffe/FloatBlobTest.java @@ -14,36 +14,74 @@ public void testAllocate() { blob.deallocate(); } + @Test + public void setcpudatanull() { + FloatBlob blob = new FloatBlob(); + boolean res = blob.set_cpu_data(null); + assertFalse(res); + } + + @Test + public void copyfromnull() { + FloatBlob blob2 = new FloatBlob(); + blob2.reshape(1, 1, 2, 2); + assertFalse(blob2.copyFrom(null)); + } + + @Test + public void floatblobnull(){ + boolean fail = false; + try { + FloatBlob blob1 = new FloatBlob(-1,false); + } catch(Exception e) { + fail = true; + } + assertFalse(fail); + } + + @Test + public void reshapeinvalid(){ + boolean fail = false; + FloatBlob blob = new FloatBlob(); + try { + fail = blob.reshape(0,0,0,0); + } catch(Exception e){ + fail = true; + } + assertFalse(fail); + } + + @Test public void testBasic() { - FloatBlob blob = new FloatBlob(); - boolean res = blob.reshape(1, 1, 2, 2); - assertTrue(res); - - float[] input = { 1.0f, 2.0f, 3.0f, 4.0f }; - res = blob.set_cpu_data(input); - assertTrue(res); - - FloatArray output = blob.cpu_data(); - for (int i=0; i file_list; - - @BeforeMethod - public void setUp() throws Exception { - String fullPath = getClass().getClassLoader().getResource("log4j.properties").getPath(); - rootPath = fullPath.substring(0, fullPath.indexOf("caffe-distri/")); - imagePath = rootPath + "data/images"; - - file_list = Files.readAllLines(Paths.get(imagePath + "/labels.txt"), StandardCharsets.UTF_8); + String rootPath, imagePath; + final int batchs = 1; + final int batch_size = 2; + final int channels = 3; + final int height = 227; + final int width = 227; + + int index = 0; + List file_list; + + @BeforeMethod + public void setUp() throws Exception { + String fullPath = getClass().getClassLoader().getResource("log4j.properties").getPath(); + rootPath = fullPath.substring(0, fullPath.indexOf("caffe-distri/")); + imagePath = rootPath + "data/images"; + + file_list = Files.readAllLines(Paths.get(imagePath + "/labels.txt"), StandardCharsets.UTF_8); + } + + @Test + private void matNull() throws Exception { + boolean fail = false; + try { + Mat mat = new Mat(null); + } catch(Exception e) { + fail = true; } - - @Test - private void basicTest() throws Exception { - MatVector matVec = new MatVector(1); - byte[] buf = new byte[1024 * 1024]; - int width = 227; - int height = 227; - String line = file_list.get(index++); - if (index >= file_list.size()) index = 0; - - String[] line_splits = line.split(" "); - String filename = line_splits[0]; - int label = Integer.parseInt(line_splits[1]); - - ByteArrayOutputStream out = new ByteArrayOutputStream(); - DataInputStream in = new DataInputStream(new FileInputStream(imagePath + "/" + filename)); - int len = in.read(buf, 0, buf.length); - while (len > 0) { - out.write(buf, 0, len); - len = in.read(buf, 0, buf.length); - } - in.close(); - - byte[] data = out.toByteArray(); + assertTrue(fail); + } + + @Test + private void basicTest() throws Exception { + MatVector matVec = new MatVector(1); + byte[] buf = new byte[1024 * 1024]; + int width = 227; + int height = 227; + String line = file_list.get(index++); + if (index >= file_list.size()) index = 0; - Mat mat = new Mat(data); - // mat.decode(Mat.CV_LOAD_IMAGE_COLOR); - //mat.resize(width, height); - Mat oldmat = matVec.put(0, mat); - if (oldmat != null) - oldmat.deallocate(); - assertEquals(matVec.width(0), mat.width()); - assertEquals(matVec.height(0), mat.height()); - //GC doesn't have any affect on mat with value 227 - width++; - height++; - - //reuse matVec for new mat with value 228 and clean old mat with 227 properly - mat = new Mat(data); - mat.decode(Mat.CV_LOAD_IMAGE_COLOR); - mat.resize(width, height); - oldmat = matVec.put(0, mat); - oldmat.deallocate(); - assertEquals(matVec.width(0), 228); - assertEquals(matVec.height(0), 228); - //GC to deallocate mat with value 227. Currently matVec has mat with 228 - //Irrespective of GC, mat with 228 shouldn't get deallocated before matVec deallocate - mat = null; - assertEquals(matVec.width(0), 228); - assertEquals(matVec.height(0), 228); - matVec.deallocate(); - out.close(); + String[] line_splits = line.split(" "); + String filename = line_splits[0]; + int label = Integer.parseInt(line_splits[1]); + + ByteArrayOutputStream out = new ByteArrayOutputStream(); + DataInputStream in = new DataInputStream(new FileInputStream(imagePath + "/" + filename)); + int len = in.read(buf, 0, buf.length); + while (len > 0) { + out.write(buf, 0, len); + len = in.read(buf, 0, buf.length); } + in.close(); - private byte[] getDataFromFile(int pos) throws Exception{ - byte[] buf = new byte[1024 * 1024]; - String line = file_list.get(pos); - String[] line_splits = line.split(" "); - String filename = line_splits[0]; - int label = Integer.parseInt(line_splits[1]); + byte[] data = out.toByteArray(); - ByteArrayOutputStream out = new ByteArrayOutputStream(); - DataInputStream in = new DataInputStream(new FileInputStream(imagePath + "/" + filename)); - int len = in.read(buf, 0, buf.length); - while (len > 0) { - out.write(buf, 0, len); - len = in.read(buf, 0, buf.length); - } - in.close(); - byte[] b = out.toByteArray(); - out.close(); - return b; + Mat mat = new Mat(data); + // mat.decode(Mat.CV_LOAD_IMAGE_COLOR); + //mat.resize(width, height); + Mat oldmat = matVec.put(0, mat); + if (oldmat != null) + oldmat.deallocate(); + assertEquals(matVec.width(0), mat.width()); + assertEquals(matVec.height(0), mat.height()); + //GC doesn't have any affect on mat with value 227 + width++; + height++; + + //reuse matVec for new mat with value 228 and clean old mat with 227 properly + mat = new Mat(data); + mat.decode(Mat.CV_LOAD_IMAGE_COLOR); + mat.resize(width, height); + oldmat = matVec.put(0, mat); + oldmat.deallocate(); + assertEquals(matVec.width(0), 228); + assertEquals(matVec.height(0), 228); + //GC to deallocate mat with value 227. Currently matVec has mat with 228 + //Irrespective of GC, mat with 228 shouldn't get deallocated before matVec deallocate + mat = null; + assertEquals(matVec.width(0), 228); + assertEquals(matVec.height(0), 228); + matVec.deallocate(); + out.close(); + } + + private byte[] getDataFromFile(int pos) throws Exception{ + byte[] buf = new byte[1024 * 1024]; + String line = file_list.get(pos); + String[] line_splits = line.split(" "); + String filename = line_splits[0]; + int label = Integer.parseInt(line_splits[1]); + + ByteArrayOutputStream out = new ByteArrayOutputStream(); + DataInputStream in = new DataInputStream(new FileInputStream(imagePath + "/" + filename)); + int len = in.read(buf, 0, buf.length); + while (len > 0) { + out.write(buf, 0, len); + len = in.read(buf, 0, buf.length); } - - - @Test - private void matResizeDecodeTest() throws Exception { - byte[] data0 = getDataFromFile(0); - Mat m = new Mat(data0); - m.resize(227,227); - m.decode(Mat.CV_LOAD_IMAGE_COLOR); - m.deallocate(); + in.close(); + byte[] b = out.toByteArray(); + out.close(); + return b; + } + + + @Test + private void matResizeDecodeTest() throws Exception { + byte[] data0 = getDataFromFile(0); + Mat m = new Mat(data0); + m.resize(227,227); + m.decode(Mat.CV_LOAD_IMAGE_COLOR); + m.deallocate(); + } + + @Test + private void matChannelsTest() throws Exception { + byte[] data0 = getDataFromFile(0); + Mat m = new Mat(3, 9, 9, data0); + assertEquals(m.channels(), 3); + m.deallocate(); + } + + @Test + private void basicDataTest() throws Exception { + MatVector matVec = new MatVector(1); + byte[] data0 = getDataFromFile(0); + Mat mat = new Mat(data0); + matVec.put(0, mat); + byte[] resultData0 = matVec.data(0); + //What we wrote is what we get + assertTrue(Arrays.equals(data0, resultData0)); + + //Now replace matVec 0th mat object with a new one and make sure it is the new one + byte[] data1 = getDataFromFile(1); + mat = new Mat(data1); + resultData0 = matVec.data(0); + assertTrue(Arrays.equals(data0, resultData0)); + Mat oldmat = matVec.put(0, mat); + byte[] resultData1 = matVec.data(0); + assertTrue(Arrays.equals(data1, resultData1)); + matVec.deallocate(); + assertTrue(Arrays.equals(data0, oldmat.data())); + oldmat.deallocate(); + } + + @Test + private void getMatDecodeWithInvalidFlag() throws Exception { + byte[] data0 = getDataFromFile(0); + Mat mat = new Mat(data0); + boolean fail = false; + try { + mat.decode(-1); + } catch(Exception e) { + fail = true; } - - @Test - private void matChannelsTest() throws Exception { - byte[] data0 = getDataFromFile(0); - Mat m = new Mat(3, 9, 9, data0); - assertEquals(m.channels(), 3); - m.deallocate(); + assertFalse(fail); + } + + @Test + private void matResizeInvalid() throws Exception { + byte[] data0 = getDataFromFile(0); + Mat m = new Mat(data0); + boolean fail = false; + try { + m.resize(-1,-1); + } catch(Exception e) { + fail = true; } + assertTrue(fail); + } - @Test - private void basicDataTest() throws Exception { - MatVector matVec = new MatVector(1); - byte[] data0 = getDataFromFile(0); - Mat mat = new Mat(data0); - matVec.put(0, mat); - byte[] resultData0 = matVec.data(0); - //What we wrote is what we get - assertTrue(Arrays.equals(data0, resultData0)); - - //Now replace matVec 0th mat object with a new one and make sure it is the new one - byte[] data1 = getDataFromFile(1); - mat = new Mat(data1); - resultData0 = matVec.data(0); - assertTrue(Arrays.equals(data0, resultData0)); - Mat oldmat = matVec.put(0, mat); - byte[] resultData1 = matVec.data(0); - assertTrue(Arrays.equals(data1, resultData1)); - matVec.deallocate(); - assertTrue(Arrays.equals(data0, oldmat.data())); - oldmat.deallocate(); - } } diff --git a/caffe-distri/src/test/java/com/yahoo/ml/jcaffe/MatVectorTest.java b/caffe-distri/src/test/java/com/yahoo/ml/jcaffe/MatVectorTest.java new file mode 100644 index 0000000..5b74ffe --- /dev/null +++ b/caffe-distri/src/test/java/com/yahoo/ml/jcaffe/MatVectorTest.java @@ -0,0 +1,138 @@ +// Copyright 2016 Yahoo Inc. +// Licensed under the terms of the Apache 2.0 license. +// Please see LICENSE file in the project root for terms. +package com.yahoo.ml.jcaffe; + +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import caffe.Caffe.*; +import com.google.protobuf.TextFormat; + +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.FileInputStream; +import java.io.FileReader; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.List; +import java.util.Arrays; +import static org.testng.Assert.*; + +public class MatVectorTest { + String rootPath, imagePath; + final int batchs = 1; + final int batch_size = 2; + final int channels = 3; + final int height = 227; + final int width = 227; + + int index = 0; + List file_list; + + @BeforeMethod + public void setUp() throws Exception { + String fullPath = getClass().getClassLoader().getResource("log4j.properties").getPath(); + rootPath = fullPath.substring(0, fullPath.indexOf("caffe-distri/")); + imagePath = rootPath + "data/images"; + + file_list = Files.readAllLines(Paths.get(imagePath + "/labels.txt"), StandardCharsets.UTF_8); + } + + @Test + private void matVecNegativeIndex() throws Exception { + boolean fail = false; + try { + MatVector matVec = new MatVector(-1); + } catch(Exception e) { + fail = true; + } + assertTrue(fail); + } + + @Test + private void matNullInMatVec() throws Exception { + MatVector matVector = new MatVector(1); + boolean fail = false; + try { + matVector.put(0,null); + } catch(Exception e) { + fail = true; + } + assertTrue(fail); + } + + @Test + private void matOnWrongMatVecIndex() throws Exception { + MatVector matVector = new MatVector(1); + byte[] data = getDataFromFile(0); + Mat mat = new Mat(data); + boolean fail = false; + try { + matVector.put(1,mat); + } catch(Exception e) { + fail = true; + } + assertTrue(fail); + } + + private byte[] getDataFromFile(int pos) throws Exception{ + byte[] buf = new byte[1024 * 1024]; + String line = file_list.get(pos); + String[] line_splits = line.split(" "); + String filename = line_splits[0]; + int label = Integer.parseInt(line_splits[1]); + + ByteArrayOutputStream out = new ByteArrayOutputStream(); + DataInputStream in = new DataInputStream(new FileInputStream(imagePath + "/" + filename)); + int len = in.read(buf, 0, buf.length); + while (len > 0) { + out.write(buf, 0, len); + len = in.read(buf, 0, buf.length); + } + in.close(); + byte[] b = out.toByteArray(); + out.close(); + return b; + } + + @Test + private void getMatVecDatafromInvalidIndex() throws Exception { + MatVector matVec = new MatVector(1); + byte[] data0 = getDataFromFile(0); + Mat mat = new Mat(data0); + matVec.put(0, mat); + byte[] resultData0 = matVec.data(-1); + assertEquals(resultData0, null); + } + + @Test + private void invalidheight() throws Exception { + MatVector matVec = new MatVector(1); + byte[] data0 = getDataFromFile(0); + Mat mat = new Mat(data0); + matVec.put(0, mat); + assertEquals(matVec.height(-1),-1); + } + + @Test + private void invalidwidth() throws Exception { + MatVector matVec = new MatVector(1); + byte[] data0 = getDataFromFile(0); + Mat mat = new Mat(data0); + matVec.put(0, mat); + assertEquals(matVec.width(-1),-1); + } + + @Test + private void invalidchannel() throws Exception { + MatVector matVec = new MatVector(1); + byte[] data0 = getDataFromFile(0); + Mat mat = new Mat(data0); + matVec.put(0, mat); + assertEquals(matVec.channels(-1),-1); + } + +} diff --git a/caffe-distri/src/test/java/com/yahoo/ml/jcaffe/TransformTest.java b/caffe-distri/src/test/java/com/yahoo/ml/jcaffe/TransformTest.java index 852291c..0683066 100755 --- a/caffe-distri/src/test/java/com/yahoo/ml/jcaffe/TransformTest.java +++ b/caffe-distri/src/test/java/com/yahoo/ml/jcaffe/TransformTest.java @@ -22,113 +22,159 @@ import static org.testng.Assert.*; public class TransformTest { - String rootPath, imagePath; - MatVector matVec; - final int batchs = 1000; - final int batch_size = 4; - final int channels = 3; - final int height = 227; - final int width = 227; - - int index = 0; - List file_list; - - @BeforeMethod - public void setUp() throws Exception { - String fullPath = getClass().getClassLoader().getResource("log4j.properties").getPath(); - rootPath = fullPath.substring(0, fullPath.indexOf("caffe-distri/")); - imagePath = rootPath + "data/images"; - - file_list = Files.readAllLines(Paths.get(imagePath + "/labels.txt"), StandardCharsets.UTF_8); - matVec = new MatVector(batch_size); + String rootPath, imagePath; + MatVector matVec; + final int batchs = 1000; + final int batch_size = 4; + final int channels = 3; + final int height = 227; + final int width = 227; + + int index = 0; + List file_list; + + @BeforeMethod + public void setUp() throws Exception { + String fullPath = getClass().getClassLoader().getResource("log4j.properties").getPath(); + rootPath = fullPath.substring(0, fullPath.indexOf("caffe-distri/")); + imagePath = rootPath + "data/images"; + + file_list = Files.readAllLines(Paths.get(imagePath + "/labels.txt"), StandardCharsets.UTF_8); + matVec = new MatVector(batch_size); + } + + @AfterMethod + public void tearDown() throws Exception { + matVec.deallocate(); + } + + private void nextBatch() throws Exception { + byte[] buf = new byte[1024 * 1024]; + + for (int idx=0; idx= file_list.size()) index = 0; + + String[] line_splits = line.split(" "); + String filename = line_splits[0]; + int label = Integer.parseInt(line_splits[1]); + + ByteArrayOutputStream out = new ByteArrayOutputStream(); + DataInputStream in = new DataInputStream(new FileInputStream(imagePath + "/" + filename)); + int len = in.read(buf, 0, buf.length); + while (len > 0) { + out.write(buf, 0, len); + len = in.read(buf, 0, buf.length); + } + in.close(); + + byte[] data = out.toByteArray(); + + Mat mat = new Mat(data); + mat.decode(Mat.CV_LOAD_IMAGE_COLOR); + mat.resize(227, 227); + + Mat oldmat = matVec.put(idx, mat); + if(oldmat != null) + oldmat.deallocate(); + + out.close(); } - - @AfterMethod - public void tearDown() throws Exception { - matVec.deallocate(); + } + + @Test + public void floatdatatransformernull(){ + boolean fail = false; + try { + FloatDataTransformer test_xform = new FloatDataTransformer(null, false); + } catch(Exception e) { + fail = true; } - - private void nextBatch() throws Exception { - byte[] buf = new byte[1024 * 1024]; - - for (int idx=0; idx= file_list.size()) index = 0; - - String[] line_splits = line.split(" "); - String filename = line_splits[0]; - int label = Integer.parseInt(line_splits[1]); - - ByteArrayOutputStream out = new ByteArrayOutputStream(); - DataInputStream in = new DataInputStream(new FileInputStream(imagePath + "/" + filename)); - int len = in.read(buf, 0, buf.length); - while (len > 0) { - out.write(buf, 0, len); - len = in.read(buf, 0, buf.length); - } - in.close(); - - byte[] data = out.toByteArray(); - - Mat mat = new Mat(data); - mat.decode(Mat.CV_LOAD_IMAGE_COLOR); - mat.resize(227, 227); - - Mat oldmat = matVec.put(idx, mat); - if(oldmat != null) - oldmat.deallocate(); - - out.close(); - } + assertTrue(fail); + } + + @Test + public void transformnull() throws Exception { + SolverParameter.Builder solver_builder = SolverParameter.newBuilder(); + FileReader reader = new FileReader(rootPath + "caffe-distri/src/test/resources/caffenet_solver.prototxt"); + TextFormat.merge(reader, solver_builder); + reader.close(); + + SolverParameter solver_param = solver_builder.build(); + String net_proto_file = solver_param.getNet(); + + NetParameter.Builder net_builder = NetParameter.newBuilder(); + reader = new FileReader(rootPath + "caffe-distri/" + net_proto_file); + TextFormat.merge(reader, net_builder); + reader.close(); + + NetParameter net_param = net_builder.build(); + + //blob + FloatBlob blob = new FloatBlob(); + blob.reshape(batch_size, channels, height, width); + + //train + LayerParameter train_layer_param = net_param.getLayer(0); + TransformationParameter param = train_layer_param.getTransformParam(); + FloatDataTransformer trans_xform = new FloatDataTransformer(param, true); + nextBatch(); + boolean fail=false; + try { + trans_xform.transform(null, null); + } catch(Exception e) { + fail = true; } - - @Test - public void testTransform() throws Exception { - SolverParameter.Builder solver_builder = SolverParameter.newBuilder(); - FileReader reader = new FileReader(rootPath + "caffe-distri/src/test/resources/caffenet_solver.prototxt"); - TextFormat.merge(reader, solver_builder); - reader.close(); - - SolverParameter solver_param = solver_builder.build(); - String net_proto_file = solver_param.getNet(); - - NetParameter.Builder net_builder = NetParameter.newBuilder(); - reader = new FileReader(rootPath + "caffe-distri/" + net_proto_file); - TextFormat.merge(reader, net_builder); - reader.close(); - - NetParameter net_param = net_builder.build(); - - //blob - FloatBlob blob = new FloatBlob(); - blob.reshape(batch_size, channels, height, width); - - //train - LayerParameter train_layer_param = net_param.getLayer(0); - TransformationParameter param = train_layer_param.getTransformParam(); - FloatDataTransformer trans_xform = new FloatDataTransformer(param, true); - System.out.print("TRAINING:"); - for (int i=0; i 1000); - } - System.out.println(); - - //test - FloatDataTransformer test_xform = new FloatDataTransformer(param, false); - System.out.print("TEST:"); - for (int i =0; i 1000); - } - System.out.println(); - - //release C++ resource - trans_xform.deallocate(); - test_xform.deallocate(); - blob.deallocate(); + assertTrue(fail); + } + + @Test + public void testTransform() throws Exception { + SolverParameter.Builder solver_builder = SolverParameter.newBuilder(); + FileReader reader = new FileReader(rootPath + "caffe-distri/src/test/resources/caffenet_solver.prototxt"); + TextFormat.merge(reader, solver_builder); + reader.close(); + + SolverParameter solver_param = solver_builder.build(); + String net_proto_file = solver_param.getNet(); + + NetParameter.Builder net_builder = NetParameter.newBuilder(); + reader = new FileReader(rootPath + "caffe-distri/" + net_proto_file); + TextFormat.merge(reader, net_builder); + reader.close(); + + NetParameter net_param = net_builder.build(); + + //blob + FloatBlob blob = new FloatBlob(); + blob.reshape(batch_size, channels, height, width); + + //train + LayerParameter train_layer_param = net_param.getLayer(0); + TransformationParameter param = train_layer_param.getTransformParam(); + FloatDataTransformer trans_xform = new FloatDataTransformer(param, true); + System.out.print("TRAINING:"); + for (int i=0; i 1000); + } + System.out.println(); + + //test + FloatDataTransformer test_xform = new FloatDataTransformer(param, false); + System.out.print("TEST:"); + for (int i =0; i 1000); } + System.out.println(); + + //release C++ resource + trans_xform.deallocate(); + test_xform.deallocate(); + blob.deallocate(); + } } From 1dd7c1ca66a4b55728e97b52570cb01525a2a80a Mon Sep 17 00:00:00 2001 From: Mridul Jain Date: Tue, 19 Jul 2016 21:39:51 -0700 Subject: [PATCH 2/3] Deallocation to free mem --- .../src/test/java/com/yahoo/ml/jcaffe/CaffeNetTest.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/caffe-distri/src/test/java/com/yahoo/ml/jcaffe/CaffeNetTest.java b/caffe-distri/src/test/java/com/yahoo/ml/jcaffe/CaffeNetTest.java index 37815d7..9a88c06 100755 --- a/caffe-distri/src/test/java/com/yahoo/ml/jcaffe/CaffeNetTest.java +++ b/caffe-distri/src/test/java/com/yahoo/ml/jcaffe/CaffeNetTest.java @@ -78,6 +78,8 @@ public void setUp() throws Exception { @AfterMethod public void tearDown() throws Exception { net.deallocate(); + socket_net.deallocate(); + test_net.deallocate(); } @Test From ede15c59c137851e57aabf1ab51de6e63825ca4b Mon Sep 17 00:00:00 2001 From: Mridul Jain Date: Wed, 20 Jul 2016 11:27:17 -0700 Subject: [PATCH 3/3] JNI addition --- caffe-distri/src/main/cpp/jni/JniMat.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/caffe-distri/src/main/cpp/jni/JniMat.cpp b/caffe-distri/src/main/cpp/jni/JniMat.cpp index 12cd189..ed7c1e7 100755 --- a/caffe-distri/src/main/cpp/jni/JniMat.cpp +++ b/caffe-distri/src/main/cpp/jni/JniMat.cpp @@ -64,7 +64,7 @@ JNIEXPORT void JNICALL Java_com_yahoo_ml_jcaffe_Mat_deallocate //Mat object is only one responsible for cleaning itself and it's data if(dataaddress){ - delete (jbyte*)dataaddress; + delete[] (jbyte*)dataaddress; } delete (cv::Mat*) native_ptr; }