diff --git a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/alexnet/AlexNetModel.java b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/alexnet/AlexNetModel.java new file mode 100644 index 0000000..b5fe109 --- /dev/null +++ b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/alexnet/AlexNetModel.java @@ -0,0 +1,294 @@ +/* + * Copyright 2020 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ +package org.tensorflow.model.examples.cnn.alexnet; + +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Session; +import org.tensorflow.Tensor; +import org.tensorflow.framework.optimizers.Adam; +import org.tensorflow.framework.optimizers.Optimizer; +import org.tensorflow.model.examples.datasets.ImageBatch; +import org.tensorflow.model.examples.datasets.mnist.MnistDataset; +import org.tensorflow.ndarray.ByteNdArray; +import org.tensorflow.ndarray.FloatNdArray; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.index.Indices; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.*; +import org.tensorflow.op.math.Add; +import org.tensorflow.op.math.Mean; +import org.tensorflow.op.nn.Conv2d; +import org.tensorflow.op.nn.MaxPool; +import org.tensorflow.op.nn.Relu; +import org.tensorflow.op.nn.LocalResponseNormalization; +import org.tensorflow.op.nn.raw.SoftmaxCrossEntropyWithLogits; +import org.tensorflow.op.random.TruncatedNormal; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TUint8; + +import java.util.Arrays; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * Describes the AlexNet Model. + */ +public class AlexNetModel implements AutoCloseable { + private static final int PIXEL_DEPTH = 255; + private static final int NUM_CHANNELS = 1; + private static final int IMAGE_SIZE = 28; + private static final int NUM_LABELS = 26; + private static final long SEED = 123456789L; + + private static final String PADDING_TYPE = "SAME"; + public static final String INPUT_NAME = "input"; + public static final String OUTPUT_NAME = "output"; + public static final String TARGET = "target"; + public static final String TRAIN = "train"; + public static final String TRAINING_LOSS = "training_loss"; + public static final String INIT = "init"; + + private static final Logger logger = Logger.getLogger(AlexNetModel.class.getName()); + + private final Graph graph; + + private final Session session; + + public AlexNetModel() { + graph = compile(); + session = new Session(graph); + } + + public static Graph compile() { + Graph graph = new Graph(); + + Ops tf = Ops.create(graph); + + // Inputs + Placeholder input = tf.withName(INPUT_NAME).placeholder(TUint8.DTYPE, + Placeholder.shape(Shape.of(-1, IMAGE_SIZE, IMAGE_SIZE))); + Reshape input_reshaped = tf + .reshape(input, tf.array(-1, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS)); + Placeholder labels = tf.withName(TARGET).placeholder(TUint8.DTYPE); + + // Scaling the features + Constant centeringFactor = tf.constant(PIXEL_DEPTH / 2.0f); + Constant scalingFactor = tf.constant((float) PIXEL_DEPTH); + Operand scaledInput = tf.math + .div(tf.math.sub(tf.dtypes.cast(input_reshaped, TFloat32.DTYPE), centeringFactor), + scalingFactor); + + //Layer 1 + Relu relu1 = alexNetConv2DLayer("1", tf, scaledInput, new int[]{11, 11, NUM_CHANNELS, 96}, 96); + MaxPool pool1 = alexNetMaxPool(tf, relu1); + LocalResponseNormalization norm1 = alexNetModelLRN(tf, pool1); + + //Layer 2 + Relu relu2 = alexNetConv2DLayer("2", tf, norm1, new int[]{5, 5, 96, 256}, 256); + MaxPool pool2 = alexNetMaxPool(tf, relu2); + LocalResponseNormalization norm2 = alexNetModelLRN(tf, pool2); + + //Layer 3 + Relu relu3 = alexNetConv2DLayer("3", tf, norm2, new int[]{3, 3, 256, 384}, 384); + LocalResponseNormalization norm3 = alexNetModelLRN(tf, relu3); + + //Layer 4 + Relu relu4 = alexNetConv2DLayer("4", tf, norm3, new int[]{3, 3, 384, 384}, 384); + + //Layer 5 + Relu relu5 = alexNetConv2DLayer("2", tf, relu4, new int[]{3, 3, 384, 256}, 256); + MaxPool pool5 = alexNetMaxPool(tf, relu5); + LocalResponseNormalization norm5 = alexNetModelLRN(tf, pool5); + + Reshape flatten = alexNetFlatten(tf, pool5); + + Add loss = buildFCLayersAndRegularization(tf, labels, flatten); + + Optimizer optimizer = new Adam(graph, 0.001f, 0.9f, 0.999f, 1e-8f); + + optimizer.minimize(loss, TRAIN); + + tf.init(); + + return graph; + } + + public static Add buildFCLayersAndRegularization(Ops tf, Placeholder labels, Reshape flatten) { + int fcBiasShape = 500; + int[] fcWeightShape = {4096, fcBiasShape}; + + Variable fc1Weights = tf.variable(tf.math.mul(tf.random + .truncatedNormal(tf.array(fcWeightShape), TFloat32.DTYPE, + TruncatedNormal.seed(SEED)), tf.constant(0.1f))); + Variable fc1Biases = tf + .variable(tf.fill(tf.array(new int[]{fcBiasShape}), tf.constant(0.1f))); + Relu fcRelu = tf.nn + .relu(tf.math.add(tf.linalg.matMul(flatten, fc1Weights), fc1Biases)); + + // Softmax layer + Variable fc2Weights = tf.variable(tf.math.mul(tf.random + .truncatedNormal(tf.array(fcBiasShape, NUM_LABELS), TFloat32.DTYPE, + TruncatedNormal.seed(SEED)), tf.constant(0.1f))); + Variable fc2Biases = tf + .variable(tf.fill(tf.array(new int[]{NUM_LABELS}), tf.constant(0.1f))); + + Add logits = tf.math.add(tf.linalg.matMul(fcRelu, fc2Weights), fc2Biases); + + // Predicted outputs + tf.withName(OUTPUT_NAME).nn.softmax(logits); + + // Loss function & regularization + OneHot oneHot = tf + .oneHot(labels, tf.constant(NUM_LABELS), tf.constant(1.0f), tf.constant(0.0f)); + SoftmaxCrossEntropyWithLogits batchLoss = tf.nn.raw + .softmaxCrossEntropyWithLogits(logits, oneHot); + Mean labelLoss = tf.math.mean(batchLoss.loss(), tf.constant(0)); + Add regularizers = tf.math.add(tf.nn.l2Loss(fc1Weights), tf.math + .add(tf.nn.l2Loss(fc1Biases), + tf.math.add(tf.nn.l2Loss(fc2Weights), tf.nn.l2Loss(fc2Biases)))); + return tf.withName(TRAINING_LOSS).math + .add(labelLoss, tf.math.mul(regularizers, tf.constant(5e-4f))); + } + + public static Reshape alexNetFlatten(Ops tf, MaxPool pool5) { + return tf.reshape(pool5, tf.concat(Arrays + .asList(tf.slice(tf.shape(pool5), tf.array(new int[]{0}), tf.array(new int[]{1})), + tf.array(new int[]{-1})), tf.constant(0))); + } + + public static MaxPool alexNetMaxPool(Ops tf, Relu relu) { + return tf.nn + .maxPool(relu, tf.array(1, 2, 2, 1), tf.array(1, 2, 2, 1), + PADDING_TYPE); + } + + private static LocalResponseNormalization alexNetModelLRN(Ops tf, MaxPool pool) { + return tf.nn.localResponseNormalization(pool); + } + + private static LocalResponseNormalization alexNetModelLRN(Ops tf, Relu relu) { + return tf.nn.localResponseNormalization(relu); + } + + public static Relu alexNetConv2DLayer(String layerName, Ops tf, Operand scaledInput, int[] convWeightsL1Shape, int convBiasL1Shape) { + Variable conv1Weights = tf.withName("conv2d_" + layerName).variable(tf.math.mul(tf.random + .truncatedNormal(tf.array(convWeightsL1Shape), TFloat32.DTYPE, + TruncatedNormal.seed(SEED)), tf.constant(0.1f))); + Conv2d conv = tf.nn + .conv2d(scaledInput, conv1Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE); + Variable convBias = tf + .withName("bias2d_" + layerName).variable(tf.fill(tf.array(new int[]{convBiasL1Shape}), tf.constant(0.0f))); + return tf.nn.relu(tf.withName("biasAdd_" + layerName).nn.biasAdd(conv, convBias)); + } + + public void train(MnistDataset dataset, int epochs, int minibatchSize) { + // Initialises the parameters. + session.runner().addTarget(INIT).run(); + logger.info("Initialised the model parameters"); + + int interval = 0; + // Train the model + for (int i = 0; i < epochs; i++) { + for (ImageBatch trainingBatch : dataset.trainingBatches(minibatchSize)) { + try (Tensor batchImages = TUint8.tensorOf(trainingBatch.images()); + Tensor batchLabels = TUint8.tensorOf(trainingBatch.labels()); + Tensor loss = session.runner() + .feed(TARGET, batchLabels) + .feed(INPUT_NAME, batchImages) + .addTarget(TRAIN) + .fetch(TRAINING_LOSS) + .run().get(0).expect(TFloat32.DTYPE)) { + + logger.log(Level.INFO, + "Iteration = " + interval + ", training loss = " + loss.data().getFloat()); + } + interval++; + } + } + } + + public void test(MnistDataset dataset, int minibatchSize) { + int correctCount = 0; + int[][] confusionMatrix = new int[NUM_LABELS + 1][NUM_LABELS + 1]; + + for (ImageBatch trainingBatch : dataset.testBatches(minibatchSize)) { + try (Tensor transformedInput = TUint8.tensorOf(trainingBatch.images()); + Tensor outputTensor = session.runner() + .feed(INPUT_NAME, transformedInput) + .fetch(OUTPUT_NAME).run().get(0).expect(TFloat32.DTYPE)) { + + ByteNdArray labelBatch = trainingBatch.labels(); + for (int k = 0; k < labelBatch.shape().size(0); k++) { + byte trueLabel = labelBatch.getByte(k); + int predLabel; + + predLabel = argmax(outputTensor.data().slice(Indices.at(k), Indices.all())); + if (predLabel == trueLabel) { + correctCount++; + } + + confusionMatrix[trueLabel][predLabel]++; + } + } + } + + logger.info("Final accuracy = " + ((float) correctCount) / dataset.numTestingExamples()); + + StringBuilder sb = new StringBuilder(); + sb.append("Label"); + for (int i = 0; i < confusionMatrix.length; i++) { + sb.append(String.format("%1$5s", "" + i)); + } + sb.append("\n"); + + for (int i = 0; i < confusionMatrix.length; i++) { + sb.append(String.format("%1$5s", "" + i)); + for (int j = 0; j < confusionMatrix[i].length; j++) { + sb.append(String.format("%1$5s", "" + confusionMatrix[i][j])); + } + sb.append("\n"); + } + + System.out.println(sb.toString()); + } + + /** + * Find the maximum probability and return it's index. + * + * @param probabilities The probabilites. + * @return The index of the max. + */ + public static int argmax(FloatNdArray probabilities) { + float maxVal = Float.NEGATIVE_INFINITY; + int idx = 0; + for (int i = 0; i < probabilities.shape().size(0); i++) { + float curVal = probabilities.getFloat(i); + if (curVal > maxVal) { + maxVal = curVal; + idx = i; + } + } + return idx; + } + + @Override + public void close() { + session.close(); + graph.close(); + } +} diff --git a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/alexnet/AlexNetOnEMNIST.java b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/alexnet/AlexNetOnEMNIST.java new file mode 100644 index 0000000..fdd04cd --- /dev/null +++ b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/alexnet/AlexNetOnEMNIST.java @@ -0,0 +1,51 @@ +/* + * Copyright 2020 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ +package org.tensorflow.model.examples.cnn.alexnet; + +import org.tensorflow.model.examples.datasets.mnist.MnistDataset; + +import java.util.logging.Logger; + +/** + * Trains and evaluates AlexNet model on Extended-MNIST dataset. + */ +public class AlexNetOnEMNIST { + // Hyper-parameters + public static final int EPOCHS = 1; + public static final int BATCH_SIZE = 500; + + // Fashion MNIST dataset paths + public static final String TRAINING_IMAGES_ARCHIVE = "emnist/emnist-letters-train-images-idx3-ubyte.gz"; + public static final String TRAINING_LABELS_ARCHIVE = "emnist/emnist-letters-train-labels-idx1-ubyte.gz"; + public static final String TEST_IMAGES_ARCHIVE = "emnist/emnist-letters-test-images-idx3-ubyte.gz"; + public static final String TEST_LABELS_ARCHIVE = "emnist/emnist-letters-test-labels-idx1-ubyte.gz"; + + private static final Logger logger = Logger.getLogger(AlexNetOnEMNIST.class.getName()); + + public static void main(String[] args) { + logger.info("Data loading."); + MnistDataset dataset = MnistDataset.create(0, TRAINING_IMAGES_ARCHIVE, TRAINING_LABELS_ARCHIVE, TEST_IMAGES_ARCHIVE, TEST_LABELS_ARCHIVE); + + try (AlexNetModel alexNetModel = new AlexNetModel()) { + logger.info("Model training."); + alexNetModel.train(dataset, EPOCHS, BATCH_SIZE); + + logger.info("Model evaluation."); + alexNetModel.test(dataset, BATCH_SIZE); + } + } +} diff --git a/tensorflow-examples/src/main/resources/emnist/emnist-letters-test-images-idx3-ubyte.gz b/tensorflow-examples/src/main/resources/emnist/emnist-letters-test-images-idx3-ubyte.gz new file mode 100644 index 0000000..221cdbc Binary files /dev/null and b/tensorflow-examples/src/main/resources/emnist/emnist-letters-test-images-idx3-ubyte.gz differ diff --git a/tensorflow-examples/src/main/resources/emnist/emnist-letters-test-labels-idx1-ubyte.gz b/tensorflow-examples/src/main/resources/emnist/emnist-letters-test-labels-idx1-ubyte.gz new file mode 100644 index 0000000..f049ca0 Binary files /dev/null and b/tensorflow-examples/src/main/resources/emnist/emnist-letters-test-labels-idx1-ubyte.gz differ diff --git a/tensorflow-examples/src/main/resources/emnist/emnist-letters-train-images-idx3-ubyte.gz b/tensorflow-examples/src/main/resources/emnist/emnist-letters-train-images-idx3-ubyte.gz new file mode 100644 index 0000000..0bd8e9c Binary files /dev/null and b/tensorflow-examples/src/main/resources/emnist/emnist-letters-train-images-idx3-ubyte.gz differ diff --git a/tensorflow-examples/src/main/resources/emnist/emnist-letters-train-labels-idx1-ubyte.gz b/tensorflow-examples/src/main/resources/emnist/emnist-letters-train-labels-idx1-ubyte.gz new file mode 100644 index 0000000..35714d2 Binary files /dev/null and b/tensorflow-examples/src/main/resources/emnist/emnist-letters-train-labels-idx1-ubyte.gz differ