24
24
import org .tensorflow .Operand ;
25
25
import org .tensorflow .Session ;
26
26
import org .tensorflow .Tensor ;
27
+ import org .tensorflow .TensorScope ;
27
28
import org .tensorflow .ndarray .Shape ;
28
29
import org .tensorflow .op .Scope ;
29
30
import org .tensorflow .types .TBool ;
@@ -34,7 +35,8 @@ public class BooleanMaskUpdateTest {
34
35
@ Test
35
36
public void testBooleanMaskUpdateSlice () {
36
37
try (Graph g = new Graph ();
37
- Session sess = new Session (g )) {
38
+ Session sess = new Session (g );
39
+ TensorScope tensorScope = new TensorScope ()) {
38
40
Scope scope = new Scope (g );
39
41
40
42
Operand <TInt32 > input = Constant .tensorOf (scope , new int [][]{{0 , 0 , 0 }, {1 , 1 , 1 }, {2 , 2 , 2 }});
@@ -47,31 +49,31 @@ public void testBooleanMaskUpdateSlice() {
47
49
48
50
Operand <TInt32 > bcastOutput = BooleanMaskUpdate .create (scope , input , mask , Constant .scalarOf (scope , -1 ));
49
51
50
- List <Tensor > results = sess .runner ().fetch (output ).fetch (bcastOutput ).run ();
51
- try ( TInt32 result = (TInt32 ) results .get (0 );
52
- TInt32 bcastResult = (TInt32 ) results .get (1 )) {
52
+ List <Tensor > results = sess .runner ().fetch (output ).fetch (bcastOutput ).run (tensorScope );
53
+ TInt32 result = (TInt32 ) results .get (0 );
54
+ TInt32 bcastResult = (TInt32 ) results .get (1 );
53
55
54
- assertEquals (Shape .of (3 , 3 ), result .shape ());
56
+ assertEquals (Shape .of (3 , 3 ), result .shape ());
55
57
56
- assertEquals (-1 , result .getInt (0 , 0 ));
57
- assertEquals (-1 , result .getInt (0 , 1 ));
58
- assertEquals (-1 , result .getInt (0 , 2 ));
59
- assertEquals (1 , result .getInt (1 , 0 ));
60
- assertEquals (1 , result .getInt (1 , 1 ));
61
- assertEquals (1 , result .getInt (1 , 2 ));
62
- assertEquals (2 , result .getInt (2 , 0 ));
63
- assertEquals (2 , result .getInt (2 , 1 ));
64
- assertEquals (2 , result .getInt (2 , 2 ));
58
+ assertEquals (-1 , result .getInt (0 , 0 ));
59
+ assertEquals (-1 , result .getInt (0 , 1 ));
60
+ assertEquals (-1 , result .getInt (0 , 2 ));
61
+ assertEquals (1 , result .getInt (1 , 0 ));
62
+ assertEquals (1 , result .getInt (1 , 1 ));
63
+ assertEquals (1 , result .getInt (1 , 2 ));
64
+ assertEquals (2 , result .getInt (2 , 0 ));
65
+ assertEquals (2 , result .getInt (2 , 1 ));
66
+ assertEquals (2 , result .getInt (2 , 2 ));
65
67
66
- assertEquals (result , bcastResult );
67
- }
68
+ assertEquals (result , bcastResult );
68
69
}
69
70
}
70
71
71
72
@ Test
72
73
public void testBooleanMaskUpdateSliceWithBroadcast () {
73
74
try (Graph g = new Graph ();
74
- Session sess = new Session (g )) {
75
+ Session sess = new Session (g );
76
+ TensorScope tensorScope = new TensorScope ()) {
75
77
Scope scope = new Scope (g );
76
78
77
79
Operand <TInt32 > input = Constant .tensorOf (scope , new int [][]{{0 , 0 , 0 }, {1 , 1 , 1 }, {2 , 2 , 2 }});
@@ -84,31 +86,31 @@ public void testBooleanMaskUpdateSliceWithBroadcast() {
84
86
85
87
Operand <TInt32 > bcastOutput = BooleanMaskUpdate .create (scope , input , mask , Constant .scalarOf (scope , -1 ));
86
88
87
- List <Tensor > results = sess .runner ().fetch (output ).fetch (bcastOutput ).run ();
88
- try ( TInt32 result = (TInt32 ) results .get (0 );
89
- TInt32 bcastResult = (TInt32 ) results .get (1 )) {
89
+ List <Tensor > results = sess .runner ().fetch (output ).fetch (bcastOutput ).run (tensorScope );
90
+ TInt32 result = (TInt32 ) results .get (0 );
91
+ TInt32 bcastResult = (TInt32 ) results .get (1 );
90
92
91
- assertEquals (Shape .of (3 , 3 ), result .shape ());
93
+ assertEquals (Shape .of (3 , 3 ), result .shape ());
92
94
93
- assertEquals (-1 , result .getInt (0 , 0 ));
94
- assertEquals (-1 , result .getInt (0 , 1 ));
95
- assertEquals (-1 , result .getInt (0 , 2 ));
96
- assertEquals (1 , result .getInt (1 , 0 ));
97
- assertEquals (1 , result .getInt (1 , 1 ));
98
- assertEquals (1 , result .getInt (1 , 2 ));
99
- assertEquals (2 , result .getInt (2 , 0 ));
100
- assertEquals (2 , result .getInt (2 , 1 ));
101
- assertEquals (2 , result .getInt (2 , 2 ));
95
+ assertEquals (-1 , result .getInt (0 , 0 ));
96
+ assertEquals (-1 , result .getInt (0 , 1 ));
97
+ assertEquals (-1 , result .getInt (0 , 2 ));
98
+ assertEquals (1 , result .getInt (1 , 0 ));
99
+ assertEquals (1 , result .getInt (1 , 1 ));
100
+ assertEquals (1 , result .getInt (1 , 2 ));
101
+ assertEquals (2 , result .getInt (2 , 0 ));
102
+ assertEquals (2 , result .getInt (2 , 1 ));
103
+ assertEquals (2 , result .getInt (2 , 2 ));
102
104
103
- assertEquals (result , bcastResult );
104
- }
105
+ assertEquals (result , bcastResult );
105
106
}
106
107
}
107
108
108
109
@ Test
109
110
public void testBooleanMaskUpdateAxis () {
110
111
try (Graph g = new Graph ();
111
- Session sess = new Session (g )) {
112
+ Session sess = new Session (g );
113
+ TensorScope tensorScope = new TensorScope ()) {
112
114
Scope scope = new Scope (g );
113
115
114
116
Operand <TInt32 > input = Constant .tensorOf (scope , new int [][][]{{{0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 }}});
@@ -122,25 +124,24 @@ public void testBooleanMaskUpdateAxis() {
122
124
Operand <TInt32 > bcastOutput = BooleanMaskUpdate
123
125
.create (scope , input , mask , Constant .scalarOf (scope , -1 ), BooleanMaskUpdate .axis (2 ));
124
126
125
- List <Tensor > results = sess .runner ().fetch (output ).fetch (bcastOutput ).run ();
126
- try (TInt32 result = (TInt32 ) results .get (0 );
127
- TInt32 bcastResult = (TInt32 ) results .get (1 )) {
128
-
129
- assertEquals (Shape .of (1 , 1 , 10 ), result .shape ());
130
-
131
- assertEquals (-1 , result .getInt (0 , 0 , 0 ));
132
- assertEquals (-1 , result .getInt (0 , 0 , 1 ));
133
- assertEquals (2 , result .getInt (0 , 0 , 2 ));
134
- assertEquals (3 , result .getInt (0 , 0 , 3 ));
135
- assertEquals (-1 , result .getInt (0 , 0 , 4 ));
136
- assertEquals (-1 , result .getInt (0 , 0 , 5 ));
137
- assertEquals (-1 , result .getInt (0 , 0 , 6 ));
138
- assertEquals (7 , result .getInt (0 , 0 , 7 ));
139
- assertEquals (8 , result .getInt (0 , 0 , 8 ));
140
- assertEquals (9 , result .getInt (0 , 0 , 9 ));
141
-
142
- assertEquals (result , bcastResult );
143
- }
127
+ List <Tensor > results = sess .runner ().fetch (output ).fetch (bcastOutput ).run (tensorScope );
128
+ TInt32 result = (TInt32 ) results .get (0 );
129
+ TInt32 bcastResult = (TInt32 ) results .get (1 );
130
+
131
+ assertEquals (Shape .of (1 , 1 , 10 ), result .shape ());
132
+
133
+ assertEquals (-1 , result .getInt (0 , 0 , 0 ));
134
+ assertEquals (-1 , result .getInt (0 , 0 , 1 ));
135
+ assertEquals (2 , result .getInt (0 , 0 , 2 ));
136
+ assertEquals (3 , result .getInt (0 , 0 , 3 ));
137
+ assertEquals (-1 , result .getInt (0 , 0 , 4 ));
138
+ assertEquals (-1 , result .getInt (0 , 0 , 5 ));
139
+ assertEquals (-1 , result .getInt (0 , 0 , 6 ));
140
+ assertEquals (7 , result .getInt (0 , 0 , 7 ));
141
+ assertEquals (8 , result .getInt (0 , 0 , 8 ));
142
+ assertEquals (9 , result .getInt (0 , 0 , 9 ));
143
+
144
+ assertEquals (result , bcastResult );
144
145
}
145
146
}
146
147
}
0 commit comments