15
15
16
16
package org .tensorflow ;
17
17
18
+ import static org .tensorflow .internal .c_api .global .tensorflow .TFE_DeleteOp ;
19
+ import static org .tensorflow .internal .c_api .global .tensorflow .TFE_DeleteTensorHandle ;
20
+ import static org .tensorflow .internal .c_api .global .tensorflow .TFE_OpGetInputLength ;
21
+ import static org .tensorflow .internal .c_api .global .tensorflow .TFE_OpGetOutputLength ;
22
+ import static org .tensorflow .internal .c_api .global .tensorflow .TFE_TensorHandleDataType ;
23
+ import static org .tensorflow .internal .c_api .global .tensorflow .TFE_TensorHandleDim ;
24
+ import static org .tensorflow .internal .c_api .global .tensorflow .TFE_TensorHandleNumDims ;
25
+ import static org .tensorflow .internal .c_api .global .tensorflow .TFE_TensorHandleResolve ;
26
+
18
27
import java .util .concurrent .atomic .AtomicReferenceArray ;
28
+ import org .bytedeco .javacpp .PointerScope ;
29
+ import org .tensorflow .internal .c_api .TFE_Op ;
30
+ import org .tensorflow .internal .c_api .TFE_TensorHandle ;
31
+ import org .tensorflow .internal .c_api .TF_Status ;
32
+ import org .tensorflow .internal .c_api .TF_Tensor ;
19
33
import org .tensorflow .tools .Shape ;
20
34
21
35
/**
@@ -31,8 +45,8 @@ class EagerOperation extends AbstractOperation {
31
45
32
46
EagerOperation (
33
47
EagerSession session ,
34
- long opNativeHandle ,
35
- long [] outputNativeHandles ,
48
+ TFE_Op opNativeHandle ,
49
+ TFE_TensorHandle [] outputNativeHandles ,
36
50
String type ,
37
51
String name ) {
38
52
this .session = session ;
@@ -68,7 +82,7 @@ public int inputListLength(final String name) {
68
82
}
69
83
70
84
@ Override
71
- public long getUnsafeNativeHandle (int outputIndex ) {
85
+ public TFE_TensorHandle getUnsafeNativeHandle (int outputIndex ) {
72
86
return nativeRef .outputHandles [outputIndex ];
73
87
}
74
88
@@ -80,7 +94,7 @@ public Shape shape(int outputIndex) {
80
94
if (tensor != null ) {
81
95
return tensor .shape ();
82
96
}
83
- long outputNativeHandle = getUnsafeNativeHandle (outputIndex );
97
+ TFE_TensorHandle outputNativeHandle = getUnsafeNativeHandle (outputIndex );
84
98
long [] shape = new long [numDims (outputNativeHandle )];
85
99
for (int i = 0 ; i < shape .length ; ++i ) {
86
100
shape [i ] = dim (outputNativeHandle , i );
@@ -96,7 +110,7 @@ public DataType<?> dtype(int outputIndex) {
96
110
if (tensor != null ) {
97
111
return tensor .dataType ();
98
112
}
99
- long outputNativeHandle = getUnsafeNativeHandle (outputIndex );
113
+ TFE_TensorHandle outputNativeHandle = getUnsafeNativeHandle (outputIndex );
100
114
return DataTypes .fromNativeCode (dataType (outputNativeHandle ));
101
115
}
102
116
@@ -119,7 +133,7 @@ private Tensor<?> resolveTensor(int outputIndex) {
119
133
// Take an optimistic approach, where we attempt to resolve the output tensor without locking.
120
134
// If another thread has resolved it meanwhile, release our copy and reuse the existing one
121
135
// instead.
122
- long tensorNativeHandle = resolveTensorHandle (getUnsafeNativeHandle (outputIndex ));
136
+ TF_Tensor tensorNativeHandle = resolveTensorHandle (getUnsafeNativeHandle (outputIndex ));
123
137
Tensor <?> tensor = Tensor .fromHandle (tensorNativeHandle , session );
124
138
if (!outputTensors .compareAndSet (outputIndex , null , tensor )) {
125
139
tensor .close ();
@@ -131,43 +145,104 @@ private Tensor<?> resolveTensor(int outputIndex) {
131
145
private static class NativeReference extends EagerSession .NativeReference {
132
146
133
147
NativeReference (
134
- EagerSession session , EagerOperation operation , long opHandle , long [] outputHandles ) {
148
+ EagerSession session , EagerOperation operation , TFE_Op opHandle , TFE_TensorHandle [] outputHandles ) {
135
149
super (session , operation );
136
150
this .opHandle = opHandle ;
137
151
this .outputHandles = outputHandles ;
138
152
}
139
153
140
154
@ Override
141
155
void delete () {
142
- if (opHandle != 0L ) {
156
+ if (opHandle != null && ! opHandle . isNull () ) {
143
157
for (int i = 0 ; i < outputHandles .length ; ++i ) {
144
- if (outputHandles [i ] != 0L ) {
158
+ if (outputHandles [i ] != null && ! outputHandles [ i ]. isNull () ) {
145
159
EagerOperation .deleteTensorHandle (outputHandles [i ]);
146
- outputHandles [i ] = 0L ;
160
+ outputHandles [i ] = null ;
147
161
}
148
162
}
149
163
EagerOperation .delete (opHandle );
150
- opHandle = 0L ;
164
+ opHandle = null ;
151
165
}
152
166
}
153
167
154
- private long opHandle ;
155
- private final long [] outputHandles ;
168
+ private TFE_Op opHandle ;
169
+ private final TFE_TensorHandle [] outputHandles ;
156
170
}
157
-
158
- private static native void delete (long handle );
159
171
160
- private static native void deleteTensorHandle (long handle );
172
+ private static void requireOp (TFE_Op handle ) {
173
+ if (handle == null || handle .isNull ()) {
174
+ throw new IllegalStateException ("Eager session has been closed" );
175
+ }
176
+ }
161
177
162
- private static native long resolveTensorHandle (long handle );
178
+ private static void requireTensorHandle (TFE_TensorHandle handle ) {
179
+ if (handle == null || handle .isNull ()) {
180
+ throw new IllegalStateException ("EagerSession has been closed" );
181
+ }
182
+ }
163
183
164
- private static native int outputListLength (long handle , String name );
184
+ private static void delete (TFE_Op handle ) {
185
+ if (handle == null || handle .isNull ()) return ;
186
+ TFE_DeleteOp (handle );
187
+ }
165
188
166
- private static native int inputListLength (long handle , String name );
189
+ private static void deleteTensorHandle (TFE_TensorHandle handle ) {
190
+ if (handle == null || handle .isNull ()) return ;
191
+ TFE_DeleteTensorHandle (handle );
192
+ }
167
193
168
- private static native int dataType (long handle );
194
+ private static TF_Tensor resolveTensorHandle (TFE_TensorHandle handle ) {
195
+ requireTensorHandle (handle );
196
+ try (PointerScope scope = new PointerScope ()) {
197
+ TF_Status status = TF_Status .newStatus ();
198
+ TF_Tensor tensor = TFE_TensorHandleResolve (handle , status );
199
+ status .throwExceptionIfNotOK ();
200
+ return tensor ;
201
+ }
202
+ }
169
203
170
- private static native int numDims (long handle );
204
+ private static int outputListLength (TFE_Op handle , String name ) {
205
+ requireOp (handle );
206
+ try (PointerScope scope = new PointerScope ()) {
207
+ TF_Status status = TF_Status .newStatus ();
208
+ int length = TFE_OpGetOutputLength (handle , name , status );
209
+ status .throwExceptionIfNotOK ();
210
+ return length ;
211
+ }
212
+ }
171
213
172
- private static native long dim (long handle , int index );
173
- }
214
+ private static int inputListLength (TFE_Op handle , String name ) {
215
+ requireOp (handle );
216
+ try (PointerScope scope = new PointerScope ()) {
217
+ TF_Status status = TF_Status .newStatus ();
218
+ int length = TFE_OpGetInputLength (handle , name , status );
219
+ status .throwExceptionIfNotOK ();
220
+ return length ;
221
+ }
222
+ }
223
+
224
+ private static int dataType (TFE_TensorHandle handle ) {
225
+ requireTensorHandle (handle );
226
+ return TFE_TensorHandleDataType (handle );
227
+ }
228
+
229
+ private static int numDims (TFE_TensorHandle handle ) {
230
+ requireTensorHandle (handle );
231
+ try (PointerScope scope = new PointerScope ()) {
232
+ TF_Status status = TF_Status .newStatus ();
233
+ int numDims = TFE_TensorHandleNumDims (handle , status );
234
+ status .throwExceptionIfNotOK ();
235
+ return numDims ;
236
+ }
237
+ }
238
+
239
+ private static long dim (TFE_TensorHandle handle , int index ) {
240
+ requireTensorHandle (handle );
241
+ try (PointerScope scope = new PointerScope ()) {
242
+ TF_Status status = TF_Status .newStatus ();
243
+ long dim = TFE_TensorHandleDim (handle , index , status );
244
+ status .throwExceptionIfNotOK ();
245
+ return dim ;
246
+ }
247
+ }
248
+ }
0 commit comments