Skip to content

Commit a7903cd

Browse files
committed
rebase fixes
Signed-off-by: Ryan Nett <rnett@calpoly.edu>
1 parent 7fde8d7 commit a7903cd

File tree

8 files changed

+44
-22
lines changed

8 files changed

+44
-22
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,18 @@ public RawTensor asRawTensor() {
6565

6666
@Override
6767
public void close() {
68+
if(closed) {
69+
throw new IllegalStateException("Tensor has already been closed");
70+
}
6871
tensorScope.close();
72+
closed = true;
73+
}
74+
75+
/**
76+
* @return {@code true} if this tensor has been closed;
77+
*/
78+
public boolean isClosed() {
79+
return closed;
6980
}
7081

7182
/**
@@ -212,6 +223,7 @@ private static long[] shape(TF_Tensor handle) {
212223
}
213224

214225
private PointerScope tensorScope;
226+
private boolean closed = false;
215227
private TF_Tensor tensorHandle;
216228
private final TensorTypeInfo<? extends TType> typeInfo;
217229
private final Shape shape;

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -150,10 +150,10 @@ public void close() {
150150
* <p>
151151
* Closing a {@code Result} object will close all of the tensors contained by it.
152152
*/
153-
public final class Result implements AutoCloseable, Iterable<Tensor<?>>{
154-
private final List<Tensor<?>> results;
153+
public final class Result implements AutoCloseable, Iterable<Tensor>{
154+
private final List<Tensor> results;
155155
private final List<Output<?>> fetches;
156-
private final LinkedHashMap<Output<?>, Tensor<?>> outputMap;
156+
private final LinkedHashMap<Output<?>, Tensor> outputMap;
157157

158158
/**
159159
* Metadata about the run.
@@ -166,7 +166,7 @@ public final class Result implements AutoCloseable, Iterable<Tensor<?>>{
166166

167167
private boolean closed = false;
168168

169-
private Result(List<Tensor<?>> results, List<Output<?>> fetches, RunMetadata metadata) {
169+
private Result(List<Tensor> results, List<Output<?>> fetches, RunMetadata metadata) {
170170

171171
if(results.size() != fetches.size()){
172172
throw new IllegalArgumentException("Expected the same number of fetches and values, got " + fetches.size()
@@ -191,7 +191,7 @@ private void requireOpen(){
191191
/**
192192
* Get the result tensors.
193193
*/
194-
public List<Tensor<?>> getResults() {
194+
public List<Tensor> getResults() {
195195
requireOpen();
196196
return Collections.unmodifiableList(results);
197197
}
@@ -206,7 +206,7 @@ public List<Output<?>> getFetches() {
206206
/**
207207
* Get a map of the fetched outputs to their results.
208208
*/
209-
public Map<Output<?>, Tensor<?>> getOutputMap(){
209+
public Map<Output<?>, Tensor> getOutputMap(){
210210
return Collections.unmodifiableMap(outputMap);
211211
}
212212

@@ -227,7 +227,7 @@ public boolean isClosed() {
227227
/**
228228
* Get the result at {@code index}.
229229
*/
230-
public Tensor<?> get(int index){
230+
public Tensor get(int index){
231231
requireOpen();
232232
return results.get(index);
233233
}
@@ -236,25 +236,25 @@ public Tensor<?> get(int index){
236236
* Get the result for {@code output} or throw an {@code IllegalArgumentException} if it wasn't fetched.
237237
*/
238238
@SuppressWarnings("unchecked")
239-
public <T extends TType> Tensor<T> get(Output<T> output){
239+
public <T extends TType> T get(Output<T> output){
240240
requireOpen();
241241
if(!outputMap.containsKey(output))
242242
throw new IllegalArgumentException("Did not fetch an output for " + output);
243-
return (Tensor<T>) outputMap.get(output);
243+
return (T) outputMap.get(output);
244244
}
245245

246246
/**
247247
* Get the result for {@code operand} or throw an {@code IllegalArgumentException} if it wasn't fetched.
248248
*/
249-
public <T extends TType> Tensor<T> get(Operand<T> operand){
249+
public <T extends TType> T get(Operand<T> operand){
250250
requireOpen();
251251
return get(operand.asOutput());
252252
}
253253

254254
/**
255255
* Get the result for the {@code index}-th output of {@code operation} or throw an {@code IllegalArgumentException} if it wasn't fetched.
256256
*/
257-
public Tensor<?> get(String operation, int index){
257+
public Tensor get(String operation, int index){
258258
requireOpen();
259259
return get(graph.getOutput(operation, index));
260260
}
@@ -263,7 +263,7 @@ public Tensor<?> get(String operation, int index){
263263
/**
264264
* Get the result for the output specified by {@code output} or throw an {@code IllegalArgumentException} if it wasn't fetched.
265265
*/
266-
public Tensor<?> get(String output){
266+
public Tensor get(String output){
267267
requireOpen();
268268
return get(graph.getOutput(output));
269269
}
@@ -307,7 +307,7 @@ public boolean contains(String output){
307307
@Override
308308
public void close() {
309309
requireOpen();
310-
for(Tensor<?> t : this){
310+
for(Tensor t : this){
311311
if(!t.isClosed()) {
312312
t.close();
313313
}
@@ -316,19 +316,19 @@ public void close() {
316316
}
317317

318318
@Override
319-
public Iterator<Tensor<?>> iterator() {
319+
public Iterator<Tensor> iterator() {
320320
requireOpen();
321321
return results.iterator();
322322
}
323323

324324
@Override
325-
public void forEach(Consumer<? super Tensor<?>> action) {
325+
public void forEach(Consumer<? super Tensor> action) {
326326
requireOpen();
327327
results.forEach(action);
328328
}
329329

330330
@Override
331-
public Spliterator<Tensor<?>> spliterator() {
331+
public Spliterator<Tensor> spliterator() {
332332
requireOpen();
333333
return results.spliterator();
334334
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,4 +193,9 @@ static <T extends TType> T of(Class<T> type, Shape shape, ByteDataBuffer rawData
193193
*/
194194
@Override
195195
void close();
196+
197+
/**
198+
* @return {@code true} if this tensor has been closed.
199+
*/
200+
boolean isClosed();
196201
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,4 +77,9 @@ default long numBytes() {
7777
default void close() {
7878
asRawTensor().close();
7979
}
80+
81+
@Override
82+
default boolean isClosed(){
83+
return asRawTensor().isClosed();
84+
}
8085
}

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,8 @@ public void addGradientsToGraph() {
160160
assertEquals(DataType.DT_FLOAT, grads1[0].dataType());
161161
assertEquals(DataType.DT_FLOAT, grads1[1].dataType());
162162

163-
try (Tensor<TFloat32> c1 = TFloat32.scalarOf(3.0f);
164-
Tensor<TFloat32> c2 = TFloat32.scalarOf(2.0f);
163+
try (TFloat32 c1 = TFloat32.scalarOf(3.0f);
164+
TFloat32 c2 = TFloat32.scalarOf(2.0f);
165165
Session.Result outputs = s.runner()
166166
.feed(x1, c1)
167167
.feed(x2, c2)

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ public void runWithMetadata() {
113113
.runAndFetchMetadata();
114114
// Sanity check on outputs.
115115
assertEquals(1, result.size());
116-
assertEquals(31, ((TInt32)outputs.get(0)).getInt(0, 0));
116+
assertEquals(31, ((TInt32)result.get(0)).getInt(0, 0));
117117
// Sanity check on metadata
118118
assertNotNull(result.getMetadata());
119119
assertTrue(result.getMetadata().hasStepStats(), result.getMetadata().toString());

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ public void createGradientsWithSum() {
7272
assertNotNull(grads.dy());
7373
assertEquals(1, grads.dy().size());
7474

75-
try (Tensor<TFloat32> c = TFloat32.scalarOf(3.0f);
75+
try (TFloat32 c = TFloat32.scalarOf(3.0f);
7676
Session.Result outputs = sess.runner().feed(x, c).fetch(grads.dy(0)).run()) {
7777

7878
assertEquals(114.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f);
@@ -97,7 +97,7 @@ public void createGradientsWithInitialValues() {
9797
assertNotNull(grads1.dy());
9898
assertEquals(1, grads1.dy().size());
9999

100-
try (Tensor<TFloat32> c = TFloat32.scalarOf(3.0f);
100+
try (TFloat32 c = TFloat32.scalarOf(3.0f);
101101
Session.Result outputs =
102102
sess.runner().feed(x, c).fetch(grads1.dy(0)).run()) {
103103

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ public void operationsComposingZerosAreCorrectlyNamed() {
131131
Session sess = new Session(g)) {
132132
Scope scope = new Scope(g);
133133
long[] shape = {2, 2};
134-
Zeros<TFloat32> zeros = Zeros.create(scope.withSubScope("test"), Constant.vectorOf(scope, shape), TFloat32.DTYPE);
134+
Zeros<TFloat32> zeros = Zeros.create(scope.withSubScope("test"), Constant.vectorOf(scope, shape), TFloat32.class);
135135
Session.Result results = sess.runner().addTarget("test/Zeros/Zero").addTarget("test/Zeros/Fill").run();
136136
}
137137
}

0 commit comments

Comments
 (0)