diff --git a/tensorflow-examples/pom.xml b/tensorflow-examples/pom.xml index f03d7a9..276a10f 100644 --- a/tensorflow-examples/pom.xml +++ b/tensorflow-examples/pom.xml @@ -52,7 +52,7 @@ - org.tensorflow.model.examples.mnist.SimpleMnist + org.tensorflow.model.examples.dense.SimpleMnist diff --git a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/CnnMnist.java b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/lenet/CnnMnist.java similarity index 88% rename from tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/CnnMnist.java rename to tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/lenet/CnnMnist.java index 0f4f9b5..a590d30 100644 --- a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/CnnMnist.java +++ b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/lenet/CnnMnist.java @@ -1,19 +1,20 @@ /* - * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * Copyright 2020 The TensorFlow Authors. All Rights Reserved. * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= */ -package org.tensorflow.model.examples.mnist; +package org.tensorflow.model.examples.cnn.lenet; import java.util.Arrays; import java.util.logging.Level; @@ -22,8 +23,8 @@ import org.tensorflow.Operand; import org.tensorflow.Session; import org.tensorflow.Tensor; -import org.tensorflow.model.examples.mnist.data.ImageBatch; -import org.tensorflow.model.examples.mnist.data.MnistDataset; +import org.tensorflow.model.examples.datasets.ImageBatch; +import org.tensorflow.model.examples.datasets.mnist.MnistDataset; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; import org.tensorflow.op.core.Constant; @@ -76,6 +77,11 @@ public class CnnMnist { public static final String TRAINING_LOSS = "training_loss"; public static final String INIT = "init"; + private static final String TRAINING_IMAGES_ARCHIVE = "mnist/train-images-idx3-ubyte.gz"; + private static final String TRAINING_LABELS_ARCHIVE = "mnist/train-labels-idx1-ubyte.gz"; + private static final String TEST_IMAGES_ARCHIVE = "mnist/t10k-images-idx3-ubyte.gz"; + private static final String TEST_LABELS_ARCHIVE = "mnist/t10k-labels-idx1-ubyte.gz"; + public static Graph build(String optimizerName) { Graph graph = new Graph(); @@ -294,7 +300,8 @@ public static void main(String[] args) { logger.info( "Usage: MNISTTest "); - MnistDataset dataset = MnistDataset.create(0); + MnistDataset dataset = MnistDataset.create(0, TRAINING_IMAGES_ARCHIVE, TRAINING_LABELS_ARCHIVE, + TEST_IMAGES_ARCHIVE, TEST_LABELS_ARCHIVE); logger.info("Loaded data."); diff --git a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/vgg/VGG11OnFashionMNIST.java b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/vgg/VGG11OnFashionMNIST.java new file mode 100644 index 0000000..b8c5c26 --- /dev/null +++ b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/vgg/VGG11OnFashionMNIST.java @@ -0,0 +1,51 @@ +/* + * Copyright 2020 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ +package org.tensorflow.model.examples.cnn.vgg; + +import org.tensorflow.model.examples.datasets.mnist.MnistDataset; + +import java.util.logging.Logger; + +/** + * Trains and evaluates VGG'11 model on FashionMNIST dataset. + */ +public class VGG11OnFashionMNIST { + // Hyper-parameters + public static final int EPOCHS = 1; + public static final int BATCH_SIZE = 500; + + // Fashion MNIST dataset paths + public static final String TRAINING_IMAGES_ARCHIVE = "fashionmnist/train-images-idx3-ubyte.gz"; + public static final String TRAINING_LABELS_ARCHIVE = "fashionmnist/train-labels-idx1-ubyte.gz"; + public static final String TEST_IMAGES_ARCHIVE = "fashionmnist/t10k-images-idx3-ubyte.gz"; + public static final String TEST_LABELS_ARCHIVE = "fashionmnist/t10k-labels-idx1-ubyte.gz"; + + private static final Logger logger = Logger.getLogger(VGG11OnFashionMNIST.class.getName()); + + public static void main(String[] args) { + logger.info("Data loading."); + MnistDataset dataset = MnistDataset.create(0, TRAINING_IMAGES_ARCHIVE, TRAINING_LABELS_ARCHIVE, TEST_IMAGES_ARCHIVE, TEST_LABELS_ARCHIVE); + + try (VGGModel vggModel = new VGGModel()) { + logger.info("Model training."); + vggModel.train(dataset, EPOCHS, BATCH_SIZE); + + logger.info("Model evaluation."); + vggModel.test(dataset, BATCH_SIZE); + } + } +} diff --git a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/vgg/VGGModel.java b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/vgg/VGGModel.java new file mode 100644 index 0000000..3519517 --- /dev/null +++ b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/vgg/VGGModel.java @@ -0,0 +1,291 @@ +/* + * Copyright 2020 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ +package org.tensorflow.model.examples.cnn.vgg; + +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Session; +import org.tensorflow.Tensor; +import org.tensorflow.framework.optimizers.Adam; +import org.tensorflow.framework.optimizers.Optimizer; +import org.tensorflow.model.examples.datasets.ImageBatch; +import org.tensorflow.model.examples.datasets.mnist.MnistDataset; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.OneHot; +import org.tensorflow.op.core.Placeholder; +import org.tensorflow.op.core.Reshape; +import org.tensorflow.op.core.Variable; +import org.tensorflow.op.math.Add; +import org.tensorflow.op.math.Mean; +import org.tensorflow.op.nn.Conv2d; +import org.tensorflow.op.nn.MaxPool; +import org.tensorflow.op.nn.Relu; +import org.tensorflow.op.nn.SoftmaxCrossEntropyWithLogits; +import org.tensorflow.op.random.TruncatedNormal; +import org.tensorflow.tools.Shape; +import org.tensorflow.tools.ndarray.ByteNdArray; +import org.tensorflow.tools.ndarray.FloatNdArray; +import org.tensorflow.tools.ndarray.index.Indices; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TUint8; + +import java.util.Arrays; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * Describes the VGGModel. + */ +public class VGGModel implements AutoCloseable { + private static final int PIXEL_DEPTH = 255; + private static final int NUM_CHANNELS = 1; + private static final int IMAGE_SIZE = 28; + private static final int NUM_LABELS = MnistDataset.NUM_CLASSES; + private static final long SEED = 123456789L; + + private static final String PADDING_TYPE = "SAME"; + public static final String INPUT_NAME = "input"; + public static final String OUTPUT_NAME = "output"; + public static final String TARGET = "target"; + public static final String TRAIN = "train"; + public static final String TRAINING_LOSS = "training_loss"; + public static final String INIT = "init"; + + private static final Logger logger = Logger.getLogger(VGGModel.class.getName()); + + private final Graph graph; + + private final Session session; + + public VGGModel() { + graph = compile(); + session = new Session(graph); + } + + public static Graph compile() { + Graph graph = new Graph(); + + Ops tf = Ops.create(graph); + + // Inputs + Placeholder input = tf.withName(INPUT_NAME).placeholder(TUint8.DTYPE, + Placeholder.shape(Shape.of(-1, IMAGE_SIZE, IMAGE_SIZE))); + Reshape input_reshaped = tf + .reshape(input, tf.array(-1, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS)); + Placeholder labels = tf.withName(TARGET).placeholder(TUint8.DTYPE); + + // Scaling the features + Constant centeringFactor = tf.constant(PIXEL_DEPTH / 2.0f); + Constant scalingFactor = tf.constant((float) PIXEL_DEPTH); + Operand scaledInput = tf.math + .div(tf.math.sub(tf.dtypes.cast(input_reshaped, TFloat32.DTYPE), centeringFactor), + scalingFactor); + + Relu relu1 = vggConv2DLayer("1", tf, scaledInput, new int[]{3, 3, NUM_CHANNELS, 32}, 32); + + MaxPool pool1 = vggMaxPool(tf, relu1); + + Relu relu2 = vggConv2DLayer("2", tf, pool1, new int[]{3, 3, 32, 64}, 64); + + MaxPool pool2 = vggMaxPool(tf, relu2); + + Relu relu3 = vggConv2DLayer("3", tf, pool2, new int[]{3, 3, 64, 128}, 128); + Relu relu4 = vggConv2DLayer("4", tf, relu3, new int[]{3, 3, 128, 128}, 128); + + MaxPool pool3 = vggMaxPool(tf, relu4); + + Relu relu5 = vggConv2DLayer("5", tf, pool3, new int[]{3, 3, 128, 256}, 256); + Relu relu6 = vggConv2DLayer("6", tf, relu5, new int[]{3, 3, 256, 256}, 256); + + MaxPool pool4 = vggMaxPool(tf, relu6); + + Relu relu7 = vggConv2DLayer("7", tf, pool4, new int[]{3, 3, 256, 256}, 256); + Relu relu8 = vggConv2DLayer("8", tf, relu7, new int[]{3, 3, 256, 256}, 256); + + MaxPool pool5 = vggMaxPool(tf, relu8); + + Reshape flatten = vggFlatten(tf, pool5); + + Add loss = buildFCLayersAndRegularization(tf, labels, flatten); + + Optimizer optimizer = new Adam(graph, 0.001f, 0.9f, 0.999f, 1e-8f); + + optimizer.minimize(loss, TRAIN); + + tf.init(); + + return graph; + } + + public static Add buildFCLayersAndRegularization(Ops tf, Placeholder labels, Reshape flatten) { + int fcBiasShape = 100; + int[] fcWeightShape = {256, fcBiasShape}; + + Variable fc1Weights = tf.variable(tf.math.mul(tf.random + .truncatedNormal(tf.array(fcWeightShape), TFloat32.DTYPE, + TruncatedNormal.seed(SEED)), tf.constant(0.1f))); + Variable fc1Biases = tf + .variable(tf.fill(tf.array(new int[]{fcBiasShape}), tf.constant(0.1f))); + Relu fcRelu = tf.nn + .relu(tf.math.add(tf.linalg.matMul(flatten, fc1Weights), fc1Biases)); + + // Softmax layer + Variable fc2Weights = tf.variable(tf.math.mul(tf.random + .truncatedNormal(tf.array(fcBiasShape, NUM_LABELS), TFloat32.DTYPE, + TruncatedNormal.seed(SEED)), tf.constant(0.1f))); + Variable fc2Biases = tf + .variable(tf.fill(tf.array(new int[]{NUM_LABELS}), tf.constant(0.1f))); + + Add logits = tf.math.add(tf.linalg.matMul(fcRelu, fc2Weights), fc2Biases); + + // Predicted outputs + tf.withName(OUTPUT_NAME).nn.softmax(logits); + + // Loss function & regularization + OneHot oneHot = tf + .oneHot(labels, tf.constant(10), tf.constant(1.0f), tf.constant(0.0f)); + SoftmaxCrossEntropyWithLogits batchLoss = tf.nn + .softmaxCrossEntropyWithLogits(logits, oneHot); + Mean labelLoss = tf.math.mean(batchLoss.loss(), tf.constant(0)); + Add regularizers = tf.math.add(tf.nn.l2Loss(fc1Weights), tf.math + .add(tf.nn.l2Loss(fc1Biases), + tf.math.add(tf.nn.l2Loss(fc2Weights), tf.nn.l2Loss(fc2Biases)))); + return tf.withName(TRAINING_LOSS).math + .add(labelLoss, tf.math.mul(regularizers, tf.constant(5e-4f))); + } + + public static Reshape vggFlatten(Ops tf, MaxPool pool2) { + return tf.reshape(pool2, tf.concat(Arrays + .asList(tf.slice(tf.shape(pool2), tf.array(new int[]{0}), tf.array(new int[]{1})), + tf.array(new int[]{-1})), tf.constant(0))); + } + + public static MaxPool vggMaxPool(Ops tf, Relu relu1) { + return tf.nn + .maxPool(relu1, tf.array(1, 2, 2, 1), tf.array(1, 2, 2, 1), + PADDING_TYPE); + } + + public static Relu vggConv2DLayer(String layerName, Ops tf, Operand scaledInput, int[] convWeightsL1Shape, int convBiasL1Shape) { + Variable conv1Weights = tf.withName("conv2d_" + layerName).variable(tf.math.mul(tf.random + .truncatedNormal(tf.array(convWeightsL1Shape), TFloat32.DTYPE, + TruncatedNormal.seed(SEED)), tf.constant(0.1f))); + Conv2d conv = tf.nn + .conv2d(scaledInput, conv1Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE); + Variable convBias = tf + .withName("bias2d_" + layerName).variable(tf.fill(tf.array(new int[]{convBiasL1Shape}), tf.constant(0.0f))); + return tf.nn.relu(tf.withName("biasAdd_" + layerName).nn.biasAdd(conv, convBias)); + } + + public void train(MnistDataset dataset, int epochs, int minibatchSize) { + // Initialises the parameters. + session.runner().addTarget(INIT).run(); + logger.info("Initialised the model parameters"); + + int interval = 0; + // Train the model + for (int i = 0; i < epochs; i++) { + for (ImageBatch trainingBatch : dataset.trainingBatches(minibatchSize)) { + try (Tensor batchImages = TUint8.tensorOf(trainingBatch.images()); + Tensor batchLabels = TUint8.tensorOf(trainingBatch.labels()); + Tensor loss = session.runner() + .feed(TARGET, batchLabels) + .feed(INPUT_NAME, batchImages) + .addTarget(TRAIN) + .fetch(TRAINING_LOSS) + .run().get(0).expect(TFloat32.DTYPE)) { + + logger.log(Level.INFO, + "Iteration = " + interval + ", training loss = " + loss.data().getFloat()); + + } + interval++; + } + } + } + + public void test(MnistDataset dataset, int minibatchSize) { + int correctCount = 0; + int[][] confusionMatrix = new int[10][10]; + + for (ImageBatch trainingBatch : dataset.testBatches(minibatchSize)) { + try (Tensor transformedInput = TUint8.tensorOf(trainingBatch.images()); + Tensor outputTensor = session.runner() + .feed(INPUT_NAME, transformedInput) + .fetch(OUTPUT_NAME).run().get(0).expect(TFloat32.DTYPE)) { + + ByteNdArray labelBatch = trainingBatch.labels(); + for (int k = 0; k < labelBatch.shape().size(0); k++) { + byte trueLabel = labelBatch.getByte(k); + int predLabel; + + predLabel = argmax(outputTensor.data().slice(Indices.at(k), Indices.all())); + if (predLabel == trueLabel) { + correctCount++; + } + + confusionMatrix[trueLabel][predLabel]++; + } + } + } + + logger.info("Final accuracy = " + ((float) correctCount) / dataset.numTestingExamples()); + + StringBuilder sb = new StringBuilder(); + sb.append("Label"); + for (int i = 0; i < confusionMatrix.length; i++) { + sb.append(String.format("%1$5s", "" + i)); + } + sb.append("\n"); + + for (int i = 0; i < confusionMatrix.length; i++) { + sb.append(String.format("%1$5s", "" + i)); + for (int j = 0; j < confusionMatrix[i].length; j++) { + sb.append(String.format("%1$5s", "" + confusionMatrix[i][j])); + } + sb.append("\n"); + } + + System.out.println(sb.toString()); + } + + /** + * Find the maximum probability and return it's index. + * + * @param probabilities The probabilites. + * @return The index of the max. + */ + public static int argmax(FloatNdArray probabilities) { + float maxVal = Float.NEGATIVE_INFINITY; + int idx = 0; + for (int i = 0; i < probabilities.shape().size(0); i++) { + float curVal = probabilities.getFloat(i); + if (curVal > maxVal) { + maxVal = curVal; + idx = i; + } + } + return idx; + } + + @Override + public void close() { + session.close(); + graph.close(); + } +} diff --git a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/data/ImageBatch.java b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/datasets/ImageBatch.java similarity index 87% rename from tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/data/ImageBatch.java rename to tensorflow-examples/src/main/java/org/tensorflow/model/examples/datasets/ImageBatch.java index 5b23cf2..61100cb 100644 --- a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/data/ImageBatch.java +++ b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/datasets/ImageBatch.java @@ -14,10 +14,11 @@ * limitations under the License. * ======================================================================= */ -package org.tensorflow.model.examples.mnist.data; +package org.tensorflow.model.examples.datasets; import org.tensorflow.tools.ndarray.ByteNdArray; +/** Batch of images for batch training. */ public class ImageBatch { public ByteNdArray images() { @@ -28,7 +29,7 @@ public ByteNdArray labels() { return labels; } - ImageBatch(ByteNdArray images, ByteNdArray labels) { + public ImageBatch(ByteNdArray images, ByteNdArray labels) { this.images = images; this.labels = labels; } diff --git a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/data/ImageBatchIterator.java b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/datasets/ImageBatchIterator.java similarity index 86% rename from tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/data/ImageBatchIterator.java rename to tensorflow-examples/src/main/java/org/tensorflow/model/examples/datasets/ImageBatchIterator.java index 46165c6..f9f6739 100644 --- a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/data/ImageBatchIterator.java +++ b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/datasets/ImageBatchIterator.java @@ -14,7 +14,7 @@ * limitations under the License. * ======================================================================= */ -package org.tensorflow.model.examples.mnist.data; +package org.tensorflow.model.examples.datasets; import static org.tensorflow.tools.ndarray.index.Indices.range; @@ -22,7 +22,8 @@ import org.tensorflow.tools.ndarray.ByteNdArray; import org.tensorflow.tools.ndarray.index.Index; -class ImageBatchIterator implements Iterator { +/** Basic batch iterator across images presented in datset. */ +public class ImageBatchIterator implements Iterator { @Override public boolean hasNext() { @@ -37,7 +38,7 @@ public ImageBatch next() { return new ImageBatch(images.slice(range), labels.slice(range)); } - ImageBatchIterator(int batchSize, ByteNdArray images, ByteNdArray labels) { + public ImageBatchIterator(int batchSize, ByteNdArray images, ByteNdArray labels) { this.batchSize = batchSize; this.images = images; this.labels = labels; diff --git a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/data/MnistDataset.java b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/datasets/mnist/MnistDataset.java similarity index 86% rename from tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/data/MnistDataset.java rename to tensorflow-examples/src/main/java/org/tensorflow/model/examples/datasets/mnist/MnistDataset.java index 7f8df7b..caec509 100644 --- a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/data/MnistDataset.java +++ b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/datasets/mnist/MnistDataset.java @@ -14,29 +14,33 @@ * limitations under the License. * ======================================================================= */ -package org.tensorflow.model.examples.mnist.data; +package org.tensorflow.model.examples.datasets.mnist; -import static org.tensorflow.tools.ndarray.index.Indices.from; -import static org.tensorflow.tools.ndarray.index.Indices.to; - -import java.io.DataInputStream; -import java.io.IOException; -import java.util.zip.GZIPInputStream; +import org.tensorflow.model.examples.datasets.ImageBatch; +import org.tensorflow.model.examples.datasets.ImageBatchIterator; import org.tensorflow.tools.Shape; import org.tensorflow.tools.buffer.DataBuffers; import org.tensorflow.tools.ndarray.ByteNdArray; import org.tensorflow.tools.ndarray.NdArrays; -public class MnistDataset { +import java.io.DataInputStream; +import java.io.IOException; +import java.util.zip.GZIPInputStream; +import static org.tensorflow.tools.ndarray.index.Indices.from; +import static org.tensorflow.tools.ndarray.index.Indices.to; + +/** Common loader and data preprocessor for MNIST and FashionMNIST datasets. */ +public class MnistDataset { public static final int NUM_CLASSES = 10; - public static MnistDataset create(int validationSize) { + public static MnistDataset create(int validationSize, String trainingImagesArchive, String trainingLabelsArchive, + String testImagesArchive, String testLabelsArchive) { try { - ByteNdArray trainingImages = readArchive(TRAINING_IMAGES_ARCHIVE); - ByteNdArray trainingLabels = readArchive(TRAINING_LABELS_ARCHIVE); - ByteNdArray testImages = readArchive(TEST_IMAGES_ARCHIVE); - ByteNdArray testLabels = readArchive(TEST_LABELS_ARCHIVE); + ByteNdArray trainingImages = readArchive(trainingImagesArchive); + ByteNdArray trainingLabels = readArchive(trainingLabelsArchive); + ByteNdArray testImages = readArchive(testImagesArchive); + ByteNdArray testLabels = readArchive(testLabelsArchive); if (validationSize > 0) { return new MnistDataset( @@ -87,10 +91,6 @@ public long numValidationExamples() { return validationLabels.shape().size(0); } - private static final String TRAINING_IMAGES_ARCHIVE = "train-images-idx3-ubyte.gz"; - private static final String TRAINING_LABELS_ARCHIVE = "train-labels-idx1-ubyte.gz"; - private static final String TEST_IMAGES_ARCHIVE = "t10k-images-idx3-ubyte.gz"; - private static final String TEST_LABELS_ARCHIVE = "t10k-labels-idx1-ubyte.gz"; private static final int TYPE_UBYTE = 0x08; private final ByteNdArray trainingImages; diff --git a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/SimpleMnist.java b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/dense/SimpleMnist.java similarity index 89% rename from tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/SimpleMnist.java rename to tensorflow-examples/src/main/java/org/tensorflow/model/examples/dense/SimpleMnist.java index 44a665d..fb7e63a 100644 --- a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/SimpleMnist.java +++ b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/dense/SimpleMnist.java @@ -14,7 +14,7 @@ * limitations under the License. * ======================================================================= */ -package org.tensorflow.model.examples.mnist; +package org.tensorflow.model.examples.dense; import org.tensorflow.Graph; import org.tensorflow.Operand; @@ -22,8 +22,8 @@ import org.tensorflow.Tensor; import org.tensorflow.framework.optimizers.GradientDescent; import org.tensorflow.framework.optimizers.Optimizer; -import org.tensorflow.model.examples.mnist.data.ImageBatch; -import org.tensorflow.model.examples.mnist.data.MnistDataset; +import org.tensorflow.model.examples.datasets.ImageBatch; +import org.tensorflow.model.examples.datasets.mnist.MnistDataset; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; import org.tensorflow.op.core.Placeholder; @@ -36,9 +36,15 @@ import org.tensorflow.types.TInt64; public class SimpleMnist implements Runnable { + private static final String TRAINING_IMAGES_ARCHIVE = "mnist/train-images-idx3-ubyte.gz"; + private static final String TRAINING_LABELS_ARCHIVE = "mnist/train-labels-idx1-ubyte.gz"; + private static final String TEST_IMAGES_ARCHIVE = "mnist/t10k-images-idx3-ubyte.gz"; + private static final String TEST_LABELS_ARCHIVE = "mnist/t10k-labels-idx1-ubyte.gz"; public static void main(String[] args) { - MnistDataset dataset = MnistDataset.create(VALIDATION_SIZE); + MnistDataset dataset = MnistDataset.create(VALIDATION_SIZE, TRAINING_IMAGES_ARCHIVE, TRAINING_LABELS_ARCHIVE, + TEST_IMAGES_ARCHIVE, TEST_LABELS_ARCHIVE); + try (Graph graph = new Graph()) { SimpleMnist mnist = new SimpleMnist(graph, dataset); mnist.run(); diff --git a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/regression/linear/LinearRegressionExample.java b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/regression/linear/LinearRegressionExample.java index b6f873a..a887b50 100644 --- a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/regression/linear/LinearRegressionExample.java +++ b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/regression/linear/LinearRegressionExample.java @@ -23,8 +23,12 @@ import org.tensorflow.framework.optimizers.Optimizer; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; -import org.tensorflow.op.core.*; -import org.tensorflow.op.math.*; +import org.tensorflow.op.core.Placeholder; +import org.tensorflow.op.core.Variable; +import org.tensorflow.op.math.Add; +import org.tensorflow.op.math.Div; +import org.tensorflow.op.math.Mul; +import org.tensorflow.op.math.Pow; import org.tensorflow.tools.Shape; import org.tensorflow.types.TFloat32; @@ -46,7 +50,6 @@ public class LinearRegressionExample { /** * This value is used to fill the Y placeholder in prediction. */ - private static final float NO_MEANING_VALUE_TO_PUT_IN_PLACEHOLDER = 2000f; public static final float LEARNING_RATE = 0.1f; public static final String WEIGHT_VARIABLE_NAME = "weight"; public static final String BIAS_VARIABLE_NAME = "bias"; diff --git a/tensorflow-examples/src/main/resources/fashionmnist/Readme.md b/tensorflow-examples/src/main/resources/fashionmnist/Readme.md new file mode 100644 index 0000000..95b6f38 --- /dev/null +++ b/tensorflow-examples/src/main/resources/fashionmnist/Readme.md @@ -0,0 +1,6 @@ +This dataset is distributed under MIT License and presented in next paper. +Fashion-MNIST: a Novel Image Dataset for Benchmarking Machine Learning Algorithms. Han Xiao, Kashif Rasul, Roland Vollgraf. arXiv:1708.07747 + +The data was downloaded from the FashionMnist Repository +https://github.com/zalandoresearch/fashion-mnist/tree/master/data/fashion + diff --git a/tensorflow-examples/src/main/resources/fashionmnist/t10k-images-idx3-ubyte.gz b/tensorflow-examples/src/main/resources/fashionmnist/t10k-images-idx3-ubyte.gz new file mode 100644 index 0000000..667844f Binary files /dev/null and b/tensorflow-examples/src/main/resources/fashionmnist/t10k-images-idx3-ubyte.gz differ diff --git a/tensorflow-examples/src/main/resources/fashionmnist/t10k-labels-idx1-ubyte.gz b/tensorflow-examples/src/main/resources/fashionmnist/t10k-labels-idx1-ubyte.gz new file mode 100644 index 0000000..abdddb8 Binary files /dev/null and b/tensorflow-examples/src/main/resources/fashionmnist/t10k-labels-idx1-ubyte.gz differ diff --git a/tensorflow-examples/src/main/resources/fashionmnist/train-images-idx3-ubyte.gz b/tensorflow-examples/src/main/resources/fashionmnist/train-images-idx3-ubyte.gz new file mode 100644 index 0000000..e6ee0e3 Binary files /dev/null and b/tensorflow-examples/src/main/resources/fashionmnist/train-images-idx3-ubyte.gz differ diff --git a/tensorflow-examples/src/main/resources/fashionmnist/train-labels-idx1-ubyte.gz b/tensorflow-examples/src/main/resources/fashionmnist/train-labels-idx1-ubyte.gz new file mode 100644 index 0000000..9c4aae2 Binary files /dev/null and b/tensorflow-examples/src/main/resources/fashionmnist/train-labels-idx1-ubyte.gz differ diff --git a/tensorflow-examples/src/main/resources/t10k-images-idx3-ubyte.gz b/tensorflow-examples/src/main/resources/mnist/t10k-images-idx3-ubyte.gz similarity index 100% rename from tensorflow-examples/src/main/resources/t10k-images-idx3-ubyte.gz rename to tensorflow-examples/src/main/resources/mnist/t10k-images-idx3-ubyte.gz diff --git a/tensorflow-examples/src/main/resources/t10k-labels-idx1-ubyte.gz b/tensorflow-examples/src/main/resources/mnist/t10k-labels-idx1-ubyte.gz similarity index 100% rename from tensorflow-examples/src/main/resources/t10k-labels-idx1-ubyte.gz rename to tensorflow-examples/src/main/resources/mnist/t10k-labels-idx1-ubyte.gz diff --git a/tensorflow-examples/src/main/resources/train-images-idx3-ubyte.gz b/tensorflow-examples/src/main/resources/mnist/train-images-idx3-ubyte.gz similarity index 100% rename from tensorflow-examples/src/main/resources/train-images-idx3-ubyte.gz rename to tensorflow-examples/src/main/resources/mnist/train-images-idx3-ubyte.gz diff --git a/tensorflow-examples/src/main/resources/train-labels-idx1-ubyte.gz b/tensorflow-examples/src/main/resources/mnist/train-labels-idx1-ubyte.gz similarity index 100% rename from tensorflow-examples/src/main/resources/train-labels-idx1-ubyte.gz rename to tensorflow-examples/src/main/resources/mnist/train-labels-idx1-ubyte.gz