Skip to content

Commit d09e5a9

Browse files
committed
Adapt to latest op changes
See tensorflow/java#36
1 parent 7fb3015 commit d09e5a9

File tree

3 files changed

+37
-52
lines changed

3 files changed

+37
-52
lines changed

tensorflow-examples/pom.xml

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,14 @@
1616
<dependencies>
1717
<dependency>
1818
<groupId>org.tensorflow</groupId>
19-
<artifactId>tensorflow-core-api</artifactId>
19+
<artifactId>tensorflow-core-platform</artifactId>
2020
<version>0.1.0-SNAPSHOT</version>
2121
</dependency>
22-
<dependency>
23-
<groupId>org.tensorflow</groupId>
24-
<artifactId>tensorflow-core-api</artifactId>
25-
<version>0.1.0-SNAPSHOT</version>
26-
<classifier>macosx-x86_64</classifier>
27-
</dependency>
2822
<dependency>
2923
<groupId>org.tensorflow</groupId>
3024
<artifactId>tensorflow-training</artifactId>
3125
<version>0.1.0-SNAPSHOT</version>
3226
</dependency>
33-
<dependency>
34-
<groupId>org.tensorflow</groupId>
35-
<artifactId>proto</artifactId>
36-
<version>1.15.0</version>
37-
</dependency>
3827
</dependencies>
3928
<build>
4029
<plugins>

tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/CnnMnist.java

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -89,20 +89,20 @@ public static Graph build(String optimizerName) {
8989
Placeholder<TUint8> labels = tf.withName(TARGET).placeholder(TUint8.DTYPE);
9090

9191
// Scaling the features
92-
Constant<TFloat32> centeringFactor = tf.val(PIXEL_DEPTH / 2.0f);
93-
Constant<TFloat32> scalingFactor = tf.val((float) PIXEL_DEPTH);
92+
Constant<TFloat32> centeringFactor = tf.constant(PIXEL_DEPTH / 2.0f);
93+
Constant<TFloat32> scalingFactor = tf.constant((float) PIXEL_DEPTH);
9494
Operand<TFloat32> scaledInput = tf.math
9595
.div(tf.math.sub(tf.dtypes.cast(input_reshaped, TFloat32.DTYPE), centeringFactor),
9696
scalingFactor);
9797

9898
// First conv layer
9999
Variable<TFloat32> conv1Weights = tf.variable(tf.math.mul(tf.random
100100
.truncatedNormal(tf.array(5, 5, NUM_CHANNELS, 32), TFloat32.DTYPE,
101-
TruncatedNormal.seed(SEED)), tf.val(0.1f)));
101+
TruncatedNormal.seed(SEED)), tf.constant(0.1f)));
102102
Conv2d<TFloat32> conv1 = tf.nn
103103
.conv2d(scaledInput, conv1Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE);
104104
Variable<TFloat32> conv1Biases = tf
105-
.variable(tf.fill(tf.array(new int[]{32}), tf.val(0.0f)));
105+
.variable(tf.fill(tf.array(new int[]{32}), tf.constant(0.0f)));
106106
Relu<TFloat32> relu1 = tf.nn.relu(tf.nn.biasAdd(conv1, conv1Biases));
107107

108108
// First pooling layer
@@ -113,11 +113,11 @@ public static Graph build(String optimizerName) {
113113
// Second conv layer
114114
Variable<TFloat32> conv2Weights = tf.variable(tf.math.mul(tf.random
115115
.truncatedNormal(tf.array(5, 5, 32, 64), TFloat32.DTYPE,
116-
TruncatedNormal.seed(SEED)), tf.val(0.1f)));
116+
TruncatedNormal.seed(SEED)), tf.constant(0.1f)));
117117
Conv2d<TFloat32> conv2 = tf.nn
118118
.conv2d(pool1, conv2Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE);
119119
Variable<TFloat32> conv2Biases = tf
120-
.variable(tf.fill(tf.array(new int[]{64}), tf.val(0.1f)));
120+
.variable(tf.fill(tf.array(new int[]{64}), tf.constant(0.1f)));
121121
Relu<TFloat32> relu2 = tf.nn.relu(tf.nn.biasAdd(conv2, conv2Biases));
122122

123123
// Second pooling layer
@@ -128,23 +128,23 @@ public static Graph build(String optimizerName) {
128128
// Flatten inputs
129129
Reshape<TFloat32> flatten = tf.reshape(pool2, tf.concat(Arrays
130130
.asList(tf.slice(tf.shape(pool2), tf.array(new int[]{0}), tf.array(new int[]{1})),
131-
tf.array(new int[]{-1})), tf.val(0)));
131+
tf.array(new int[]{-1})), tf.constant(0)));
132132

133133
// Fully connected layer
134134
Variable<TFloat32> fc1Weights = tf.variable(tf.math.mul(tf.random
135135
.truncatedNormal(tf.array(IMAGE_SIZE * IMAGE_SIZE * 4, 512), TFloat32.DTYPE,
136-
TruncatedNormal.seed(SEED)), tf.val(0.1f)));
136+
TruncatedNormal.seed(SEED)), tf.constant(0.1f)));
137137
Variable<TFloat32> fc1Biases = tf
138-
.variable(tf.fill(tf.array(new int[]{512}), tf.val(0.1f)));
138+
.variable(tf.fill(tf.array(new int[]{512}), tf.constant(0.1f)));
139139
Relu<TFloat32> relu3 = tf.nn
140140
.relu(tf.math.add(tf.linalg.matMul(flatten, fc1Weights), fc1Biases));
141141

142142
// Softmax layer
143143
Variable<TFloat32> fc2Weights = tf.variable(tf.math.mul(tf.random
144144
.truncatedNormal(tf.array(512, NUM_LABELS), TFloat32.DTYPE,
145-
TruncatedNormal.seed(SEED)), tf.val(0.1f)));
145+
TruncatedNormal.seed(SEED)), tf.constant(0.1f)));
146146
Variable<TFloat32> fc2Biases = tf
147-
.variable(tf.fill(tf.array(new int[]{NUM_LABELS}), tf.val(0.1f)));
147+
.variable(tf.fill(tf.array(new int[]{NUM_LABELS}), tf.constant(0.1f)));
148148

149149
Add<TFloat32> logits = tf.math.add(tf.linalg.matMul(relu3, fc2Weights), fc2Biases);
150150

@@ -153,15 +153,15 @@ public static Graph build(String optimizerName) {
153153

154154
// Loss function & regularization
155155
OneHot<TFloat32> oneHot = tf
156-
.oneHot(labels, tf.val(10), tf.val(1.0f), tf.val(0.0f));
156+
.oneHot(labels, tf.constant(10), tf.constant(1.0f), tf.constant(0.0f));
157157
SoftmaxCrossEntropyWithLogits<TFloat32> batchLoss = tf.nn
158158
.softmaxCrossEntropyWithLogits(logits, oneHot);
159-
Mean<TFloat32> labelLoss = tf.math.mean(batchLoss.loss(), tf.val(0));
159+
Mean<TFloat32> labelLoss = tf.math.mean(batchLoss.loss(), tf.constant(0));
160160
Add<TFloat32> regularizers = tf.math.add(tf.nn.l2Loss(fc1Weights), tf.math
161161
.add(tf.nn.l2Loss(fc1Biases),
162162
tf.math.add(tf.nn.l2Loss(fc2Weights), tf.nn.l2Loss(fc2Biases))));
163163
Add<TFloat32> loss = tf.withName(TRAINING_LOSS).math
164-
.add(labelLoss, tf.math.mul(regularizers, tf.val(5e-4f)));
164+
.add(labelLoss, tf.math.mul(regularizers, tf.constant(5e-4f)));
165165

166166
String lcOptimizerName = optimizerName.toLowerCase();
167167
// Optimizer
@@ -194,7 +194,7 @@ public static Graph build(String optimizerName) {
194194
logger.info("Optimizer = " + optimizer.toString());
195195
Op minimize = optimizer.minimize(loss, TRAIN);
196196

197-
Op init = graph.variablesInitializer();
197+
tf.init();
198198

199199
return graph;
200200
}

tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/SimpleMnist.java

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,22 @@
11
package org.tensorflow.model.examples.mnist;
22

3-
import java.util.Arrays;
43
import org.tensorflow.Graph;
54
import org.tensorflow.Operand;
65
import org.tensorflow.Session;
76
import org.tensorflow.Tensor;
87
import org.tensorflow.model.examples.mnist.data.ImageBatch;
98
import org.tensorflow.model.examples.mnist.data.MnistDataset;
9+
import org.tensorflow.op.Op;
1010
import org.tensorflow.op.Ops;
11-
import org.tensorflow.op.core.Assign;
12-
import org.tensorflow.op.core.Constant;
13-
import org.tensorflow.op.core.Gradients;
11+
import org.tensorflow.op.RawOp;
1412
import org.tensorflow.op.core.Placeholder;
1513
import org.tensorflow.op.core.Variable;
1614
import org.tensorflow.op.math.Mean;
1715
import org.tensorflow.op.nn.Softmax;
18-
import org.tensorflow.op.train.ApplyGradientDescent;
1916
import org.tensorflow.tools.Shape;
2017
import org.tensorflow.tools.ndarray.ByteNdArray;
18+
import org.tensorflow.training.optimizers.GradientDescent;
19+
import org.tensorflow.training.optimizers.Optimizer;
2120
import org.tensorflow.types.TFloat32;
2221
import org.tensorflow.types.TInt64;
2322

@@ -42,12 +41,15 @@ public void run() {
4241
// Create weights with an initial value of 0
4342
Shape weightShape = Shape.of(dataset.imageSize(), MnistDataset.NUM_CLASSES);
4443
Variable<TFloat32> weights = tf.variable(weightShape, TFloat32.DTYPE);
45-
Assign<TFloat32> weightsInit = tf.assign(weights, tf.zerosLike(weights));
44+
tf.initAdd(tf.assign(weights, tf.zerosLike(weights)));
4645

4746
// Create biases with an initial value of 0
4847
Shape biasShape = Shape.of(MnistDataset.NUM_CLASSES);
4948
Variable<TFloat32> biases = tf.variable(biasShape, TFloat32.DTYPE);
50-
Assign<TFloat32> biasesInit = tf.assign(biases, tf.zerosLike(biases));
49+
tf.initAdd(tf.assign(biases, tf.zerosLike(biases)));
50+
51+
// Register all variable initializers for single execution
52+
tf.init();
5153

5254
// Predict the class of each image in the batch and compute the loss
5355
Softmax<TFloat32> softmax =
@@ -69,32 +71,26 @@ public void run() {
6971
);
7072

7173
// Back-propagate gradients to variables for training
72-
Gradients gradients = tf.gradients(crossEntropy, Arrays.asList(weights, biases));
73-
Constant<TFloat32> alpha = tf.val(LEARNING_RATE);
74-
ApplyGradientDescent<TFloat32> weightGradientDescent = tf.train.applyGradientDescent(weights, alpha, gradients.dy(0));
75-
ApplyGradientDescent<TFloat32> biasGradientDescent = tf.train.applyGradientDescent(biases, alpha, gradients.dy(1));
74+
Optimizer optimizer = new GradientDescent(graph, LEARNING_RATE);
75+
Op minimize = optimizer.minimize(crossEntropy);
7676

7777
// Compute the accuracy of the model
78-
Operand<TInt64> predicted = tf.math.argMax(softmax, tf.val(1));
79-
Operand<TInt64> expected = tf.math.argMax(labels, tf.val(1));
78+
Operand<TInt64> predicted = tf.math.argMax(softmax, tf.constant(1));
79+
Operand<TInt64> expected = tf.math.argMax(labels, tf.constant(1));
8080
Operand<TFloat32> accuracy = tf.math.mean(tf.dtypes.cast(tf.math.equal(predicted, expected), TFloat32.DTYPE), tf.array(0));
8181

8282
// Run the graph
8383
try (Session session = new Session(graph)) {
8484

8585
// Initialize variables
86-
session.runner()
87-
.addTarget(weightsInit)
88-
.addTarget(biasesInit)
89-
.run();
86+
session.runInit();
9087

9188
// Train the model
9289
for (ImageBatch trainingBatch : dataset.trainingBatches(TRAINING_BATCH_SIZE)) {
9390
try (Tensor<TFloat32> batchImages = preprocessImages(trainingBatch.images());
9491
Tensor<TFloat32> batchLabels = preprocessLabels(trainingBatch.labels())) {
9592
session.runner()
96-
.addTarget(weightGradientDescent)
97-
.addTarget(biasGradientDescent)
93+
.addTarget(minimize)
9894
.feed(images.asOutput(), batchImages)
9995
.feed(labels.asOutput(), batchLabels)
10096
.run();
@@ -128,10 +124,10 @@ private static Tensor<TFloat32> preprocessImages(ByteNdArray rawImages) {
128124
long imageSize = rawImages.get(0).shape().size();
129125
return tf.math.div(
130126
tf.reshape(
131-
tf.dtypes.cast(tf.val(rawImages), TFloat32.DTYPE),
127+
tf.dtypes.cast(tf.constant(rawImages), TFloat32.DTYPE),
132128
tf.array(-1L, imageSize)
133129
),
134-
tf.val(255.0f)
130+
tf.constant(255.0f)
135131
).asTensor();
136132
}
137133

@@ -140,10 +136,10 @@ private static Tensor<TFloat32> preprocessLabels(ByteNdArray rawLabels) {
140136

141137
// Map labels to one hot vectors where only the expected predictions as a value of 1.0
142138
return tf.oneHot(
143-
tf.val(rawLabels),
144-
tf.val(MnistDataset.NUM_CLASSES),
145-
tf.val(1.0f),
146-
tf.val(0.0f)
139+
tf.constant(rawLabels),
140+
tf.constant(MnistDataset.NUM_CLASSES),
141+
tf.constant(1.0f),
142+
tf.constant(0.0f)
147143
).asTensor();
148144
}
149145

0 commit comments

Comments
 (0)