Skip to content

Commit 85ae893

Browse files
Use InterpreterBuilder::AddDelegate (attempt #2).
Use the non-experimental method InterpreterBuilder::AddDelegate to add delegates before constructing the interpreter, rather than using the experimental method Interpreter::ModifyGraphWithDelegate to set delegates after the interpreter is already constructed. A significant drawback of this refactoring is that if any delegates are enabled, we end up having to construct the interpreter twice: once without delegates, in order to determine whether the model contains any unresolved flex ops in order to decide whether to enable the flex delegate, and then again with all the delegates. PiperOrigin-RevId: 401003838 Change-Id: I8fdf8a4728a7a98677c5f8cea895447df234310b
1 parent 8452c0c commit 85ae893

File tree

2 files changed

+118
-50
lines changed

2 files changed

+118
-50
lines changed

tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -72,25 +72,30 @@ private void init(long errorHandle, long modelHandle, InterpreterImpl.Options op
7272
}
7373
this.errorHandle = errorHandle;
7474
this.modelHandle = modelHandle;
75-
// First create the interpreter without delegates.
76-
// We need this in order to figure out whether the model contains any unresolved flex ops.
75+
// First create the interpreter without delegates. We need an interpreter in order to figure
76+
// out whether the model contains any unresolved flex ops, and creating the interpreter with
77+
// delegates might fail if there are any unresolved flex ops.
7778
// (Alternatively, we could determine this without needing to recreate the interpreter
7879
// by passing the tflite::Model in to here, and then traversing that?)
79-
this.interpreterHandle = createInterpreter(modelHandle, errorHandle, options.numThreads);
80+
ArrayList<Long> delegateHandles = new ArrayList<Long>();
81+
this.interpreterHandle =
82+
createInterpreter(modelHandle, errorHandle, options.numThreads, delegateHandles);
8083
this.originalGraphHasUnresolvedFlexOp = hasUnresolvedFlexOp(interpreterHandle);
8184
addDelegates(options);
82-
// TODO(b/187920750): uncomment this when createInterpreter is modified to use
83-
// InterpreterBuilder::AddDelegate.
84-
// if (!delegates.isEmpty()) {
85-
// // If there are any delegates enabled, recreate the interpreter with those delegates.
86-
// delete(/* errorHandle= */ 0, /* modelHandle= */ 0, this.interpreterHandle);
87-
// this.interpreterHandle = createInterpreter(modelHandle, errorHandle, options.numThreads);
88-
// }
85+
delegateHandles.ensureCapacity(delegates.size());
86+
for (Delegate delegate : delegates) {
87+
delegateHandles.add(Long.valueOf(delegate.getNativeHandle()));
88+
}
89+
if (!delegateHandles.isEmpty()) {
90+
// If there are any delegates enabled, recreate the interpreter with those delegates.
91+
delete(/* errorHandle= */ 0, /* modelHandle= */ 0, this.interpreterHandle);
92+
this.interpreterHandle =
93+
createInterpreter(modelHandle, errorHandle, options.numThreads, delegateHandles);
94+
}
8995
if (options.allowFp16PrecisionForFp32 != null) {
9096
allowFp16PrecisionForFp32(
9197
interpreterHandle, options.allowFp16PrecisionForFp32.booleanValue());
9298
}
93-
applyDelegates(options);
9499
if (options.allowBufferHandleOutput != null) {
95100
allowBufferHandleOutput(interpreterHandle, options.allowBufferHandleOutput.booleanValue());
96101
}
@@ -524,13 +529,6 @@ private void maybeAddXnnpackDelegate(InterpreterImpl.Options options) {
524529
}
525530
}
526531

527-
// Apply all the delegates specified in this.delegates.
528-
private void applyDelegates(InterpreterImpl.Options options) {
529-
for (Delegate delegate : delegates) {
530-
applyDelegate(interpreterHandle, errorHandle, delegate.getNativeHandle());
531-
}
532-
}
533-
534532
private NativeSignatureRunnerWrapper getSignatureRunnerWrapper(String signatureKey) {
535533
if (signatureRunnerMap == null) {
536534
signatureRunnerMap = new HashMap<>();
@@ -643,10 +641,8 @@ private static native XnnpackDelegate createXNNPACKDelegate(
643641

644642
private static native long createModelWithBuffer(ByteBuffer modelBuffer, long errorHandle);
645643

646-
private static native long createInterpreter(long modelHandle, long errorHandle, int numThreads);
647-
648-
private static native void applyDelegate(
649-
long interpreterHandle, long errorHandle, long delegateHandle);
644+
private static native long createInterpreter(
645+
long modelHandle, long errorHandle, int numThreads, List<Long> delegateHandles);
650646

651647
private static native void resetVariableTensors(long interpreterHandle, long errorHandle);
652648

tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc

Lines changed: 100 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -494,25 +494,121 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer(
494494
JNIEXPORT jlong JNICALL
495495
Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter(
496496
JNIEnv* env, jclass clazz, jlong model_handle, jlong error_handle,
497-
jint num_threads) {
497+
jint num_threads, jobject delegate_handle_list) {
498498
if (!tflite::jni::CheckJniInitializedOrThrow(env)) return 0;
499499

500+
static jclass list_class = env->FindClass("java/util/List");
501+
if (list_class == nullptr) {
502+
if (!env->ExceptionCheck()) {
503+
ThrowException(env, tflite::jni::kUnsupportedOperationException,
504+
"Internal error: Can't find java.util.List class.");
505+
}
506+
return 0;
507+
}
508+
static jmethodID list_size_method =
509+
env->GetMethodID(list_class, "size", "()I");
510+
if (list_size_method == nullptr) {
511+
if (!env->ExceptionCheck()) {
512+
ThrowException(env, tflite::jni::kUnsupportedOperationException,
513+
"Internal error: Can't find java.util.List.size method.");
514+
}
515+
return 0;
516+
}
517+
static jmethodID list_get_method =
518+
env->GetMethodID(list_class, "get", "(I)Ljava/lang/Object;");
519+
if (list_get_method == nullptr) {
520+
if (!env->ExceptionCheck()) {
521+
ThrowException(env, tflite::jni::kUnsupportedOperationException,
522+
"Internal error: Can't find java.util.List.get method.");
523+
}
524+
return 0;
525+
}
526+
static jclass long_class = env->FindClass("java/lang/Long");
527+
if (long_class == nullptr) {
528+
if (!env->ExceptionCheck()) {
529+
ThrowException(env, tflite::jni::kUnsupportedOperationException,
530+
"Internal error: "
531+
"Can't find java.lang.Long class.");
532+
}
533+
return 0;
534+
}
535+
static jmethodID long_value_method =
536+
env->GetMethodID(long_class, "longValue", "()J");
537+
if (long_value_method == nullptr) {
538+
if (!env->ExceptionCheck()) {
539+
ThrowException(env, tflite::jni::kUnsupportedOperationException,
540+
"Internal error: "
541+
"Can't find java.lang.Long longValue method.");
542+
}
543+
return 0;
544+
}
545+
500546
FlatBufferModel* model = convertLongToModel(env, model_handle);
501547
if (model == nullptr) return 0;
548+
502549
BufferErrorReporter* error_reporter =
503550
convertLongToErrorReporter(env, error_handle);
504551
if (error_reporter == nullptr) return 0;
552+
505553
std::unique_ptr<OpResolver> resolver = tflite_shims::CreateOpResolver();
554+
506555
InterpreterBuilder interpreter_builder(*model, *resolver);
507556
interpreter_builder.SetNumThreads(static_cast<int>(num_threads));
557+
558+
// Add delegate_list to interpreter_builder.
559+
560+
// Java: int size = delegate_list.size();
561+
jint size = env->CallIntMethod(delegate_handle_list, list_size_method);
562+
for (jint i = 0; i < size; ++i) {
563+
// Java: Long jdelegate_handle = delegate_handle_list->get(i);
564+
jobject jdelegate_handle =
565+
env->CallObjectMethod(delegate_handle_list, list_get_method, i);
566+
if (jdelegate_handle == nullptr) {
567+
ThrowException(env, tflite::jni::kIllegalArgumentException,
568+
"Internal error: null Delegate handle");
569+
return 0;
570+
}
571+
// Java: long delegate_handle = jdelegate_handle.longValue();
572+
jlong delegate_handle =
573+
env->CallLongMethod(jdelegate_handle, long_value_method);
574+
if (delegate_handle == 0) {
575+
ThrowException(env, tflite::jni::kIllegalArgumentException,
576+
"Internal error: Found invalid handle");
577+
return 0;
578+
}
579+
auto delegate = reinterpret_cast<TfLiteOpaqueDelegate*>(delegate_handle);
580+
interpreter_builder.AddDelegate(delegate);
581+
}
582+
583+
// Create the Interpreter.
508584
std::unique_ptr<Interpreter> interpreter;
509585
TfLiteStatus status = interpreter_builder(&interpreter);
510586
if (status != kTfLiteOk) {
511-
ThrowException(env, tflite::jni::kIllegalArgumentException,
512-
"Internal error: Cannot create interpreter: %s",
513-
error_reporter->CachedErrorMessage());
587+
if (status == kTfLiteDelegateError) {
588+
ThrowException(env, tflite::jni::kIllegalArgumentException,
589+
"Internal error: Failed to apply delegate: %s",
590+
error_reporter->CachedErrorMessage());
591+
} else if (status == kTfLiteApplicationError) {
592+
ThrowException(env, tflite::jni::kIllegalArgumentException,
593+
"Internal error: Error applying delegate: %s",
594+
error_reporter->CachedErrorMessage());
595+
} else {
596+
const char* error_message = error_reporter->CachedErrorMessage();
597+
if (std::strcmp(
598+
error_message,
599+
"Restored original execution plan after delegate application "
600+
"failure.") == 0) {
601+
ThrowException(env, tflite::jni::kIllegalArgumentException,
602+
"Internal error: Failed to apply delegate.");
603+
} else {
604+
ThrowException(env, tflite::jni::kIllegalArgumentException,
605+
"Internal error: Cannot create interpreter: %s",
606+
error_message);
607+
}
608+
}
514609
return 0;
515610
}
611+
516612
// Note that tensor allocation is performed explicitly by the owning Java
517613
// NativeInterpreterWrapper instance.
518614
return reinterpret_cast<jlong>(interpreter.release());
@@ -598,30 +694,6 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_resizeInput(
598694
return is_changed ? JNI_TRUE : JNI_FALSE;
599695
}
600696

601-
JNIEXPORT void JNICALL
602-
Java_org_tensorflow_lite_NativeInterpreterWrapper_applyDelegate(
603-
JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle,
604-
jlong delegate_handle) {
605-
if (!tflite::jni::CheckJniInitializedOrThrow(env)) return;
606-
607-
Interpreter* interpreter = convertLongToInterpreter(env, interpreter_handle);
608-
if (interpreter == nullptr) return;
609-
610-
BufferErrorReporter* error_reporter =
611-
convertLongToErrorReporter(env, error_handle);
612-
if (error_reporter == nullptr) return;
613-
614-
TfLiteOpaqueDelegate* delegate = convertLongToDelegate(env, delegate_handle);
615-
if (delegate == nullptr) return;
616-
617-
TfLiteStatus status = interpreter->ModifyGraphWithDelegate(delegate);
618-
if (status != kTfLiteOk) {
619-
ThrowException(env, tflite::jni::kIllegalArgumentException,
620-
"Internal error: Failed to apply delegate: %s",
621-
error_reporter->CachedErrorMessage());
622-
}
623-
}
624-
625697
JNIEXPORT jlong JNICALL
626698
Java_org_tensorflow_lite_NativeInterpreterWrapper_createCancellationFlag(
627699
JNIEnv* env, jclass clazz, jlong interpreter_handle) {

0 commit comments

Comments
 (0)