Skip to content

Commit 1177e2c

Browse files
authored
VGG'11 model on FashionMNIST dataset (#16)
* Added VGG example on Fashion Mnist Dataset * Added licenses * Fixed notes and renamed the packages * Fixed the package name * Fixed the licenses and added citation for FashionMNIST
1 parent 5704d97 commit 1177e2c

18 files changed

+410
-44
lines changed

tensorflow-examples/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
<archive>
5353
<manifest>
5454
<mainClass>
55-
org.tensorflow.model.examples.mnist.SimpleMnist
55+
org.tensorflow.model.examples.dense.SimpleMnist
5656
</mainClass>
5757
</manifest>
5858
</archive>

tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/CnnMnist.java renamed to tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/lenet/CnnMnist.java

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
/*
2-
* Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved.
2+
* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
33
*
4-
* Licensed under the Apache License, Version 2.0 (the "License");
5-
* you may not use this file except in compliance with the License.
6-
* You may obtain a copy of the License at
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
77
*
8-
* http://www.apache.org/licenses/LICENSE-2.0
8+
* http://www.apache.org/licenses/LICENSE-2.0
99
*
10-
* Unless required by applicable law or agreed to in writing, software
11-
* distributed under the License is distributed on an "AS IS" BASIS,
12-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13-
* See the License for the specific language governing permissions and
14-
* limitations under the License.
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =======================================================================
1516
*/
16-
package org.tensorflow.model.examples.mnist;
17+
package org.tensorflow.model.examples.cnn.lenet;
1718

1819
import java.util.Arrays;
1920
import java.util.logging.Level;
@@ -22,8 +23,8 @@
2223
import org.tensorflow.Operand;
2324
import org.tensorflow.Session;
2425
import org.tensorflow.Tensor;
25-
import org.tensorflow.model.examples.mnist.data.ImageBatch;
26-
import org.tensorflow.model.examples.mnist.data.MnistDataset;
26+
import org.tensorflow.model.examples.datasets.ImageBatch;
27+
import org.tensorflow.model.examples.datasets.mnist.MnistDataset;
2728
import org.tensorflow.op.Op;
2829
import org.tensorflow.op.Ops;
2930
import org.tensorflow.op.core.Constant;
@@ -76,6 +77,11 @@ public class CnnMnist {
7677
public static final String TRAINING_LOSS = "training_loss";
7778
public static final String INIT = "init";
7879

80+
private static final String TRAINING_IMAGES_ARCHIVE = "mnist/train-images-idx3-ubyte.gz";
81+
private static final String TRAINING_LABELS_ARCHIVE = "mnist/train-labels-idx1-ubyte.gz";
82+
private static final String TEST_IMAGES_ARCHIVE = "mnist/t10k-images-idx3-ubyte.gz";
83+
private static final String TEST_LABELS_ARCHIVE = "mnist/t10k-labels-idx1-ubyte.gz";
84+
7985
public static Graph build(String optimizerName) {
8086
Graph graph = new Graph();
8187

@@ -294,7 +300,8 @@ public static void main(String[] args) {
294300
logger.info(
295301
"Usage: MNISTTest <num-epochs> <minibatch-size> <optimizer-name>");
296302

297-
MnistDataset dataset = MnistDataset.create(0);
303+
MnistDataset dataset = MnistDataset.create(0, TRAINING_IMAGES_ARCHIVE, TRAINING_LABELS_ARCHIVE,
304+
TEST_IMAGES_ARCHIVE, TEST_LABELS_ARCHIVE);
298305

299306
logger.info("Loaded data.");
300307

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =======================================================================
16+
*/
17+
package org.tensorflow.model.examples.cnn.vgg;
18+
19+
import org.tensorflow.model.examples.datasets.mnist.MnistDataset;
20+
21+
import java.util.logging.Logger;
22+
23+
/**
24+
* Trains and evaluates VGG'11 model on FashionMNIST dataset.
25+
*/
26+
public class VGG11OnFashionMNIST {
27+
// Hyper-parameters
28+
public static final int EPOCHS = 1;
29+
public static final int BATCH_SIZE = 500;
30+
31+
// Fashion MNIST dataset paths
32+
public static final String TRAINING_IMAGES_ARCHIVE = "fashionmnist/train-images-idx3-ubyte.gz";
33+
public static final String TRAINING_LABELS_ARCHIVE = "fashionmnist/train-labels-idx1-ubyte.gz";
34+
public static final String TEST_IMAGES_ARCHIVE = "fashionmnist/t10k-images-idx3-ubyte.gz";
35+
public static final String TEST_LABELS_ARCHIVE = "fashionmnist/t10k-labels-idx1-ubyte.gz";
36+
37+
private static final Logger logger = Logger.getLogger(VGG11OnFashionMNIST.class.getName());
38+
39+
public static void main(String[] args) {
40+
logger.info("Data loading.");
41+
MnistDataset dataset = MnistDataset.create(0, TRAINING_IMAGES_ARCHIVE, TRAINING_LABELS_ARCHIVE, TEST_IMAGES_ARCHIVE, TEST_LABELS_ARCHIVE);
42+
43+
try (VGGModel vggModel = new VGGModel()) {
44+
logger.info("Model training.");
45+
vggModel.train(dataset, EPOCHS, BATCH_SIZE);
46+
47+
logger.info("Model evaluation.");
48+
vggModel.test(dataset, BATCH_SIZE);
49+
}
50+
}
51+
}

0 commit comments

Comments
 (0)