-
Notifications
You must be signed in to change notification settings - Fork 214
Refactor JNI code in C++ into Java code with JavaCPP #18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
534e9d4
e69874c
b8ff850
55b7b5b
47125ef
1e22569
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,7 +15,21 @@ | |
|
||
package org.tensorflow; | ||
|
||
import static org.tensorflow.internal.c_api.global.tensorflow.TFE_DeleteOp; | ||
import static org.tensorflow.internal.c_api.global.tensorflow.TFE_DeleteTensorHandle; | ||
import static org.tensorflow.internal.c_api.global.tensorflow.TFE_OpGetInputLength; | ||
import static org.tensorflow.internal.c_api.global.tensorflow.TFE_OpGetOutputLength; | ||
import static org.tensorflow.internal.c_api.global.tensorflow.TFE_TensorHandleDataType; | ||
import static org.tensorflow.internal.c_api.global.tensorflow.TFE_TensorHandleDim; | ||
import static org.tensorflow.internal.c_api.global.tensorflow.TFE_TensorHandleNumDims; | ||
import static org.tensorflow.internal.c_api.global.tensorflow.TFE_TensorHandleResolve; | ||
|
||
import java.util.concurrent.atomic.AtomicReferenceArray; | ||
import org.bytedeco.javacpp.PointerScope; | ||
import org.tensorflow.internal.c_api.TFE_Op; | ||
import org.tensorflow.internal.c_api.TFE_TensorHandle; | ||
import org.tensorflow.internal.c_api.TF_Status; | ||
import org.tensorflow.internal.c_api.TF_Tensor; | ||
import org.tensorflow.tools.Shape; | ||
|
||
/** | ||
|
@@ -31,8 +45,8 @@ class EagerOperation extends AbstractOperation { | |
|
||
EagerOperation( | ||
EagerSession session, | ||
long opNativeHandle, | ||
long[] outputNativeHandles, | ||
TFE_Op opNativeHandle, | ||
TFE_TensorHandle[] outputNativeHandles, | ||
String type, | ||
String name) { | ||
this.session = session; | ||
|
@@ -68,7 +82,7 @@ public int inputListLength(final String name) { | |
} | ||
|
||
@Override | ||
public long getUnsafeNativeHandle(int outputIndex) { | ||
public TFE_TensorHandle getUnsafeNativeHandle(int outputIndex) { | ||
return nativeRef.outputHandles[outputIndex]; | ||
} | ||
|
||
|
@@ -80,7 +94,7 @@ public Shape shape(int outputIndex) { | |
if (tensor != null) { | ||
return tensor.shape(); | ||
} | ||
long outputNativeHandle = getUnsafeNativeHandle(outputIndex); | ||
TFE_TensorHandle outputNativeHandle = getUnsafeNativeHandle(outputIndex); | ||
long[] shape = new long[numDims(outputNativeHandle)]; | ||
for (int i = 0; i < shape.length; ++i) { | ||
shape[i] = dim(outputNativeHandle, i); | ||
|
@@ -96,7 +110,7 @@ public DataType<?> dtype(int outputIndex) { | |
if (tensor != null) { | ||
return tensor.dataType(); | ||
} | ||
long outputNativeHandle = getUnsafeNativeHandle(outputIndex); | ||
TFE_TensorHandle outputNativeHandle = getUnsafeNativeHandle(outputIndex); | ||
return DataTypes.fromNativeCode(dataType(outputNativeHandle)); | ||
} | ||
|
||
|
@@ -119,7 +133,7 @@ private Tensor<?> resolveTensor(int outputIndex) { | |
// Take an optimistic approach, where we attempt to resolve the output tensor without locking. | ||
// If another thread has resolved it meanwhile, release our copy and reuse the existing one | ||
// instead. | ||
long tensorNativeHandle = resolveTensorHandle(getUnsafeNativeHandle(outputIndex)); | ||
TF_Tensor tensorNativeHandle = resolveTensorHandle(getUnsafeNativeHandle(outputIndex)); | ||
Tensor<?> tensor = Tensor.fromHandle(tensorNativeHandle, session); | ||
if (!outputTensors.compareAndSet(outputIndex, null, tensor)) { | ||
tensor.close(); | ||
|
@@ -131,43 +145,104 @@ private Tensor<?> resolveTensor(int outputIndex) { | |
private static class NativeReference extends EagerSession.NativeReference { | ||
|
||
NativeReference( | ||
EagerSession session, EagerOperation operation, long opHandle, long[] outputHandles) { | ||
EagerSession session, EagerOperation operation, TFE_Op opHandle, TFE_TensorHandle[] outputHandles) { | ||
super(session, operation); | ||
this.opHandle = opHandle; | ||
this.outputHandles = outputHandles; | ||
} | ||
|
||
@Override | ||
void delete() { | ||
if (opHandle != 0L) { | ||
if (opHandle != null && !opHandle.isNull()) { | ||
for (int i = 0; i < outputHandles.length; ++i) { | ||
if (outputHandles[i] != 0L) { | ||
if (outputHandles[i] != null && !outputHandles[i].isNull()) { | ||
EagerOperation.deleteTensorHandle(outputHandles[i]); | ||
outputHandles[i] = 0L; | ||
outputHandles[i] = null; | ||
} | ||
} | ||
EagerOperation.delete(opHandle); | ||
opHandle = 0L; | ||
opHandle = null; | ||
} | ||
} | ||
|
||
private long opHandle; | ||
private final long[] outputHandles; | ||
private TFE_Op opHandle; | ||
private final TFE_TensorHandle[] outputHandles; | ||
} | ||
|
||
private static native void delete(long handle); | ||
|
||
private static native void deleteTensorHandle(long handle); | ||
private static void requireOp(TFE_Op handle) { | ||
if (handle == null || handle.isNull()) { | ||
throw new IllegalStateException("Eager session has been closed"); | ||
} | ||
} | ||
|
||
private static native long resolveTensorHandle(long handle); | ||
private static void requireTensorHandle(TFE_TensorHandle handle) { | ||
if (handle == null || handle.isNull()) { | ||
throw new IllegalStateException("Eager session has been closed"); | ||
} | ||
} | ||
|
||
private static native int outputListLength(long handle, String name); | ||
private static void delete(TFE_Op handle) { | ||
if (handle == null || handle.isNull()) return; | ||
TFE_DeleteOp(handle); | ||
} | ||
|
||
private static native int inputListLength(long handle, String name); | ||
private static void deleteTensorHandle(TFE_TensorHandle handle) { | ||
if (handle == null || handle.isNull()) return; | ||
TFE_DeleteTensorHandle(handle); | ||
} | ||
|
||
private static native int dataType(long handle); | ||
private static TF_Tensor resolveTensorHandle(TFE_TensorHandle handle) { | ||
requireTensorHandle(handle); | ||
try (PointerScope scope = new PointerScope()) { | ||
TF_Status status = TF_Status.newStatus(); | ||
TF_Tensor tensor = TFE_TensorHandleResolve(handle, status); | ||
status.throwExceptionIfNotOK(); | ||
return tensor; | ||
karllessard marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
} | ||
|
||
private static native int numDims(long handle); | ||
private static int outputListLength(TFE_Op handle, String name) { | ||
requireOp(handle); | ||
try (PointerScope scope = new PointerScope()) { | ||
TF_Status status = TF_Status.newStatus(); | ||
int length = TFE_OpGetOutputLength(handle, name, status); | ||
status.throwExceptionIfNotOK(); | ||
return length; | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of starting new scopes in all of these methods, could it be simpler and more efficient to just create the status in a try (TF_Status status = TF_Status.newStatus()) {
...
} There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it would be more efficient, but it would also make it more error-prone when we start creating other objects in there that may start doing temporary allocations. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah... I would still prefer we go with the most efficient approach but it's up to you if you want to make the changes or not, we can merge it like this too. |
||
} | ||
|
||
private static native long dim(long handle, int index); | ||
} | ||
private static int inputListLength(TFE_Op handle, String name) { | ||
requireOp(handle); | ||
try (PointerScope scope = new PointerScope()) { | ||
TF_Status status = TF_Status.newStatus(); | ||
int length = TFE_OpGetInputLength(handle, name, status); | ||
status.throwExceptionIfNotOK(); | ||
return length; | ||
} | ||
} | ||
|
||
private static int dataType(TFE_TensorHandle handle) { | ||
requireTensorHandle(handle); | ||
return TFE_TensorHandleDataType(handle); | ||
} | ||
|
||
private static int numDims(TFE_TensorHandle handle) { | ||
requireTensorHandle(handle); | ||
try (PointerScope scope = new PointerScope()) { | ||
TF_Status status = TF_Status.newStatus(); | ||
int numDims = TFE_TensorHandleNumDims(handle, status); | ||
status.throwExceptionIfNotOK(); | ||
return numDims; | ||
} | ||
} | ||
|
||
private static long dim(TFE_TensorHandle handle, int index) { | ||
requireTensorHandle(handle); | ||
try (PointerScope scope = new PointerScope()) { | ||
TF_Status status = TF_Status.newStatus(); | ||
long dim = TFE_TensorHandleDim(handle, index, status); | ||
status.throwExceptionIfNotOK(); | ||
return dim; | ||
} | ||
} | ||
} |
Uh oh!
There was an error while loading. Please reload this page.