Skip to content

Commit dd6373b

Browse files
committed
Refactor the rest of the JNI code into Java with JavaCPP
1 parent d4f2a7d commit dd6373b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+1774
-3916
lines changed

tensorflow-core/tensorflow-core-api/pom.xml

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -134,20 +134,6 @@
134134
<preloadPath>${project.basedir}/bazel-tensorflow-core-api/external/mkl_darwin/lib/</preloadPath>
135135
<preloadPath>${project.basedir}/bazel-tensorflow-core-api/external/mkl_windows/lib/</preloadPath>
136136
</preloadPaths>
137-
<compilerOptions>
138-
<compilerOption>${project.basedir}/src/main/native/eager_operation_builder_jni.cc</compilerOption>
139-
<compilerOption>${project.basedir}/src/main/native/eager_operation_jni.cc</compilerOption>
140-
<compilerOption>${project.basedir}/src/main/native/eager_session_jni.cc</compilerOption>
141-
<compilerOption>${project.basedir}/src/main/native/exception_jni.cc</compilerOption>
142-
<compilerOption>${project.basedir}/src/main/native/graph_jni.cc</compilerOption>
143-
<compilerOption>${project.basedir}/src/main/native/graph_operation_builder_jni.cc</compilerOption>
144-
<compilerOption>${project.basedir}/src/main/native/graph_operation_jni.cc</compilerOption>
145-
<compilerOption>${project.basedir}/src/main/native/server_jni.cc</compilerOption>
146-
<compilerOption>${project.basedir}/src/main/native/session_jni.cc</compilerOption>
147-
<compilerOption>${project.basedir}/src/main/native/tensorflow_jni.cc</compilerOption>
148-
<compilerOption>${project.basedir}/src/main/native/tensor_jni.cc</compilerOption>
149-
<compilerOption>${project.basedir}/src/main/native/utils_jni.cc</compilerOption>
150-
</compilerOptions>
151137
</configuration>
152138
<executions>
153139
<execution>

tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TFE_Context.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
//
1818
// TODO(ashankar): Merge with TF_Session?
1919
@Opaque @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
20-
public class TFE_Context extends Pointer {
20+
public class TFE_Context extends org.tensorflow.internal.c_api.AbstractTFE_Context {
2121
/** Empty constructor. Calls {@code super((Pointer)null)}. */
2222
public TFE_Context() { super((Pointer)null); }
2323
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */

tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TFE_ContextOptions.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
// #endif
1212

1313
@Opaque @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
14-
public class TFE_ContextOptions extends Pointer {
14+
public class TFE_ContextOptions extends org.tensorflow.internal.c_api.AbstractTFE_ContextOptions {
1515
/** Empty constructor. Calls {@code super((Pointer)null)}. */
1616
public TFE_ContextOptions() { super((Pointer)null); }
1717
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractOperation.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
package org.tensorflow;
1717

18+
import org.bytedeco.javacpp.Pointer;
1819
import org.tensorflow.tools.Shape;
1920
import org.tensorflow.types.family.TType;
2021

@@ -59,7 +60,7 @@ public String toString() {
5960
* @param outputIdx index of the output in this operation
6061
* @return a native handle, see method description for more details
6162
*/
62-
abstract long getUnsafeNativeHandle(int outputIdx);
63+
abstract Pointer getUnsafeNativeHandle(int outputIdx);
6364

6465
/**
6566
* Returns the shape of the tensor of the {@code outputIdx}th output of this operation.

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/DataType.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ int nativeCode() {
8080
* @return data structure of elements of this type
8181
*/
8282
T map(Tensor<T> tensor) {
83-
return tensorMapper.apply(tensor.getNative(), tensor.shape());
83+
return tensorMapper.apply(tensor.getNativeHandle(), tensor.shape());
8484
}
8585

8686
private final int nativeCode;

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java

Lines changed: 98 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,21 @@
1515

1616
package org.tensorflow;
1717

18+
import static org.tensorflow.internal.c_api.global.tensorflow.TFE_DeleteOp;
19+
import static org.tensorflow.internal.c_api.global.tensorflow.TFE_DeleteTensorHandle;
20+
import static org.tensorflow.internal.c_api.global.tensorflow.TFE_OpGetInputLength;
21+
import static org.tensorflow.internal.c_api.global.tensorflow.TFE_OpGetOutputLength;
22+
import static org.tensorflow.internal.c_api.global.tensorflow.TFE_TensorHandleDataType;
23+
import static org.tensorflow.internal.c_api.global.tensorflow.TFE_TensorHandleDim;
24+
import static org.tensorflow.internal.c_api.global.tensorflow.TFE_TensorHandleNumDims;
25+
import static org.tensorflow.internal.c_api.global.tensorflow.TFE_TensorHandleResolve;
26+
1827
import java.util.concurrent.atomic.AtomicReferenceArray;
28+
import org.bytedeco.javacpp.PointerScope;
29+
import org.tensorflow.internal.c_api.TFE_Op;
30+
import org.tensorflow.internal.c_api.TFE_TensorHandle;
31+
import org.tensorflow.internal.c_api.TF_Status;
32+
import org.tensorflow.internal.c_api.TF_Tensor;
1933
import org.tensorflow.tools.Shape;
2034

2135
/**
@@ -31,8 +45,8 @@ class EagerOperation extends AbstractOperation {
3145

3246
EagerOperation(
3347
EagerSession session,
34-
long opNativeHandle,
35-
long[] outputNativeHandles,
48+
TFE_Op opNativeHandle,
49+
TFE_TensorHandle[] outputNativeHandles,
3650
String type,
3751
String name) {
3852
this.session = session;
@@ -68,7 +82,7 @@ public int inputListLength(final String name) {
6882
}
6983

7084
@Override
71-
public long getUnsafeNativeHandle(int outputIndex) {
85+
public TFE_TensorHandle getUnsafeNativeHandle(int outputIndex) {
7286
return nativeRef.outputHandles[outputIndex];
7387
}
7488

@@ -80,7 +94,7 @@ public Shape shape(int outputIndex) {
8094
if (tensor != null) {
8195
return tensor.shape();
8296
}
83-
long outputNativeHandle = getUnsafeNativeHandle(outputIndex);
97+
TFE_TensorHandle outputNativeHandle = getUnsafeNativeHandle(outputIndex);
8498
long[] shape = new long[numDims(outputNativeHandle)];
8599
for (int i = 0; i < shape.length; ++i) {
86100
shape[i] = dim(outputNativeHandle, i);
@@ -96,7 +110,7 @@ public DataType<?> dtype(int outputIndex) {
96110
if (tensor != null) {
97111
return tensor.dataType();
98112
}
99-
long outputNativeHandle = getUnsafeNativeHandle(outputIndex);
113+
TFE_TensorHandle outputNativeHandle = getUnsafeNativeHandle(outputIndex);
100114
return DataTypes.fromNativeCode(dataType(outputNativeHandle));
101115
}
102116

@@ -119,7 +133,7 @@ private Tensor<?> resolveTensor(int outputIndex) {
119133
// Take an optimistic approach, where we attempt to resolve the output tensor without locking.
120134
// If another thread has resolved it meanwhile, release our copy and reuse the existing one
121135
// instead.
122-
long tensorNativeHandle = resolveTensorHandle(getUnsafeNativeHandle(outputIndex));
136+
TF_Tensor tensorNativeHandle = resolveTensorHandle(getUnsafeNativeHandle(outputIndex));
123137
Tensor<?> tensor = Tensor.fromHandle(tensorNativeHandle, session);
124138
if (!outputTensors.compareAndSet(outputIndex, null, tensor)) {
125139
tensor.close();
@@ -131,43 +145,104 @@ private Tensor<?> resolveTensor(int outputIndex) {
131145
private static class NativeReference extends EagerSession.NativeReference {
132146

133147
NativeReference(
134-
EagerSession session, EagerOperation operation, long opHandle, long[] outputHandles) {
148+
EagerSession session, EagerOperation operation, TFE_Op opHandle, TFE_TensorHandle[] outputHandles) {
135149
super(session, operation);
136150
this.opHandle = opHandle;
137151
this.outputHandles = outputHandles;
138152
}
139153

140154
@Override
141155
void delete() {
142-
if (opHandle != 0L) {
156+
if (opHandle != null && !opHandle.isNull()) {
143157
for (int i = 0; i < outputHandles.length; ++i) {
144-
if (outputHandles[i] != 0L) {
158+
if (outputHandles[i] != null && !outputHandles[i].isNull()) {
145159
EagerOperation.deleteTensorHandle(outputHandles[i]);
146-
outputHandles[i] = 0L;
160+
outputHandles[i] = null;
147161
}
148162
}
149163
EagerOperation.delete(opHandle);
150-
opHandle = 0L;
164+
opHandle = null;
151165
}
152166
}
153167

154-
private long opHandle;
155-
private final long[] outputHandles;
168+
private TFE_Op opHandle;
169+
private final TFE_TensorHandle[] outputHandles;
156170
}
157-
158-
private static native void delete(long handle);
159171

160-
private static native void deleteTensorHandle(long handle);
172+
private static void requireOp(TFE_Op handle) {
173+
if (handle == null || handle.isNull()) {
174+
throw new IllegalStateException("Eager session has been closed");
175+
}
176+
}
161177

162-
private static native long resolveTensorHandle(long handle);
178+
private static void requireTensorHandle(TFE_TensorHandle handle) {
179+
if (handle == null || handle.isNull()) {
180+
throw new IllegalStateException("EagerSession has been closed");
181+
}
182+
}
163183

164-
private static native int outputListLength(long handle, String name);
184+
private static void delete(TFE_Op handle) {
185+
if (handle == null || handle.isNull()) return;
186+
TFE_DeleteOp(handle);
187+
}
165188

166-
private static native int inputListLength(long handle, String name);
189+
private static void deleteTensorHandle(TFE_TensorHandle handle) {
190+
if (handle == null || handle.isNull()) return;
191+
TFE_DeleteTensorHandle(handle);
192+
}
167193

168-
private static native int dataType(long handle);
194+
private static TF_Tensor resolveTensorHandle(TFE_TensorHandle handle) {
195+
requireTensorHandle(handle);
196+
try (PointerScope scope = new PointerScope()) {
197+
TF_Status status = TF_Status.newStatus();
198+
TF_Tensor tensor = TFE_TensorHandleResolve(handle, status);
199+
status.throwExceptionIfNotOK();
200+
return tensor;
201+
}
202+
}
169203

170-
private static native int numDims(long handle);
204+
private static int outputListLength(TFE_Op handle, String name) {
205+
requireOp(handle);
206+
try (PointerScope scope = new PointerScope()) {
207+
TF_Status status = TF_Status.newStatus();
208+
int length = TFE_OpGetOutputLength(handle, name, status);
209+
status.throwExceptionIfNotOK();
210+
return length;
211+
}
212+
}
171213

172-
private static native long dim(long handle, int index);
173-
}
214+
private static int inputListLength(TFE_Op handle, String name) {
215+
requireOp(handle);
216+
try (PointerScope scope = new PointerScope()) {
217+
TF_Status status = TF_Status.newStatus();
218+
int length = TFE_OpGetInputLength(handle, name, status);
219+
status.throwExceptionIfNotOK();
220+
return length;
221+
}
222+
}
223+
224+
private static int dataType(TFE_TensorHandle handle) {
225+
requireTensorHandle(handle);
226+
return TFE_TensorHandleDataType(handle);
227+
}
228+
229+
private static int numDims(TFE_TensorHandle handle) {
230+
requireTensorHandle(handle);
231+
try (PointerScope scope = new PointerScope()) {
232+
TF_Status status = TF_Status.newStatus();
233+
int numDims = TFE_TensorHandleNumDims(handle, status);
234+
status.throwExceptionIfNotOK();
235+
return numDims;
236+
}
237+
}
238+
239+
private static long dim(TFE_TensorHandle handle, int index) {
240+
requireTensorHandle(handle);
241+
try (PointerScope scope = new PointerScope()) {
242+
TF_Status status = TF_Status.newStatus();
243+
long dim = TFE_TensorHandleDim(handle, index, status);
244+
status.throwExceptionIfNotOK();
245+
return dim;
246+
}
247+
}
248+
}

0 commit comments

Comments
 (0)