Skip to content

Commit 42e5f2a

Browse files
committed
Refactor mnist convnet for multi GPU usage
1 parent a566134 commit 42e5f2a

File tree

3 files changed

+7
-6
lines changed

3 files changed

+7
-6
lines changed

mnist_convnet/config.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,7 @@ hooks:
1919
- ComputeStats:
2020
variables: [loss, accuracy]
2121
- LogVariables
22+
- LogProfile
2223
- StopAfter:
2324
minutes: 2
25+

mnist_convnet/convnet.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,19 @@
66
class SimpleConvNet(cxtf.BaseModel):
77

88
def _create_model(self):
9-
images = tf.placeholder(tf.float32, shape=[None, 28, 28], name='images')
9+
images = tf.placeholder(tf.float32, shape=[None, 28, 28, 1], name='images')
1010
labels = tf.placeholder(tf.int64, shape=[None], name='labels')
1111

1212
with tf.variable_scope('conv1'):
13-
net = tf.expand_dims(images, -1)
14-
net = K.layers.Conv2D(20, 5)(net)
13+
net = K.layers.Conv2D(64, 5)(images)
1514
net = K.layers.MaxPool2D()(net)
1615
with tf.variable_scope('conv2'):
17-
net = K.layers.Conv2D(50, 3)(net)
16+
net = K.layers.Conv2D(128, 3)(net)
1817
net = K.layers.MaxPool2D()(net)
1918
with tf.variable_scope('dense3'):
2019
net = K.layers.Flatten()(net)
2120
net = K.layers.Dropout(0.4).apply(net, training=self.is_training)
22-
net = K.layers.Dense(100)(net)
21+
net = K.layers.Dense(64)(net)
2322
with tf.variable_scope('dense4'):
2423
logits = K.layers.Dense(10, activation=None)(net)
2524

mnist_convnet/mnist_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def _load_data(self) -> None:
3535
with gzip.open(file_path, 'rb') as file:
3636
if 'images' in key:
3737
_, _, rows, cols = struct.unpack(">IIII", file.read(16))
38-
self._data[key] = np.frombuffer(file.read(), dtype=np.uint8).reshape(-1, rows, cols)
38+
self._data[key] = np.frombuffer(file.read(), dtype=np.uint8).reshape(-1, rows, cols, 1)
3939
else:
4040
_ = struct.unpack(">II", file.read(8))
4141
self._data[key] = np.frombuffer(file.read(), dtype=np.int8)

0 commit comments

Comments
 (0)