Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine tf mnist #47

Merged
merged 3 commits into from
Jan 14, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions fluid/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@

SEED = 1
DTYPE = "float32"

# random seed must set before configuring the network.
fluid.default_startup_program().random_seed = SEED
# fluid.default_startup_program().random_seed = SEED


def parse_args():
Expand Down Expand Up @@ -149,10 +150,12 @@ def run_benchmark(model, args):
(pass_id, batch_id, loss, 1 - acc, (end - start) / 1000))

pass_end = time.time()
train_avg_acc = accuracy.eval(exe)
test_avg_acc = eval_test(exe, accuracy, avg_cost)
pass_acc = accuracy.eval(exe)
print("pass=%d, test_avg_acc=%f, test_avg_acc=%f, elapse=%f" %
(pass_id, pass_acc, test_avg_acc, (pass_end - pass_start) / 1000))

print("pass=%d, train_avg_acc=%f, test_avg_acc=%f, elapse=%f" %
(pass_id, train_avg_acc, test_avg_acc,
(pass_end - pass_start) / 1000))


if __name__ == '__main__':
Expand Down
269 changes: 142 additions & 127 deletions tensorflow/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,142 +8,157 @@

import tensorflow as tf
import paddle.v2 as paddle
import paddle.v2.fluid as fluid

BATCH_SIZE = 128
PASS_NUM = 5
SEED = 1
DTYPE = tf.float32


def normal_scale(size, channels):
scale = (2.0 / (size**2 * channels))**0.5
return scale


# NOTE(dzhwinter) : tensorflow use Phliox random algorithm
# as normal generator, fetch out paddle random for comparization
def paddle_random_normal(shape, loc=.0, scale=1., seed=1, dtype="float32"):
program = fluid.framework.Program()
block = program.global_block()
w = block.create_var(
dtype="float32",
shape=shape,
lod_level=0,
name="param",
initializer=fluid.initializer.NormalInitializer(
loc=.0, scale=scale, seed=seed))
place = fluid.CPUPlace()
exe = fluid.Executor(place)
out = exe.run(program, fetch_list=[w])
return np.array(out[0])


train_reader = paddle.batch(paddle.dataset.mnist.train(), batch_size=BATCH_SIZE)
images = tf.placeholder(DTYPE, shape=(None, 28, 28, 1))
labels = tf.placeholder(tf.int64, shape=(None, ))

# conv layer
arg = tf.convert_to_tensor(
np.transpose(
paddle_random_normal(
[20, 1, 5, 5], scale=normal_scale(5, 1), seed=SEED, dtype=DTYPE),
axes=[2, 3, 1, 0]))
conv1_weights = tf.Variable(arg)
conv1_bias = tf.Variable(tf.zeros([20]), dtype=DTYPE)
conv1 = tf.nn.conv2d(
images, conv1_weights, strides=[1, 1, 1, 1], padding="VALID")
relu1 = tf.nn.relu(tf.nn.bias_add(conv1, conv1_bias))
pool1 = tf.nn.max_pool(
relu1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="VALID")

arg = tf.convert_to_tensor(
np.transpose(
paddle_random_normal(
[50, 20, 5, 5], scale=normal_scale(5, 20), seed=SEED, dtype=DTYPE),
axes=[2, 3, 1, 0]))
conv2_weights = tf.Variable(arg)
conv2_bias = tf.Variable(tf.zeros([50]), dtype=DTYPE)
conv2 = tf.nn.conv2d(
pool1, conv2_weights, strides=[1, 1, 1, 1], padding="VALID")
relu2 = tf.nn.relu(tf.nn.bias_add(conv2, conv2_bias))
pool2 = tf.nn.max_pool(
relu2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="VALID")

pool_shape = pool2.get_shape().as_list()
hidden_dim = reduce(lambda a, b: a * b, pool_shape[1:], 1)
reshape = tf.reshape(pool2, shape=(tf.shape(pool2)[0], hidden_dim))

# fc layer
# NOTE(dzhwinter) : paddle has a NCHW data format, tensorflow has a NHWC data format
# need to convert the fc weight
paddle_weight = paddle_random_normal(
[hidden_dim, 10],
scale=normal_scale(hidden_dim, 10),
seed=SEED,
dtype=DTYPE)
new_shape = pool_shape[-1:] + pool_shape[1:-1] + [10]
paddle_weight = np.reshape(paddle_weight, new_shape)
paddle_weight = np.transpose(paddle_weight, [1, 2, 0, 3])

arg = tf.convert_to_tensor(np.reshape(paddle_weight, [hidden_dim, 10]))
fc_weights = tf.Variable(arg, dtype=DTYPE)
fc_bias = tf.Variable(tf.zeros([10]), dtype=DTYPE)
logits = tf.matmul(reshape, fc_weights) + fc_bias

# cross entropy

prediction = tf.nn.softmax(logits)

one_hot_labels = tf.one_hot(labels, depth=10)
cost = -tf.reduce_sum(tf.log(prediction) * one_hot_labels, [1])
avg_cost = tf.reduce_mean(cost)

correct = tf.equal(tf.argmax(prediction, 1), labels)
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
g_accuracy = tf.metrics.accuracy(labels, tf.argmax(prediction, axis=1))

opt = tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.9, beta2=0.999)
train_op = opt.minimize(avg_cost)


def eval_test():
def parse_args():
parser = argparse.ArgumentParser("mnist model benchmark.")
parser.add_argument(
'--batch_size', type=int, default=128, help='The minibatch size.')
parser.add_argument(
'--iterations', type=int, default=35, help='The number of minibatches.')
parser.add_argument(
'--pass_num', type=int, default=5, help='The number of passes.')
parser.add_argument(
'--device',
type=str,
default='GPU',
choices=['CPU', 'GPU'],
help='The device type.')
args = parser.parse_args()
return args


def run_benchmark(args):
def weight_variable(dtype, shape):
initial = tf.truncated_normal(shape, stddev=0.1, dtype=dtype)
return tf.Variable(initial)

def bias_variable(dtype, shape):
initial = tf.constant(0.1, shape=shape, dtype=dtype)
return tf.Variable(initial)

device = '/cpu:0' if args.device == 'CPU' else '/device:GPU:0'
with tf.device(device):
images = tf.placeholder(DTYPE, shape=(None, 28, 28, 1))
labels = tf.placeholder(tf.int64, shape=(None, ))

# conv1, relu, pool1
conv1_weights = weight_variable(DTYPE, [5, 5, 1, 20])
conv1_bias = bias_variable(DTYPE, [20])
conv1 = tf.nn.conv2d(
images, conv1_weights, strides=[1, 1, 1, 1], padding="VALID")
relu1 = tf.nn.relu(tf.nn.bias_add(conv1, conv1_bias))
pool1 = tf.nn.max_pool(
relu1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="VALID")

# conv2, relu, pool2
conv2_weights = weight_variable(DTYPE, [5, 5, 20, 50])
conv2_bias = bias_variable(DTYPE, [50])
conv2 = tf.nn.conv2d(
pool1, conv2_weights, strides=[1, 1, 1, 1], padding="VALID")
relu2 = tf.nn.relu(tf.nn.bias_add(conv2, conv2_bias))
pool2 = tf.nn.max_pool(
relu2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="VALID")

# FC
pool_shape = pool2.get_shape().as_list()
hidden_dim = reduce(lambda a, b: a * b, pool_shape[1:], 1)
reshape = tf.reshape(pool2, shape=(tf.shape(pool2)[0], hidden_dim))
fc_weights = weight_variable(DTYPE, [hidden_dim, 10])
fc_bias = bias_variable(DTYPE, [10])
logits = tf.matmul(reshape, fc_weights) + fc_bias

# Get prediction
prediction = tf.nn.softmax(logits)

# Loss
one_hot_labels = tf.one_hot(labels, depth=10)
cost = -tf.reduce_sum(tf.log(prediction) * one_hot_labels, [1])
avg_cost = tf.reduce_mean(cost)

# Get accuracy
correct = tf.equal(tf.argmax(prediction, 1), labels)
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))

# metrics, g_accuracy
with tf.variable_scope("reset_metrics_accuracy_scope") as scope:
g_accuracy = tf.metrics.accuracy(
labels, tf.argmax(
prediction, axis=1))
vars = tf.contrib.framework.get_variables(
scope, collection=tf.GraphKeys.LOCAL_VARIABLES)
g_accuracy_reset_op = tf.variables_initializer(vars)

# Optimizer
opt = tf.train.AdamOptimizer(
learning_rate=0.001, beta1=0.9, beta2=0.999)
train_op = opt.minimize(avg_cost)
# train_op = tf.train.AdamOptimizer(1e-4).minimize(avg_cost)

train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=args.batch_size)
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=BATCH_SIZE)
for batch_id, data in enumerate(test_reader()):
images_data = np.array(
map(lambda x: np.transpose(x[0].reshape([1, 28, 28]), axes=[1,2,0]), data)).astype("float32")
labels_data = np.array(map(lambda x: x[1], data)).astype("int64")
_, loss, acc, g_acc = sess.run(
[train_op, avg_cost, accuracy, g_accuracy],
feed_dict={images: images_data,
labels: labels_data})
return g_acc[1]


config = tf.ConfigProto(
intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
with tf.Session(config=config) as sess:
init_g = tf.global_variables_initializer()
init_l = tf.local_variables_initializer()
sess.run(init_g)
sess.run(init_l)
for pass_id in range(PASS_NUM):
pass_start = time.time()
for batch_id, data in enumerate(train_reader()):
paddle.dataset.mnist.test(), batch_size=args.batch_size)

def eval_test():
sess.run(g_accuracy_reset_op)
for batch_id, data in enumerate(test_reader()):
images_data = np.array(
map(lambda x: np.transpose(x[0].reshape([1, 28, 28]), axes=[1,2,0]), data)).astype("float32")
labels_data = np.array(map(lambda x: x[1], data)).astype("int64")
start = time.time()

_, loss, acc, g_acc = sess.run(
[train_op, avg_cost, accuracy, g_accuracy],
feed_dict={images: images_data,
labels: labels_data})
end = time.time()

print("pass=%d, batch=%d, loss=%f, error=%f, elapse=%f" %
(pass_id, batch_id, loss, 1 - acc, (end - start) / 1000))
pass_end = time.time()
test_avg_acc = eval_test()
print("pass=%d, training_avg_accuracy=%f, test_avg_acc=%f, elapse=%f" %
(pass_id, g_acc[1], test_avg_acc, (pass_end - pass_start) / 1000))
return g_acc[1]
Copy link
Collaborator

@pkuyym pkuyym Jan 11, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Be careful here, tf.metrics.accuracy will accumulate the acc. Please make sure to reset the accumulating variables each time doing validation. You can refer tensorflow/tensorflow#4814 (comment)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Thx!


config = tf.ConfigProto(
intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
with tf.Session(config=config) as sess:
init_g = tf.global_variables_initializer()
init_l = tf.local_variables_initializer()
sess.run(init_g)
sess.run(init_l)
for pass_id in range(args.pass_num):
sess.run(g_accuracy_reset_op)

pass_start = time.time()
for batch_id, data in enumerate(train_reader()):
images_data = np.array(
map(lambda x: np.transpose(x[0].reshape([1, 28, 28]), axes=[1,2,0]), data)).astype("float32")
labels_data = np.array(map(lambda x: x[1], data)).astype(
"int64")

start = time.time()
_, loss, acc, g_acc = sess.run(
[train_op, avg_cost, accuracy, g_accuracy],
feed_dict={images: images_data,
labels: labels_data})
end = time.time()

print("pass=%d, batch=%d, loss=%f, error=%f, elapse=%f" %
(pass_id, batch_id, loss, 1 - acc, (end - start) / 1000))

pass_end = time.time()
test_avg_acc = eval_test()

print(
"pass=%d, training_avg_accuracy=%f, test_avg_acc=%f, elapse=%f"
% (pass_id, g_acc[1], test_avg_acc,
(pass_end - pass_start) / 1000))


def print_arguments(args):
print('----------- Configuration Arguments -----------')
for arg, value in sorted(vars(args).iteritems()):
print('%s: %s' % (arg, value))
print('------------------------------------------------')


if __name__ == '__main__':
args = parse_args()
print_arguments(args)
run_benchmark(args)