Skip to content

Commit 7915e63

Browse files
committed
Fixed from Unit test results
1 parent ebefc2e commit 7915e63

File tree

10 files changed

+363
-609
lines changed

10 files changed

+363
-609
lines changed

tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnOps.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1863,14 +1863,13 @@ public <T extends TNumber> Softmax<T> softmax(Operand<T> logits) {
18631863
* @param logits Per-label activations, typically a linear output. These activation energies are
18641864
* interpreted as unnormalized log probabilities.
18651865
* @param axis The class dimension. -1 is the last dimension.
1866-
* @param <U> the data type of the <code>logits</code>
18671866
* @param <T> the number type of the operands
18681867
* @return the softmax cross entropy loss. Its type is the same as <code>logits</code> and its
18691868
* shape is the same as <code>labels</code> except that it does not have the last dimension of
18701869
* <code>labels</code>.
18711870
*/
1872-
public <U extends TType, T extends TNumber> Operand<T> softmaxCrossEntropyWithLogits(
1873-
Operand<T> labels, Operand<U> logits, int axis) {
1871+
public <T extends TNumber, U extends TNumber> Operand<T> softmaxCrossEntropyWithLogits(
1872+
Operand<U> labels, Operand<T> logits, int axis) {
18741873
return SoftmaxCrossEntropyWithLogits.softmaxCrossEntropyWithLogits(scope, labels, logits, axis);
18751874
}
18761875

tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,14 @@ protected Optimizer(Graph graph, String name) {
7171
this.globals = new ArrayList<>();
7272
}
7373

74+
/**
75+
* Gets the Optimizer's Ops instance
76+
* @return the Optimizer's Ops instance
77+
*/
78+
public final Ops getTF() {
79+
return tf;
80+
}
81+
7482
/**
7583
* Creates a name by combining a variable name and a slot name
7684
*

tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizers.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ public enum Optimizers {
1515
NADAM(Nadam::new),
1616
RMSPROP(RMSProp::new),
1717
MOMENTUM(Momentum::new),
18-
GRADIENT_DESCENT(Momentum::new);
18+
GRADIENT_DESCENT(GradientDescent::new);
1919

2020
private final Function<Graph, Optimizer> creator;
2121

tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradDATest.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,13 @@ public void testBasic() {
5555
float[] grads0Init = {0.1F, 0.2F};
5656
float[] grads1Init = {0.01F, 0.02F};
5757
try (TestSession session = TestSession.createTestSession(tfMode)) {
58-
Ops tf = session.getTF();
5958
Graph graph = session.getGraph();
6059

60+
float learningRate = 3.0F;
61+
62+
AdaGradDA instance = new AdaGradDA(graph, learningRate);
63+
Ops tf = instance.getTF();
64+
6165
Shape shape0 = Shape.of(var0Init.length);
6266
Shape shape1 = Shape.of(var1Init.length);
6367
Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE);
@@ -70,13 +74,10 @@ public void testBasic() {
7074
Constant<TFloat32> grads1 = tf.constant(grads1Init);
7175

7276
/* initialize the local variables */
73-
/* initialize the local variables */
77+
7478
session.run(var0Initializer);
7579
session.run(var1Initializer);
7680

77-
float learningRate = 3.0F;
78-
79-
AdaGrad instance = new AdaGrad(graph, learningRate);
8081

8182
/* build the GradsAnvVars */
8283
List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>();

tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradTest.java

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,12 @@ public void testBasic() {
7171
FloatNdArray accum1Np = NdArrays.vectorOf(accum1);
7272

7373
try (TestSession session = TestSession.createTestSession(tfMode)) {
74-
Ops tf = session.getTF();
7574
Graph graph = session.getGraph();
7675

76+
float learningRate = 3.0F;
77+
AdaGrad instance = new AdaGrad(graph, learningRate, 0.1f);
78+
Ops tf = instance.getTF();
79+
7780
Shape shape0 = Shape.of(var0Init.length);
7881
Shape shape1 = Shape.of(var1Init.length);
7982
Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE);
@@ -85,9 +88,7 @@ public void testBasic() {
8588
Constant<TFloat32> grads0 = tf.constant(grads0Init);
8689
Constant<TFloat32> grads1 = tf.constant(grads1Init);
8790

88-
float learningRate = 3.0F;
8991

90-
AdaGrad instance = new AdaGrad(graph, learningRate);
9192

9293
/* build the GradsAnvVars */
9394
List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>();
@@ -116,15 +117,11 @@ public void testBasic() {
116117

117118
for (int step = 0; step < numSteps; step++) {
118119
session.run(adaUpdate);
119-
120-
accum0Np = caclulateAccum(accum0Np, grads0Np);
121-
var0Np = calculate(var0Np, accum0Np, grads0Np, learningRate);
122-
session.evaluate(var0Np, var0);
123-
124-
accum1Np = caclulateAccum(accum1Np, grads1Np);
125-
var1Np = calculate(var1Np, accum1Np, grads1Np, learningRate);
126-
session.evaluate(var1Np, var1);
127120
}
121+
float[] expected0 = {-1.6026098728179932f, -0.6026098728179932f};
122+
session.evaluate(expected0, var0);
123+
float[] expected1 = {2.715679168701172f, 3.715679168701172f};
124+
session.evaluate(expected1, var1);
128125
}
129126
}
130127

tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,15 @@ public void testBasic() {
6969
float epsilon1 = 1e-3F;
7070

7171
try (TestSession session = TestSession.createTestSession(tfMode)) {
72-
Ops tf = session.getTF();
72+
float learningRate = 0.001F;
73+
float beta1 = 0.9F;
74+
float beta2 = 0.999F;
7375
Graph graph = session.getGraph();
76+
7477
session.setEpsilon(epsilon1);
7578

79+
Adam instance = new Adam(graph, learningRate);
80+
Ops tf = instance.getTF();
7681
Shape shape0 = Shape.of(var0Init.length);
7782
Shape shape1 = Shape.of(var1Init.length);
7883
Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE);
@@ -88,16 +93,14 @@ public void testBasic() {
8893
session.run(var0Initializer);
8994
session.run(var1Initializer);
9095

91-
float learningRate = 0.001F;
92-
float beta1 = 0.9F;
93-
float beta2 = 0.999F;
96+
9497

9598
/* build the GradsAnvVars */
9699
List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>();
97100
gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput()));
98101
gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput()));
99102

100-
Adam instance = new Adam(graph, learningRate);
103+
101104

102105
Op update = instance.applyGradients(gradsAndVars, "AdamTest");
103106

@@ -139,7 +142,7 @@ public void testBasic() {
139142
session
140143
.getGraphSession()
141144
.runner()
142-
.fetch("beta1Power")
145+
.fetch("beta1_power")
143146
.run()
144147
.get(0)
145148
.expect(TFloat32.DTYPE)) {
@@ -149,7 +152,7 @@ public void testBasic() {
149152
session
150153
.getGraphSession()
151154
.runner()
152-
.fetch("beta2Power")
155+
.fetch("beta2_power")
153156
.run()
154157
.get(0)
155158
.expect(TFloat32.DTYPE)) {

tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,11 @@ public void testBasic() {
9494
float epsilon1 = 1e-3F;
9595

9696
try (TestSession session = TestSession.createTestSession(tfMode)) {
97-
Ops tf = session.getTF();
9897
Graph graph = session.getGraph();
9998

99+
Adamax instance = new Adamax(graph);
100+
Ops tf = instance.getTF();
101+
100102
Shape shape0 = Shape.of(var0Init.length);
101103
Shape shape1 = Shape.of(var1Init.length);
102104
Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE);
@@ -112,7 +114,7 @@ public void testBasic() {
112114
session.run(var0Initializer);
113115
session.run(var1Initializer);
114116

115-
Adamax instance = new Adamax(graph);
117+
116118
/* build the GradsAnvVars */
117119
List<GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>();
118120
gradsAndVars.add(new GradAndVar<>(grads0.asOutput(), var0.asOutput()));
@@ -151,7 +153,7 @@ public void testBasic() {
151153
session
152154
.getGraphSession()
153155
.runner()
154-
.fetch("beta1Power")
156+
.fetch("beta1_power")
155157
.run()
156158
.get(0)
157159
.expect(TFloat32.DTYPE)) {

0 commit comments

Comments
 (0)