Skip to content
This repository has been archived by the owner on Nov 16, 2019. It is now read-only.

Exception handling JNI #116

Merged
merged 3 commits into from
Jul 21, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions caffe-distri/include/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@ using namespace caffe;
#ifdef __cplusplus
extern "C" {
#endif

bool SetNativeAddress(JNIEnv *env, jobject object, void* address);
void* GetNativeAddress(JNIEnv *env, jobject object);

bool GetStringVector(vector<const char*>& vec, JNIEnv *env, jobjectArray array, int length);
bool GetFloatBlobVector(vector< Blob<float>* >& vec, JNIEnv *env, jobjectArray array, int length);


bool SetNativeAddress(JNIEnv *env, jobject object, void* address);
void* GetNativeAddress(JNIEnv *env, jobject object);

bool GetStringVector(vector<const char*>& vec, JNIEnv *env, jobjectArray array, int length);
bool GetFloatBlobVector(vector< Blob<float>* >& vec, JNIEnv *env, jobjectArray array, int length);
void ThrowJavaException(const std::exception& ex, JNIEnv* env);
void ThrowCosJavaException(char* message, JNIEnv* env);
#ifdef __cplusplus
}
#endif
Expand Down
2 changes: 1 addition & 1 deletion caffe-distri/include/util/socket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class SocketChannel {
public:
SocketChannel();
~SocketChannel();
void Connect(string peer);
bool Connect(string peer);
int client_fd;
caffe::BlockingQueue<QueuedMessage*> receive_queue;
int serving_fd;
Expand Down
8 changes: 4 additions & 4 deletions caffe-distri/src/main/cpp/CaffeNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ bool RDMACaffeNet<Dtype>::connect(vector<const char*>& peer_addresses) {
rdma_channels_,
this->node_rank_));
// Pair devices for map-reduce synchronization
this->syncs_[0]->prepare(this->local_devices_,
this->syncs_[0]->Prepare(this->local_devices_,
&this->syncs_);

return true;
Expand All @@ -382,7 +382,8 @@ bool SocketCaffeNet<Dtype>::connect(vector<const char*>& peer_addresses) {
if (i != this->node_rank_) {
const char* addr = peer_addresses[i];
string addr_str(addr, strlen(addr));
sockt_channels_[i]->Connect(addr_str);
if(!sockt_channels_[i]->Connect(addr_str))
return false;
}

#ifndef CPU_ONLY
Expand Down Expand Up @@ -559,9 +560,8 @@ void CaffeNet<Dtype>::predict(int solver_index,
input_adapter_[solver_index]->feed(input_data, input_labels);

//invoke network's Forward operation
const vector<Blob<Dtype>*> dummy_bottom_vec;
CHECK(nets_[solver_index]);
nets_[solver_index]->Forward(dummy_bottom_vec);
nets_[solver_index]->Forward();

//grab the output blobs via names
int num_features = output_blob_names.size();
Expand Down
68 changes: 46 additions & 22 deletions caffe-distri/src/main/cpp/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ bool SetNativeAddress(JNIEnv *env, jobject object, void* address) {
}
/* Get a reference to JVM object class */
jclass claz = env->GetObjectClass(object);
if (claz == NULL) {
if (claz == NULL || env->ExceptionCheck()) {
LOG(ERROR) << "unable to get object's class";
return false;
}
/* Locate init(long) method */
jmethodID methodId = env->GetMethodID(claz, "init", "(J)V");
if (methodId == NULL) {
if (methodId == NULL || env->ExceptionCheck()) {
LOG(ERROR) << "could not locate init() method";
return false;
}
Expand All @@ -40,13 +40,13 @@ void* GetNativeAddress(JNIEnv *env, jobject object) {
}
/* Get a reference to JVM object class */
jclass claz = env->GetObjectClass(object);
if (claz == NULL) {
if (claz == NULL || env->ExceptionCheck()) {
LOG(ERROR) << "unable to get object's class";
return 0;
}
/* Getting the field id in the class */
jfieldID fieldId = env->GetFieldID(claz, "address", "J");
if (fieldId == NULL) {
if (fieldId == NULL || env->ExceptionCheck()) {
LOG(ERROR) << "could not locate field 'address'";
return 0;
}
Expand All @@ -57,11 +57,15 @@ void* GetNativeAddress(JNIEnv *env, jobject object) {
bool GetStringVector(vector<const char*>& vec, JNIEnv *env, jobjectArray array, int length) {
for (int i = 0; i < length; i++) {
jstring addr = (jstring)env->GetObjectArrayElement(array, i);
if (addr == NULL) {
if (addr == NULL || env->ExceptionCheck()) {
LOG(INFO) << i << "-th string is NULL";
vec[i] = NULL;
} else {
const char *cStr = env->GetStringUTFChars(addr, NULL);
if (env->ExceptionCheck()) {
LOG(ERROR) << "GetStringUTFChars failed";
return false;
}
vec[i] = cStr;
//Too many local refs could get created due to the loop, so delete them
//CHECKME:cStr is also a local ref called in loop, but it's not clear if deleting it via DeleteLocalRef deletes the memory pointed by it too
Expand All @@ -73,25 +77,45 @@ bool GetStringVector(vector<const char*>& vec, JNIEnv *env, jobjectArray array,
}

bool GetFloatBlobVector(vector< Blob<float>* >& vec, JNIEnv *env, jobjectArray array, int length) {
if (array == NULL) {
LOG(ERROR) << "array is NULL";
return false;
}

if (array == NULL) {
LOG(ERROR) << "array is NULL";
return false;
}
for (int i = 0; i < length; i++) {
//get i-th FloatBlob object (JVM)
jobject object = env->GetObjectArrayElement(array, i);
if (object == NULL) {
LOG(WARNING) << i << "-th FloatBlob is NULL";
vec[i] = NULL;
} else {
//find the native Blob<float> object
vec[i] = (Blob<float>*) GetNativeAddress(env, object);
}
//Too many local refs could get created due to the loop, so delete them
env->DeleteLocalRef(object);
//get i-th FloatBlob object (JVM)
jobject object = env->GetObjectArrayElement(array, i);
if (object == NULL || env->ExceptionCheck()) {
LOG(WARNING) << i << "-th FloatBlob is NULL";
vec[i] = NULL;
} else {
try{
//find the native Blob<float> object
vec[i] = (Blob<float>*) GetNativeAddress(env, object);
} catch (const std::exception& ex) {
ThrowJavaException(ex, env);
return false;
}
}
//Too many local refs could get created due to the loop, so delete them
env->DeleteLocalRef(object);
if (env->ExceptionCheck()) {
LOG(ERROR) << "DeleteLocalRef failed";
return false;
}
}

return true;
}

return true;
void ThrowJavaException(const std::exception& ex, JNIEnv *env) {
char exMsg[sizeof(typeid(ex).name())+sizeof(ex.what())+3];
sprintf(exMsg, "%s : %s", typeid(ex).name(), ex.what());
jclass exClass = env->FindClass("java/lang/Exception");
env->ThrowNew(exClass, exMsg);
}

void ThrowCosJavaException(char* message, JNIEnv *env) {
jclass exClass = env->FindClass("java/lang/Exception");
env->ThrowNew(exClass, message);
}
Loading