Skip to content

Refactor JNI code in C++ into Java code with JavaCPP #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 0 additions & 15 deletions tensorflow-core/tensorflow-core-api/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -140,21 +140,6 @@
<preloadPath>${project.basedir}/bazel-${project.artifactId}/external/mkl_darwin/lib/</preloadPath>
<preloadPath>${project.basedir}/bazel-${project.artifactId}/external/mkl_windows/lib/</preloadPath>
</preloadPaths>
<compilerOptions>
<compilerOption>${project.basedir}/src/main/native/eager_operation_builder_jni.cc</compilerOption>
<compilerOption>${project.basedir}/src/main/native/eager_operation_jni.cc</compilerOption>
<compilerOption>${project.basedir}/src/main/native/eager_session_jni.cc</compilerOption>
<compilerOption>${project.basedir}/src/main/native/exception_jni.cc</compilerOption>
<compilerOption>${project.basedir}/src/main/native/graph_jni.cc</compilerOption>
<compilerOption>${project.basedir}/src/main/native/graph_operation_builder_jni.cc</compilerOption>
<compilerOption>${project.basedir}/src/main/native/graph_operation_jni.cc</compilerOption>
<compilerOption>${project.basedir}/src/main/native/saved_model_bundle_jni.cc</compilerOption>
<compilerOption>${project.basedir}/src/main/native/server_jni.cc</compilerOption>
<compilerOption>${project.basedir}/src/main/native/session_jni.cc</compilerOption>
<compilerOption>${project.basedir}/src/main/native/tensorflow_jni.cc</compilerOption>
<compilerOption>${project.basedir}/src/main/native/tensor_jni.cc</compilerOption>
<compilerOption>${project.basedir}/src/main/native/utils_jni.cc</compilerOption>
</compilerOptions>
</configuration>
<executions>
<execution>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
//
// TODO(ashankar): Merge with TF_Session?
@Opaque @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
public class TFE_Context extends Pointer {
public class TFE_Context extends org.tensorflow.internal.c_api.AbstractTFE_Context {
/** Empty constructor. Calls {@code super((Pointer)null)}. */
public TFE_Context() { super((Pointer)null); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
// #endif

@Opaque @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
public class TFE_ContextOptions extends Pointer {
public class TFE_ContextOptions extends org.tensorflow.internal.c_api.AbstractTFE_ContextOptions {
/** Empty constructor. Calls {@code super((Pointer)null)}. */
public TFE_ContextOptions() { super((Pointer)null); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

package org.tensorflow;

import org.bytedeco.javacpp.Pointer;
import org.tensorflow.tools.Shape;
import org.tensorflow.types.family.TType;

Expand Down Expand Up @@ -59,7 +60,7 @@ public String toString() {
* @param outputIdx index of the output in this operation
* @return a native handle, see method description for more details
*/
abstract long getUnsafeNativeHandle(int outputIdx);
abstract Pointer getUnsafeNativeHandle(int outputIdx);

/**
* Returns the shape of the tensor of the {@code outputIdx}th output of this operation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ int nativeCode() {
* @return data structure of elements of this type
*/
T map(Tensor<T> tensor) {
return tensorMapper.apply(tensor.getNative(), tensor.shape());
return tensorMapper.apply(tensor.getNativeHandle(), tensor.shape());
}

private final int nativeCode;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,21 @@

package org.tensorflow;

import static org.tensorflow.internal.c_api.global.tensorflow.TFE_DeleteOp;
import static org.tensorflow.internal.c_api.global.tensorflow.TFE_DeleteTensorHandle;
import static org.tensorflow.internal.c_api.global.tensorflow.TFE_OpGetInputLength;
import static org.tensorflow.internal.c_api.global.tensorflow.TFE_OpGetOutputLength;
import static org.tensorflow.internal.c_api.global.tensorflow.TFE_TensorHandleDataType;
import static org.tensorflow.internal.c_api.global.tensorflow.TFE_TensorHandleDim;
import static org.tensorflow.internal.c_api.global.tensorflow.TFE_TensorHandleNumDims;
import static org.tensorflow.internal.c_api.global.tensorflow.TFE_TensorHandleResolve;

import java.util.concurrent.atomic.AtomicReferenceArray;
import org.bytedeco.javacpp.PointerScope;
import org.tensorflow.internal.c_api.TFE_Op;
import org.tensorflow.internal.c_api.TFE_TensorHandle;
import org.tensorflow.internal.c_api.TF_Status;
import org.tensorflow.internal.c_api.TF_Tensor;
import org.tensorflow.tools.Shape;

/**
Expand All @@ -31,8 +45,8 @@ class EagerOperation extends AbstractOperation {

EagerOperation(
EagerSession session,
long opNativeHandle,
long[] outputNativeHandles,
TFE_Op opNativeHandle,
TFE_TensorHandle[] outputNativeHandles,
String type,
String name) {
this.session = session;
Expand Down Expand Up @@ -68,7 +82,7 @@ public int inputListLength(final String name) {
}

@Override
public long getUnsafeNativeHandle(int outputIndex) {
public TFE_TensorHandle getUnsafeNativeHandle(int outputIndex) {
return nativeRef.outputHandles[outputIndex];
}

Expand All @@ -80,7 +94,7 @@ public Shape shape(int outputIndex) {
if (tensor != null) {
return tensor.shape();
}
long outputNativeHandle = getUnsafeNativeHandle(outputIndex);
TFE_TensorHandle outputNativeHandle = getUnsafeNativeHandle(outputIndex);
long[] shape = new long[numDims(outputNativeHandle)];
for (int i = 0; i < shape.length; ++i) {
shape[i] = dim(outputNativeHandle, i);
Expand All @@ -96,7 +110,7 @@ public DataType<?> dtype(int outputIndex) {
if (tensor != null) {
return tensor.dataType();
}
long outputNativeHandle = getUnsafeNativeHandle(outputIndex);
TFE_TensorHandle outputNativeHandle = getUnsafeNativeHandle(outputIndex);
return DataTypes.fromNativeCode(dataType(outputNativeHandle));
}

Expand All @@ -119,7 +133,7 @@ private Tensor<?> resolveTensor(int outputIndex) {
// Take an optimistic approach, where we attempt to resolve the output tensor without locking.
// If another thread has resolved it meanwhile, release our copy and reuse the existing one
// instead.
long tensorNativeHandle = resolveTensorHandle(getUnsafeNativeHandle(outputIndex));
TF_Tensor tensorNativeHandle = resolveTensorHandle(getUnsafeNativeHandle(outputIndex));
Tensor<?> tensor = Tensor.fromHandle(tensorNativeHandle, session);
if (!outputTensors.compareAndSet(outputIndex, null, tensor)) {
tensor.close();
Expand All @@ -131,43 +145,104 @@ private Tensor<?> resolveTensor(int outputIndex) {
private static class NativeReference extends EagerSession.NativeReference {

NativeReference(
EagerSession session, EagerOperation operation, long opHandle, long[] outputHandles) {
EagerSession session, EagerOperation operation, TFE_Op opHandle, TFE_TensorHandle[] outputHandles) {
super(session, operation);
this.opHandle = opHandle;
this.outputHandles = outputHandles;
}

@Override
void delete() {
if (opHandle != 0L) {
if (opHandle != null && !opHandle.isNull()) {
for (int i = 0; i < outputHandles.length; ++i) {
if (outputHandles[i] != 0L) {
if (outputHandles[i] != null && !outputHandles[i].isNull()) {
EagerOperation.deleteTensorHandle(outputHandles[i]);
outputHandles[i] = 0L;
outputHandles[i] = null;
}
}
EagerOperation.delete(opHandle);
opHandle = 0L;
opHandle = null;
}
}

private long opHandle;
private final long[] outputHandles;
private TFE_Op opHandle;
private final TFE_TensorHandle[] outputHandles;
}

private static native void delete(long handle);

private static native void deleteTensorHandle(long handle);
private static void requireOp(TFE_Op handle) {
if (handle == null || handle.isNull()) {
throw new IllegalStateException("Eager session has been closed");
}
}

private static native long resolveTensorHandle(long handle);
private static void requireTensorHandle(TFE_TensorHandle handle) {
if (handle == null || handle.isNull()) {
throw new IllegalStateException("Eager session has been closed");
}
}

private static native int outputListLength(long handle, String name);
private static void delete(TFE_Op handle) {
if (handle == null || handle.isNull()) return;
TFE_DeleteOp(handle);
}

private static native int inputListLength(long handle, String name);
private static void deleteTensorHandle(TFE_TensorHandle handle) {
if (handle == null || handle.isNull()) return;
TFE_DeleteTensorHandle(handle);
}

private static native int dataType(long handle);
private static TF_Tensor resolveTensorHandle(TFE_TensorHandle handle) {
requireTensorHandle(handle);
try (PointerScope scope = new PointerScope()) {
TF_Status status = TF_Status.newStatus();
TF_Tensor tensor = TFE_TensorHandleResolve(handle, status);
status.throwExceptionIfNotOK();
return tensor;
}
}

private static native int numDims(long handle);
private static int outputListLength(TFE_Op handle, String name) {
requireOp(handle);
try (PointerScope scope = new PointerScope()) {
TF_Status status = TF_Status.newStatus();
int length = TFE_OpGetOutputLength(handle, name, status);
status.throwExceptionIfNotOK();
return length;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of starting new scopes in all of these methods, could it be simpler and more efficient to just create the status in a try-with-resource block when there are not other resource allocated?

try (TF_Status status = TF_Status.newStatus()) { 
    ... 
}

Copy link
Contributor Author

@saudet saudet Jan 29, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it would be more efficient, but it would also make it more error-prone when we start creating other objects in there that may start doing temporary allocations.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah... I would still prefer we go with the most efficient approach but it's up to you if you want to make the changes or not, we can merge it like this too.

}

private static native long dim(long handle, int index);
}
private static int inputListLength(TFE_Op handle, String name) {
requireOp(handle);
try (PointerScope scope = new PointerScope()) {
TF_Status status = TF_Status.newStatus();
int length = TFE_OpGetInputLength(handle, name, status);
status.throwExceptionIfNotOK();
return length;
}
}

private static int dataType(TFE_TensorHandle handle) {
requireTensorHandle(handle);
return TFE_TensorHandleDataType(handle);
}

private static int numDims(TFE_TensorHandle handle) {
requireTensorHandle(handle);
try (PointerScope scope = new PointerScope()) {
TF_Status status = TF_Status.newStatus();
int numDims = TFE_TensorHandleNumDims(handle, status);
status.throwExceptionIfNotOK();
return numDims;
}
}

private static long dim(TFE_TensorHandle handle, int index) {
requireTensorHandle(handle);
try (PointerScope scope = new PointerScope()) {
TF_Status status = TF_Status.newStatus();
long dim = TFE_TensorHandleDim(handle, index, status);
status.throwExceptionIfNotOK();
return dim;
}
}
}
Loading