diff --git a/tensorflow-examples/pom.xml b/tensorflow-examples/pom.xml
index f03d7a9..276a10f 100644
--- a/tensorflow-examples/pom.xml
+++ b/tensorflow-examples/pom.xml
@@ -52,7 +52,7 @@
- org.tensorflow.model.examples.mnist.SimpleMnist
+ org.tensorflow.model.examples.dense.SimpleMnist
diff --git a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/CnnMnist.java b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/lenet/CnnMnist.java
similarity index 88%
rename from tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/CnnMnist.java
rename to tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/lenet/CnnMnist.java
index 0f4f9b5..a590d30 100644
--- a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/CnnMnist.java
+++ b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/lenet/CnnMnist.java
@@ -1,19 +1,20 @@
/*
- * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved.
+ * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
*
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
*
- * http://www.apache.org/licenses/LICENSE-2.0
+ * http://www.apache.org/licenses/LICENSE-2.0
*
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =======================================================================
*/
-package org.tensorflow.model.examples.mnist;
+package org.tensorflow.model.examples.cnn.lenet;
import java.util.Arrays;
import java.util.logging.Level;
@@ -22,8 +23,8 @@
import org.tensorflow.Operand;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
-import org.tensorflow.model.examples.mnist.data.ImageBatch;
-import org.tensorflow.model.examples.mnist.data.MnistDataset;
+import org.tensorflow.model.examples.datasets.ImageBatch;
+import org.tensorflow.model.examples.datasets.mnist.MnistDataset;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Constant;
@@ -76,6 +77,11 @@ public class CnnMnist {
public static final String TRAINING_LOSS = "training_loss";
public static final String INIT = "init";
+ private static final String TRAINING_IMAGES_ARCHIVE = "mnist/train-images-idx3-ubyte.gz";
+ private static final String TRAINING_LABELS_ARCHIVE = "mnist/train-labels-idx1-ubyte.gz";
+ private static final String TEST_IMAGES_ARCHIVE = "mnist/t10k-images-idx3-ubyte.gz";
+ private static final String TEST_LABELS_ARCHIVE = "mnist/t10k-labels-idx1-ubyte.gz";
+
public static Graph build(String optimizerName) {
Graph graph = new Graph();
@@ -294,7 +300,8 @@ public static void main(String[] args) {
logger.info(
"Usage: MNISTTest ");
- MnistDataset dataset = MnistDataset.create(0);
+ MnistDataset dataset = MnistDataset.create(0, TRAINING_IMAGES_ARCHIVE, TRAINING_LABELS_ARCHIVE,
+ TEST_IMAGES_ARCHIVE, TEST_LABELS_ARCHIVE);
logger.info("Loaded data.");
diff --git a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/vgg/VGG11OnFashionMNIST.java b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/vgg/VGG11OnFashionMNIST.java
new file mode 100644
index 0000000..b8c5c26
--- /dev/null
+++ b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/vgg/VGG11OnFashionMNIST.java
@@ -0,0 +1,51 @@
+/*
+ * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =======================================================================
+ */
+package org.tensorflow.model.examples.cnn.vgg;
+
+import org.tensorflow.model.examples.datasets.mnist.MnistDataset;
+
+import java.util.logging.Logger;
+
+/**
+ * Trains and evaluates VGG'11 model on FashionMNIST dataset.
+ */
+public class VGG11OnFashionMNIST {
+ // Hyper-parameters
+ public static final int EPOCHS = 1;
+ public static final int BATCH_SIZE = 500;
+
+ // Fashion MNIST dataset paths
+ public static final String TRAINING_IMAGES_ARCHIVE = "fashionmnist/train-images-idx3-ubyte.gz";
+ public static final String TRAINING_LABELS_ARCHIVE = "fashionmnist/train-labels-idx1-ubyte.gz";
+ public static final String TEST_IMAGES_ARCHIVE = "fashionmnist/t10k-images-idx3-ubyte.gz";
+ public static final String TEST_LABELS_ARCHIVE = "fashionmnist/t10k-labels-idx1-ubyte.gz";
+
+ private static final Logger logger = Logger.getLogger(VGG11OnFashionMNIST.class.getName());
+
+ public static void main(String[] args) {
+ logger.info("Data loading.");
+ MnistDataset dataset = MnistDataset.create(0, TRAINING_IMAGES_ARCHIVE, TRAINING_LABELS_ARCHIVE, TEST_IMAGES_ARCHIVE, TEST_LABELS_ARCHIVE);
+
+ try (VGGModel vggModel = new VGGModel()) {
+ logger.info("Model training.");
+ vggModel.train(dataset, EPOCHS, BATCH_SIZE);
+
+ logger.info("Model evaluation.");
+ vggModel.test(dataset, BATCH_SIZE);
+ }
+ }
+}
diff --git a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/vgg/VGGModel.java b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/vgg/VGGModel.java
new file mode 100644
index 0000000..3519517
--- /dev/null
+++ b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/vgg/VGGModel.java
@@ -0,0 +1,291 @@
+/*
+ * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =======================================================================
+ */
+package org.tensorflow.model.examples.cnn.vgg;
+
+import org.tensorflow.Graph;
+import org.tensorflow.Operand;
+import org.tensorflow.Session;
+import org.tensorflow.Tensor;
+import org.tensorflow.framework.optimizers.Adam;
+import org.tensorflow.framework.optimizers.Optimizer;
+import org.tensorflow.model.examples.datasets.ImageBatch;
+import org.tensorflow.model.examples.datasets.mnist.MnistDataset;
+import org.tensorflow.op.Ops;
+import org.tensorflow.op.core.Constant;
+import org.tensorflow.op.core.OneHot;
+import org.tensorflow.op.core.Placeholder;
+import org.tensorflow.op.core.Reshape;
+import org.tensorflow.op.core.Variable;
+import org.tensorflow.op.math.Add;
+import org.tensorflow.op.math.Mean;
+import org.tensorflow.op.nn.Conv2d;
+import org.tensorflow.op.nn.MaxPool;
+import org.tensorflow.op.nn.Relu;
+import org.tensorflow.op.nn.SoftmaxCrossEntropyWithLogits;
+import org.tensorflow.op.random.TruncatedNormal;
+import org.tensorflow.tools.Shape;
+import org.tensorflow.tools.ndarray.ByteNdArray;
+import org.tensorflow.tools.ndarray.FloatNdArray;
+import org.tensorflow.tools.ndarray.index.Indices;
+import org.tensorflow.types.TFloat32;
+import org.tensorflow.types.TUint8;
+
+import java.util.Arrays;
+import java.util.logging.Level;
+import java.util.logging.Logger;
+
+/**
+ * Describes the VGGModel.
+ */
+public class VGGModel implements AutoCloseable {
+ private static final int PIXEL_DEPTH = 255;
+ private static final int NUM_CHANNELS = 1;
+ private static final int IMAGE_SIZE = 28;
+ private static final int NUM_LABELS = MnistDataset.NUM_CLASSES;
+ private static final long SEED = 123456789L;
+
+ private static final String PADDING_TYPE = "SAME";
+ public static final String INPUT_NAME = "input";
+ public static final String OUTPUT_NAME = "output";
+ public static final String TARGET = "target";
+ public static final String TRAIN = "train";
+ public static final String TRAINING_LOSS = "training_loss";
+ public static final String INIT = "init";
+
+ private static final Logger logger = Logger.getLogger(VGGModel.class.getName());
+
+ private final Graph graph;
+
+ private final Session session;
+
+ public VGGModel() {
+ graph = compile();
+ session = new Session(graph);
+ }
+
+ public static Graph compile() {
+ Graph graph = new Graph();
+
+ Ops tf = Ops.create(graph);
+
+ // Inputs
+ Placeholder input = tf.withName(INPUT_NAME).placeholder(TUint8.DTYPE,
+ Placeholder.shape(Shape.of(-1, IMAGE_SIZE, IMAGE_SIZE)));
+ Reshape input_reshaped = tf
+ .reshape(input, tf.array(-1, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS));
+ Placeholder labels = tf.withName(TARGET).placeholder(TUint8.DTYPE);
+
+ // Scaling the features
+ Constant centeringFactor = tf.constant(PIXEL_DEPTH / 2.0f);
+ Constant scalingFactor = tf.constant((float) PIXEL_DEPTH);
+ Operand scaledInput = tf.math
+ .div(tf.math.sub(tf.dtypes.cast(input_reshaped, TFloat32.DTYPE), centeringFactor),
+ scalingFactor);
+
+ Relu relu1 = vggConv2DLayer("1", tf, scaledInput, new int[]{3, 3, NUM_CHANNELS, 32}, 32);
+
+ MaxPool pool1 = vggMaxPool(tf, relu1);
+
+ Relu relu2 = vggConv2DLayer("2", tf, pool1, new int[]{3, 3, 32, 64}, 64);
+
+ MaxPool pool2 = vggMaxPool(tf, relu2);
+
+ Relu relu3 = vggConv2DLayer("3", tf, pool2, new int[]{3, 3, 64, 128}, 128);
+ Relu relu4 = vggConv2DLayer("4", tf, relu3, new int[]{3, 3, 128, 128}, 128);
+
+ MaxPool pool3 = vggMaxPool(tf, relu4);
+
+ Relu relu5 = vggConv2DLayer("5", tf, pool3, new int[]{3, 3, 128, 256}, 256);
+ Relu relu6 = vggConv2DLayer("6", tf, relu5, new int[]{3, 3, 256, 256}, 256);
+
+ MaxPool pool4 = vggMaxPool(tf, relu6);
+
+ Relu relu7 = vggConv2DLayer("7", tf, pool4, new int[]{3, 3, 256, 256}, 256);
+ Relu relu8 = vggConv2DLayer("8", tf, relu7, new int[]{3, 3, 256, 256}, 256);
+
+ MaxPool pool5 = vggMaxPool(tf, relu8);
+
+ Reshape flatten = vggFlatten(tf, pool5);
+
+ Add loss = buildFCLayersAndRegularization(tf, labels, flatten);
+
+ Optimizer optimizer = new Adam(graph, 0.001f, 0.9f, 0.999f, 1e-8f);
+
+ optimizer.minimize(loss, TRAIN);
+
+ tf.init();
+
+ return graph;
+ }
+
+ public static Add buildFCLayersAndRegularization(Ops tf, Placeholder labels, Reshape flatten) {
+ int fcBiasShape = 100;
+ int[] fcWeightShape = {256, fcBiasShape};
+
+ Variable fc1Weights = tf.variable(tf.math.mul(tf.random
+ .truncatedNormal(tf.array(fcWeightShape), TFloat32.DTYPE,
+ TruncatedNormal.seed(SEED)), tf.constant(0.1f)));
+ Variable fc1Biases = tf
+ .variable(tf.fill(tf.array(new int[]{fcBiasShape}), tf.constant(0.1f)));
+ Relu fcRelu = tf.nn
+ .relu(tf.math.add(tf.linalg.matMul(flatten, fc1Weights), fc1Biases));
+
+ // Softmax layer
+ Variable fc2Weights = tf.variable(tf.math.mul(tf.random
+ .truncatedNormal(tf.array(fcBiasShape, NUM_LABELS), TFloat32.DTYPE,
+ TruncatedNormal.seed(SEED)), tf.constant(0.1f)));
+ Variable fc2Biases = tf
+ .variable(tf.fill(tf.array(new int[]{NUM_LABELS}), tf.constant(0.1f)));
+
+ Add logits = tf.math.add(tf.linalg.matMul(fcRelu, fc2Weights), fc2Biases);
+
+ // Predicted outputs
+ tf.withName(OUTPUT_NAME).nn.softmax(logits);
+
+ // Loss function & regularization
+ OneHot oneHot = tf
+ .oneHot(labels, tf.constant(10), tf.constant(1.0f), tf.constant(0.0f));
+ SoftmaxCrossEntropyWithLogits batchLoss = tf.nn
+ .softmaxCrossEntropyWithLogits(logits, oneHot);
+ Mean labelLoss = tf.math.mean(batchLoss.loss(), tf.constant(0));
+ Add regularizers = tf.math.add(tf.nn.l2Loss(fc1Weights), tf.math
+ .add(tf.nn.l2Loss(fc1Biases),
+ tf.math.add(tf.nn.l2Loss(fc2Weights), tf.nn.l2Loss(fc2Biases))));
+ return tf.withName(TRAINING_LOSS).math
+ .add(labelLoss, tf.math.mul(regularizers, tf.constant(5e-4f)));
+ }
+
+ public static Reshape vggFlatten(Ops tf, MaxPool pool2) {
+ return tf.reshape(pool2, tf.concat(Arrays
+ .asList(tf.slice(tf.shape(pool2), tf.array(new int[]{0}), tf.array(new int[]{1})),
+ tf.array(new int[]{-1})), tf.constant(0)));
+ }
+
+ public static MaxPool vggMaxPool(Ops tf, Relu relu1) {
+ return tf.nn
+ .maxPool(relu1, tf.array(1, 2, 2, 1), tf.array(1, 2, 2, 1),
+ PADDING_TYPE);
+ }
+
+ public static Relu vggConv2DLayer(String layerName, Ops tf, Operand scaledInput, int[] convWeightsL1Shape, int convBiasL1Shape) {
+ Variable conv1Weights = tf.withName("conv2d_" + layerName).variable(tf.math.mul(tf.random
+ .truncatedNormal(tf.array(convWeightsL1Shape), TFloat32.DTYPE,
+ TruncatedNormal.seed(SEED)), tf.constant(0.1f)));
+ Conv2d conv = tf.nn
+ .conv2d(scaledInput, conv1Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE);
+ Variable convBias = tf
+ .withName("bias2d_" + layerName).variable(tf.fill(tf.array(new int[]{convBiasL1Shape}), tf.constant(0.0f)));
+ return tf.nn.relu(tf.withName("biasAdd_" + layerName).nn.biasAdd(conv, convBias));
+ }
+
+ public void train(MnistDataset dataset, int epochs, int minibatchSize) {
+ // Initialises the parameters.
+ session.runner().addTarget(INIT).run();
+ logger.info("Initialised the model parameters");
+
+ int interval = 0;
+ // Train the model
+ for (int i = 0; i < epochs; i++) {
+ for (ImageBatch trainingBatch : dataset.trainingBatches(minibatchSize)) {
+ try (Tensor batchImages = TUint8.tensorOf(trainingBatch.images());
+ Tensor batchLabels = TUint8.tensorOf(trainingBatch.labels());
+ Tensor loss = session.runner()
+ .feed(TARGET, batchLabels)
+ .feed(INPUT_NAME, batchImages)
+ .addTarget(TRAIN)
+ .fetch(TRAINING_LOSS)
+ .run().get(0).expect(TFloat32.DTYPE)) {
+
+ logger.log(Level.INFO,
+ "Iteration = " + interval + ", training loss = " + loss.data().getFloat());
+
+ }
+ interval++;
+ }
+ }
+ }
+
+ public void test(MnistDataset dataset, int minibatchSize) {
+ int correctCount = 0;
+ int[][] confusionMatrix = new int[10][10];
+
+ for (ImageBatch trainingBatch : dataset.testBatches(minibatchSize)) {
+ try (Tensor transformedInput = TUint8.tensorOf(trainingBatch.images());
+ Tensor outputTensor = session.runner()
+ .feed(INPUT_NAME, transformedInput)
+ .fetch(OUTPUT_NAME).run().get(0).expect(TFloat32.DTYPE)) {
+
+ ByteNdArray labelBatch = trainingBatch.labels();
+ for (int k = 0; k < labelBatch.shape().size(0); k++) {
+ byte trueLabel = labelBatch.getByte(k);
+ int predLabel;
+
+ predLabel = argmax(outputTensor.data().slice(Indices.at(k), Indices.all()));
+ if (predLabel == trueLabel) {
+ correctCount++;
+ }
+
+ confusionMatrix[trueLabel][predLabel]++;
+ }
+ }
+ }
+
+ logger.info("Final accuracy = " + ((float) correctCount) / dataset.numTestingExamples());
+
+ StringBuilder sb = new StringBuilder();
+ sb.append("Label");
+ for (int i = 0; i < confusionMatrix.length; i++) {
+ sb.append(String.format("%1$5s", "" + i));
+ }
+ sb.append("\n");
+
+ for (int i = 0; i < confusionMatrix.length; i++) {
+ sb.append(String.format("%1$5s", "" + i));
+ for (int j = 0; j < confusionMatrix[i].length; j++) {
+ sb.append(String.format("%1$5s", "" + confusionMatrix[i][j]));
+ }
+ sb.append("\n");
+ }
+
+ System.out.println(sb.toString());
+ }
+
+ /**
+ * Find the maximum probability and return it's index.
+ *
+ * @param probabilities The probabilites.
+ * @return The index of the max.
+ */
+ public static int argmax(FloatNdArray probabilities) {
+ float maxVal = Float.NEGATIVE_INFINITY;
+ int idx = 0;
+ for (int i = 0; i < probabilities.shape().size(0); i++) {
+ float curVal = probabilities.getFloat(i);
+ if (curVal > maxVal) {
+ maxVal = curVal;
+ idx = i;
+ }
+ }
+ return idx;
+ }
+
+ @Override
+ public void close() {
+ session.close();
+ graph.close();
+ }
+}
diff --git a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/data/ImageBatch.java b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/datasets/ImageBatch.java
similarity index 87%
rename from tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/data/ImageBatch.java
rename to tensorflow-examples/src/main/java/org/tensorflow/model/examples/datasets/ImageBatch.java
index 5b23cf2..61100cb 100644
--- a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/data/ImageBatch.java
+++ b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/datasets/ImageBatch.java
@@ -14,10 +14,11 @@
* limitations under the License.
* =======================================================================
*/
-package org.tensorflow.model.examples.mnist.data;
+package org.tensorflow.model.examples.datasets;
import org.tensorflow.tools.ndarray.ByteNdArray;
+/** Batch of images for batch training. */
public class ImageBatch {
public ByteNdArray images() {
@@ -28,7 +29,7 @@ public ByteNdArray labels() {
return labels;
}
- ImageBatch(ByteNdArray images, ByteNdArray labels) {
+ public ImageBatch(ByteNdArray images, ByteNdArray labels) {
this.images = images;
this.labels = labels;
}
diff --git a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/data/ImageBatchIterator.java b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/datasets/ImageBatchIterator.java
similarity index 86%
rename from tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/data/ImageBatchIterator.java
rename to tensorflow-examples/src/main/java/org/tensorflow/model/examples/datasets/ImageBatchIterator.java
index 46165c6..f9f6739 100644
--- a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/data/ImageBatchIterator.java
+++ b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/datasets/ImageBatchIterator.java
@@ -14,7 +14,7 @@
* limitations under the License.
* =======================================================================
*/
-package org.tensorflow.model.examples.mnist.data;
+package org.tensorflow.model.examples.datasets;
import static org.tensorflow.tools.ndarray.index.Indices.range;
@@ -22,7 +22,8 @@
import org.tensorflow.tools.ndarray.ByteNdArray;
import org.tensorflow.tools.ndarray.index.Index;
-class ImageBatchIterator implements Iterator {
+/** Basic batch iterator across images presented in datset. */
+public class ImageBatchIterator implements Iterator {
@Override
public boolean hasNext() {
@@ -37,7 +38,7 @@ public ImageBatch next() {
return new ImageBatch(images.slice(range), labels.slice(range));
}
- ImageBatchIterator(int batchSize, ByteNdArray images, ByteNdArray labels) {
+ public ImageBatchIterator(int batchSize, ByteNdArray images, ByteNdArray labels) {
this.batchSize = batchSize;
this.images = images;
this.labels = labels;
diff --git a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/data/MnistDataset.java b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/datasets/mnist/MnistDataset.java
similarity index 86%
rename from tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/data/MnistDataset.java
rename to tensorflow-examples/src/main/java/org/tensorflow/model/examples/datasets/mnist/MnistDataset.java
index 7f8df7b..caec509 100644
--- a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/data/MnistDataset.java
+++ b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/datasets/mnist/MnistDataset.java
@@ -14,29 +14,33 @@
* limitations under the License.
* =======================================================================
*/
-package org.tensorflow.model.examples.mnist.data;
+package org.tensorflow.model.examples.datasets.mnist;
-import static org.tensorflow.tools.ndarray.index.Indices.from;
-import static org.tensorflow.tools.ndarray.index.Indices.to;
-
-import java.io.DataInputStream;
-import java.io.IOException;
-import java.util.zip.GZIPInputStream;
+import org.tensorflow.model.examples.datasets.ImageBatch;
+import org.tensorflow.model.examples.datasets.ImageBatchIterator;
import org.tensorflow.tools.Shape;
import org.tensorflow.tools.buffer.DataBuffers;
import org.tensorflow.tools.ndarray.ByteNdArray;
import org.tensorflow.tools.ndarray.NdArrays;
-public class MnistDataset {
+import java.io.DataInputStream;
+import java.io.IOException;
+import java.util.zip.GZIPInputStream;
+import static org.tensorflow.tools.ndarray.index.Indices.from;
+import static org.tensorflow.tools.ndarray.index.Indices.to;
+
+/** Common loader and data preprocessor for MNIST and FashionMNIST datasets. */
+public class MnistDataset {
public static final int NUM_CLASSES = 10;
- public static MnistDataset create(int validationSize) {
+ public static MnistDataset create(int validationSize, String trainingImagesArchive, String trainingLabelsArchive,
+ String testImagesArchive, String testLabelsArchive) {
try {
- ByteNdArray trainingImages = readArchive(TRAINING_IMAGES_ARCHIVE);
- ByteNdArray trainingLabels = readArchive(TRAINING_LABELS_ARCHIVE);
- ByteNdArray testImages = readArchive(TEST_IMAGES_ARCHIVE);
- ByteNdArray testLabels = readArchive(TEST_LABELS_ARCHIVE);
+ ByteNdArray trainingImages = readArchive(trainingImagesArchive);
+ ByteNdArray trainingLabels = readArchive(trainingLabelsArchive);
+ ByteNdArray testImages = readArchive(testImagesArchive);
+ ByteNdArray testLabels = readArchive(testLabelsArchive);
if (validationSize > 0) {
return new MnistDataset(
@@ -87,10 +91,6 @@ public long numValidationExamples() {
return validationLabels.shape().size(0);
}
- private static final String TRAINING_IMAGES_ARCHIVE = "train-images-idx3-ubyte.gz";
- private static final String TRAINING_LABELS_ARCHIVE = "train-labels-idx1-ubyte.gz";
- private static final String TEST_IMAGES_ARCHIVE = "t10k-images-idx3-ubyte.gz";
- private static final String TEST_LABELS_ARCHIVE = "t10k-labels-idx1-ubyte.gz";
private static final int TYPE_UBYTE = 0x08;
private final ByteNdArray trainingImages;
diff --git a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/SimpleMnist.java b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/dense/SimpleMnist.java
similarity index 89%
rename from tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/SimpleMnist.java
rename to tensorflow-examples/src/main/java/org/tensorflow/model/examples/dense/SimpleMnist.java
index 44a665d..fb7e63a 100644
--- a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/SimpleMnist.java
+++ b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/dense/SimpleMnist.java
@@ -14,7 +14,7 @@
* limitations under the License.
* =======================================================================
*/
-package org.tensorflow.model.examples.mnist;
+package org.tensorflow.model.examples.dense;
import org.tensorflow.Graph;
import org.tensorflow.Operand;
@@ -22,8 +22,8 @@
import org.tensorflow.Tensor;
import org.tensorflow.framework.optimizers.GradientDescent;
import org.tensorflow.framework.optimizers.Optimizer;
-import org.tensorflow.model.examples.mnist.data.ImageBatch;
-import org.tensorflow.model.examples.mnist.data.MnistDataset;
+import org.tensorflow.model.examples.datasets.ImageBatch;
+import org.tensorflow.model.examples.datasets.mnist.MnistDataset;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Placeholder;
@@ -36,9 +36,15 @@
import org.tensorflow.types.TInt64;
public class SimpleMnist implements Runnable {
+ private static final String TRAINING_IMAGES_ARCHIVE = "mnist/train-images-idx3-ubyte.gz";
+ private static final String TRAINING_LABELS_ARCHIVE = "mnist/train-labels-idx1-ubyte.gz";
+ private static final String TEST_IMAGES_ARCHIVE = "mnist/t10k-images-idx3-ubyte.gz";
+ private static final String TEST_LABELS_ARCHIVE = "mnist/t10k-labels-idx1-ubyte.gz";
public static void main(String[] args) {
- MnistDataset dataset = MnistDataset.create(VALIDATION_SIZE);
+ MnistDataset dataset = MnistDataset.create(VALIDATION_SIZE, TRAINING_IMAGES_ARCHIVE, TRAINING_LABELS_ARCHIVE,
+ TEST_IMAGES_ARCHIVE, TEST_LABELS_ARCHIVE);
+
try (Graph graph = new Graph()) {
SimpleMnist mnist = new SimpleMnist(graph, dataset);
mnist.run();
diff --git a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/regression/linear/LinearRegressionExample.java b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/regression/linear/LinearRegressionExample.java
index b6f873a..a887b50 100644
--- a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/regression/linear/LinearRegressionExample.java
+++ b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/regression/linear/LinearRegressionExample.java
@@ -23,8 +23,12 @@
import org.tensorflow.framework.optimizers.Optimizer;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
-import org.tensorflow.op.core.*;
-import org.tensorflow.op.math.*;
+import org.tensorflow.op.core.Placeholder;
+import org.tensorflow.op.core.Variable;
+import org.tensorflow.op.math.Add;
+import org.tensorflow.op.math.Div;
+import org.tensorflow.op.math.Mul;
+import org.tensorflow.op.math.Pow;
import org.tensorflow.tools.Shape;
import org.tensorflow.types.TFloat32;
@@ -46,7 +50,6 @@ public class LinearRegressionExample {
/**
* This value is used to fill the Y placeholder in prediction.
*/
- private static final float NO_MEANING_VALUE_TO_PUT_IN_PLACEHOLDER = 2000f;
public static final float LEARNING_RATE = 0.1f;
public static final String WEIGHT_VARIABLE_NAME = "weight";
public static final String BIAS_VARIABLE_NAME = "bias";
diff --git a/tensorflow-examples/src/main/resources/fashionmnist/Readme.md b/tensorflow-examples/src/main/resources/fashionmnist/Readme.md
new file mode 100644
index 0000000..95b6f38
--- /dev/null
+++ b/tensorflow-examples/src/main/resources/fashionmnist/Readme.md
@@ -0,0 +1,6 @@
+This dataset is distributed under MIT License and presented in next paper.
+Fashion-MNIST: a Novel Image Dataset for Benchmarking Machine Learning Algorithms. Han Xiao, Kashif Rasul, Roland Vollgraf. arXiv:1708.07747
+
+The data was downloaded from the FashionMnist Repository
+https://github.com/zalandoresearch/fashion-mnist/tree/master/data/fashion
+
diff --git a/tensorflow-examples/src/main/resources/fashionmnist/t10k-images-idx3-ubyte.gz b/tensorflow-examples/src/main/resources/fashionmnist/t10k-images-idx3-ubyte.gz
new file mode 100644
index 0000000..667844f
Binary files /dev/null and b/tensorflow-examples/src/main/resources/fashionmnist/t10k-images-idx3-ubyte.gz differ
diff --git a/tensorflow-examples/src/main/resources/fashionmnist/t10k-labels-idx1-ubyte.gz b/tensorflow-examples/src/main/resources/fashionmnist/t10k-labels-idx1-ubyte.gz
new file mode 100644
index 0000000..abdddb8
Binary files /dev/null and b/tensorflow-examples/src/main/resources/fashionmnist/t10k-labels-idx1-ubyte.gz differ
diff --git a/tensorflow-examples/src/main/resources/fashionmnist/train-images-idx3-ubyte.gz b/tensorflow-examples/src/main/resources/fashionmnist/train-images-idx3-ubyte.gz
new file mode 100644
index 0000000..e6ee0e3
Binary files /dev/null and b/tensorflow-examples/src/main/resources/fashionmnist/train-images-idx3-ubyte.gz differ
diff --git a/tensorflow-examples/src/main/resources/fashionmnist/train-labels-idx1-ubyte.gz b/tensorflow-examples/src/main/resources/fashionmnist/train-labels-idx1-ubyte.gz
new file mode 100644
index 0000000..9c4aae2
Binary files /dev/null and b/tensorflow-examples/src/main/resources/fashionmnist/train-labels-idx1-ubyte.gz differ
diff --git a/tensorflow-examples/src/main/resources/t10k-images-idx3-ubyte.gz b/tensorflow-examples/src/main/resources/mnist/t10k-images-idx3-ubyte.gz
similarity index 100%
rename from tensorflow-examples/src/main/resources/t10k-images-idx3-ubyte.gz
rename to tensorflow-examples/src/main/resources/mnist/t10k-images-idx3-ubyte.gz
diff --git a/tensorflow-examples/src/main/resources/t10k-labels-idx1-ubyte.gz b/tensorflow-examples/src/main/resources/mnist/t10k-labels-idx1-ubyte.gz
similarity index 100%
rename from tensorflow-examples/src/main/resources/t10k-labels-idx1-ubyte.gz
rename to tensorflow-examples/src/main/resources/mnist/t10k-labels-idx1-ubyte.gz
diff --git a/tensorflow-examples/src/main/resources/train-images-idx3-ubyte.gz b/tensorflow-examples/src/main/resources/mnist/train-images-idx3-ubyte.gz
similarity index 100%
rename from tensorflow-examples/src/main/resources/train-images-idx3-ubyte.gz
rename to tensorflow-examples/src/main/resources/mnist/train-images-idx3-ubyte.gz
diff --git a/tensorflow-examples/src/main/resources/train-labels-idx1-ubyte.gz b/tensorflow-examples/src/main/resources/mnist/train-labels-idx1-ubyte.gz
similarity index 100%
rename from tensorflow-examples/src/main/resources/train-labels-idx1-ubyte.gz
rename to tensorflow-examples/src/main/resources/mnist/train-labels-idx1-ubyte.gz