|
1 | 1 | /*
|
2 |
| - * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. |
| 2 | + * Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
3 | 3 | *
|
4 |
| - * Licensed under the Apache License, Version 2.0 (the "License"); |
5 |
| - * you may not use this file except in compliance with the License. |
6 |
| - * You may obtain a copy of the License at |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
7 | 7 | *
|
8 |
| - * http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
9 | 9 | *
|
10 |
| - * Unless required by applicable law or agreed to in writing, software |
11 |
| - * distributed under the License is distributed on an "AS IS" BASIS, |
12 |
| - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 |
| - * See the License for the specific language governing permissions and |
14 |
| - * limitations under the License. |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + * ======================================================================= |
15 | 16 | */
|
16 |
| -package org.tensorflow.model.examples.mnist; |
| 17 | +package org.tensorflow.model.examples.cnn.lenet; |
17 | 18 |
|
18 | 19 | import java.util.Arrays;
|
19 | 20 | import java.util.logging.Level;
|
|
22 | 23 | import org.tensorflow.Operand;
|
23 | 24 | import org.tensorflow.Session;
|
24 | 25 | import org.tensorflow.Tensor;
|
25 |
| -import org.tensorflow.model.examples.mnist.data.ImageBatch; |
26 |
| -import org.tensorflow.model.examples.mnist.data.MnistDataset; |
| 26 | +import org.tensorflow.model.examples.datasets.ImageBatch; |
| 27 | +import org.tensorflow.model.examples.datasets.mnist.MnistDataset; |
27 | 28 | import org.tensorflow.op.Op;
|
28 | 29 | import org.tensorflow.op.Ops;
|
29 | 30 | import org.tensorflow.op.core.Constant;
|
@@ -76,6 +77,11 @@ public class CnnMnist {
|
76 | 77 | public static final String TRAINING_LOSS = "training_loss";
|
77 | 78 | public static final String INIT = "init";
|
78 | 79 |
|
| 80 | + private static final String TRAINING_IMAGES_ARCHIVE = "mnist/train-images-idx3-ubyte.gz"; |
| 81 | + private static final String TRAINING_LABELS_ARCHIVE = "mnist/train-labels-idx1-ubyte.gz"; |
| 82 | + private static final String TEST_IMAGES_ARCHIVE = "mnist/t10k-images-idx3-ubyte.gz"; |
| 83 | + private static final String TEST_LABELS_ARCHIVE = "mnist/t10k-labels-idx1-ubyte.gz"; |
| 84 | + |
79 | 85 | public static Graph build(String optimizerName) {
|
80 | 86 | Graph graph = new Graph();
|
81 | 87 |
|
@@ -294,7 +300,8 @@ public static void main(String[] args) {
|
294 | 300 | logger.info(
|
295 | 301 | "Usage: MNISTTest <num-epochs> <minibatch-size> <optimizer-name>");
|
296 | 302 |
|
297 |
| - MnistDataset dataset = MnistDataset.create(0); |
| 303 | + MnistDataset dataset = MnistDataset.create(0, TRAINING_IMAGES_ARCHIVE, TRAINING_LABELS_ARCHIVE, |
| 304 | + TEST_IMAGES_ARCHIVE, TEST_LABELS_ARCHIVE); |
298 | 305 |
|
299 | 306 | logger.info("Loaded data.");
|
300 | 307 |
|
|
0 commit comments