Skip to content

Commit 6397c77

Browse files
Merge pull request NVIDIA#584 from peri044/rn50_qat_v2
Add quantization aware training (QAT) support for Resnet 50
2 parents 80f9ee9 + f60f8a8 commit 6397c77

File tree

9 files changed

+408
-33
lines changed

9 files changed

+408
-33
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
4+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
import os
22+
23+
import tensorflow as tf
24+
25+
import horovod.tensorflow as hvd
26+
from model import resnet
27+
28+
tf.app.flags.DEFINE_string(
29+
'model_name', 'resnet50', 'The name of the architecture to save. The default name was being '
30+
'used to train the model')
31+
32+
tf.app.flags.DEFINE_integer(
33+
'image_size', 224,
34+
'The image size to use, otherwise use the model default_image_size.')
35+
36+
tf.app.flags.DEFINE_integer(
37+
'num_classes', 1001,
38+
'The number of classes to predict.')
39+
40+
tf.app.flags.DEFINE_integer(
41+
'batch_size', None,
42+
'Batch size for the exported model. Defaulted to "None" so batch size can '
43+
'be specified at model runtime.')
44+
45+
46+
tf.app.flags.DEFINE_string('input_format', 'NCHW',
47+
'The dataformat used by the layers in the model')
48+
49+
tf.app.flags.DEFINE_string('compute_format', 'NCHW',
50+
'The dataformat used by the layers in the model')
51+
52+
tf.app.flags.DEFINE_string('checkpoint', '',
53+
'The trained model checkpoint.')
54+
55+
tf.app.flags.DEFINE_string(
56+
'output_file', '', 'Where to save the resulting file to.')
57+
58+
tf.app.flags.DEFINE_bool(
59+
'quantize', False, 'whether to use quantized graph or not.')
60+
61+
tf.app.flags.DEFINE_bool(
62+
'symmetric', False, 'Using symmetric quantization or not.')
63+
64+
65+
tf.app.flags.DEFINE_bool(
66+
'use_qdq', False, 'Use quantize and dequantize op instead of fake quant op')
67+
68+
tf.app.flags.DEFINE_bool(
69+
'use_final_conv', False, 'whether to use quantized graph or not.')
70+
71+
tf.app.flags.DEFINE_bool('write_text_graphdef', False,
72+
'Whether to write a text version of graphdef.')
73+
74+
FLAGS = tf.app.flags.FLAGS
75+
76+
77+
def main(_):
78+
79+
# Initialize Horovod (TODO: Remove dependency of horovod for freezing graphs)
80+
hvd.init()
81+
82+
if not FLAGS.output_file:
83+
raise ValueError('You must supply the path to save to with --output_file')
84+
85+
tf.logging.set_verbosity(tf.logging.INFO)
86+
with tf.Graph().as_default() as graph:
87+
if FLAGS.input_format=='NCHW':
88+
input_shape = [FLAGS.batch_size, 3, FLAGS.image_size, FLAGS.image_size]
89+
else:
90+
input_shape = [FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, 3]
91+
input_images = tf.placeholder(name='input', dtype=tf.float32, shape=input_shape)
92+
93+
resnet50_config = resnet.model_architectures[FLAGS.model_name]
94+
network = resnet.ResnetModel(FLAGS.model_name,
95+
FLAGS.num_classes,
96+
resnet50_config['layers'],
97+
resnet50_config['widths'],
98+
resnet50_config['expansions'],
99+
FLAGS.compute_format,
100+
FLAGS.input_format)
101+
probs, logits = network.build_model(
102+
input_images,
103+
training=False,
104+
reuse=False,
105+
use_final_conv=FLAGS.use_final_conv)
106+
107+
if FLAGS.quantize:
108+
tf.contrib.quantize.experimental_create_eval_graph(symmetric=FLAGS.symmetric,
109+
use_qdq=FLAGS.use_qdq)
110+
111+
# Define the saver and restore the checkpoint
112+
saver = tf.train.Saver()
113+
with tf.Session() as sess:
114+
if FLAGS.checkpoint:
115+
saver.restore(sess, FLAGS.checkpoint)
116+
else:
117+
sess.run(tf.global_variables_initializer())
118+
graph_def = graph.as_graph_def()
119+
frozen_graph_def = tf.graph_util.convert_variables_to_constants(sess, graph_def, [probs.op.name])
120+
121+
# Write out the frozen graph
122+
tf.io.write_graph(
123+
frozen_graph_def,
124+
os.path.dirname(FLAGS.output_file),
125+
os.path.basename(FLAGS.output_file),
126+
as_text=FLAGS.write_text_graphdef)
127+
128+
129+
if __name__ == '__main__':
130+
tf.app.run()

TensorFlow/Classification/ConvNets/main.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,12 @@
9494
use_static_loss_scaling=FLAGS.use_static_loss_scaling,
9595
use_cosine_lr=FLAGS.use_cosine_lr,
9696
is_benchmark=FLAGS.mode == 'training_benchmark',
97+
use_final_conv=FLAGS.use_final_conv,
98+
quantize=FLAGS.quantize,
99+
symmetric=FLAGS.symmetric,
100+
quant_delay = FLAGS.quant_delay,
101+
use_qdq = FLAGS.use_qdq,
102+
finetune_checkpoint = FLAGS.finetune_checkpoint,
97103
)
98104

99105
if FLAGS.mode in ["train_and_evaluate", 'evaluate', 'inference_benchmark']:
@@ -110,7 +116,11 @@
110116
batch_size=FLAGS.batch_size,
111117
log_every_n_steps=FLAGS.display_every,
112118
is_benchmark=FLAGS.mode == 'inference_benchmark',
113-
export_dir=FLAGS.export_dir
119+
export_dir=FLAGS.export_dir,
120+
quantize=FLAGS.quantize,
121+
symmetric=FLAGS.symmetric,
122+
use_final_conv=FLAGS.use_final_conv,
123+
use_qdq=FLAGS.use_qdq
114124
)
115125

116126
if FLAGS.mode == 'predict':
@@ -124,4 +134,4 @@
124134
raise NotImplementedError("Only single GPU inference is implemented.")
125135

126136
elif not hvd_utils.is_using_hvd() or hvd.rank() == 0:
127-
runner.predict(FLAGS.to_predict)
137+
runner.predict(FLAGS.to_predict, quantize=FLAGS.quantize, symmetric=FLAGS.symmetric, use_qdq=FLAGS.use_qdq, use_final_conv=FLAGS.use_final_conv)

TensorFlow/Classification/ConvNets/model/layers/conv2d.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ def conv2d(
3131
use_bias=True,
3232
kernel_initializer=tf.variance_scaling_initializer(),
3333
bias_initializer=tf.zeros_initializer(),
34-
trainable=True
34+
trainable=True,
35+
name=None
3536
):
3637

3738
if data_format not in ['NHWC', 'NCHW']:
@@ -52,7 +53,8 @@ def conv2d(
5253
kernel_initializer=kernel_initializer,
5354
bias_initializer=bias_initializer,
5455
trainable=trainable,
55-
activation=None
56+
activation=None,
57+
name=name
5658
)
5759

5860
return net

TensorFlow/Classification/ConvNets/model/resnet.py

+54-23
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import tensorflow as tf
2222

2323
import horovod.tensorflow as hvd
24+
import dllogger
2425

2526
from model import layers
2627
from model import blocks
@@ -183,8 +184,12 @@ def __call__(self, features, labels, mode, params):
183184
probs, logits = self.build_model(
184185
features,
185186
training=mode == tf.estimator.ModeKeys.TRAIN,
186-
reuse=False
187+
reuse=False,
188+
use_final_conv=params['use_final_conv']
187189
)
190+
191+
if mode!=tf.estimator.ModeKeys.PREDICT:
192+
logits = tf.squeeze(logits)
188193

189194
y_preds = tf.argmax(logits, axis=1, output_type=tf.int32)
190195

@@ -196,16 +201,25 @@ def __call__(self, features, labels, mode, params):
196201
tf.identity(logits, name="logits_ref")
197202
tf.identity(probs, name="probs_ref")
198203
tf.identity(y_preds, name="y_preds_ref")
199-
200-
#if mode == tf.estimator.ModeKeys.TRAIN:
201-
#
202-
# assert (len(tf.trainable_variables()) == 161)
203-
#
204-
#else:
205-
#
206-
# assert (len(tf.trainable_variables()) == 0)
207-
208-
204+
205+
if mode == tf.estimator.ModeKeys.TRAIN and params['quantize']:
206+
dllogger.log(data={"QUANTIZATION AWARE TRAINING ENABLED": True}, step=tuple())
207+
if params['symmetric']:
208+
dllogger.log(data={"MODE":"USING SYMMETRIC MODE"}, step=tuple())
209+
tf.contrib.quantize.experimental_create_training_graph(tf.get_default_graph(), symmetric=True, use_qdq=params['use_qdq'] ,quant_delay=params['quant_delay'])
210+
else:
211+
dllogger.log(data={"MODE":"USING ASSYMETRIC MODE"}, step=tuple())
212+
tf.contrib.quantize.create_training_graph(tf.get_default_graph(), quant_delay=params['quant_delay'], use_qdq=params['use_qdq'])
213+
214+
# Fix for restoring variables during fine-tuning of Resnet-50
215+
if 'finetune_checkpoint' in params.keys():
216+
train_vars = tf.trainable_variables()
217+
train_var_dict = {}
218+
for var in train_vars:
219+
train_var_dict[var.op.name] = var
220+
dllogger.log(data={"Restoring variables from checkpoint": params['finetune_checkpoint']}, step=tuple())
221+
tf.train.init_from_checkpoint(params['finetune_checkpoint'], train_var_dict)
222+
209223
if mode == tf.estimator.ModeKeys.PREDICT:
210224

211225
predictions = {'classes': y_preds, 'probabilities': probs}
@@ -352,7 +366,7 @@ def _stage(tensors):
352366

353367

354368

355-
def build_model(self, inputs, training=True, reuse=False):
369+
def build_model(self, inputs, training=True, reuse=False, use_final_conv=False):
356370

357371
with var_storage.model_variable_scope(
358372
self.model_hparams.model_name,
@@ -416,20 +430,37 @@ def build_model(self, inputs, training=True, reuse=False):
416430

417431
with tf.variable_scope("output"):
418432
net = layers.reduce_mean(
419-
net, keepdims=False, data_format=self.model_hparams.compute_format, name='spatial_mean')
420-
421-
logits = layers.dense(
422-
inputs=net,
423-
units=self.model_hparams.n_classes,
424-
use_bias=True,
425-
trainable=training,
426-
kernel_initializer=self.dense_hparams.kernel_initializer,
427-
bias_initializer=self.dense_hparams.bias_initializer)
433+
net, keepdims=use_final_conv, data_format=self.model_hparams.compute_format, name='spatial_mean')
434+
435+
if use_final_conv:
436+
logits = layers.conv2d(
437+
net,
438+
n_channels=self.model_hparams.n_classes,
439+
kernel_size=(1, 1),
440+
strides=(1, 1),
441+
padding='SAME',
442+
data_format=self.model_hparams.compute_format,
443+
dilation_rate=(1, 1),
444+
use_bias=True,
445+
kernel_initializer=self.dense_hparams.kernel_initializer,
446+
bias_initializer=self.dense_hparams.bias_initializer,
447+
trainable=training,
448+
name='dense'
449+
)
450+
else:
451+
logits = layers.dense(
452+
inputs=net,
453+
units=self.model_hparams.n_classes,
454+
use_bias=True,
455+
trainable=training,
456+
kernel_initializer=self.dense_hparams.kernel_initializer,
457+
bias_initializer=self.dense_hparams.bias_initializer)
428458

429459
if logits.dtype != tf.float32:
430460
logits = tf.cast(logits, tf.float32)
431-
432-
probs = layers.softmax(logits, name="softmax", axis=1)
461+
462+
axis = 3 if self.model_hparams.compute_format=="NHWC" and use_final_conv else 1
463+
probs = layers.softmax(logits, name="softmax", axis=axis)
433464

434465
return probs, logits
435466

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
4+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
import tensorflow as tf
19+
import numpy as np
20+
import argparse
21+
import os
22+
23+
def process_checkpoint(input_ckpt, output_ckpt_path, dense_layer):
24+
"""
25+
This function loads a RN50 checkpoint with Dense layer as the final layer
26+
and transforms the final dense layer into a 1x1 convolution layer. The weights
27+
of the dense layer are reshaped into weights of 1x1 conv layer.
28+
Args:
29+
input_ckpt: Path to the input RN50 ckpt which has dense layer as classification layer.
30+
Returns:
31+
None. New checkpoint with 1x1 conv layer as classification layer is generated.
32+
"""
33+
with tf.Session() as sess:
34+
# Load all the variables
35+
all_vars = tf.train.list_variables(input_ckpt)
36+
# Capture the dense layer weights and reshape them to a 4D tensor which would be
37+
# the weights of a 1x1 convolution layer. This code replaces the dense (FC) layer
38+
# to a 1x1 conv layer.
39+
dense_layer_value=0.
40+
new_var_list=[]
41+
for var in all_vars:
42+
curr_var = tf.train.load_variable(input_ckpt, var[0])
43+
if var[0]==dense_layer:
44+
dense_layer_value = curr_var
45+
else:
46+
new_var_list.append(tf.Variable(curr_var, name=var[0]))
47+
48+
dense_layer_shape = [1, 1, 2048, 1001]
49+
new_var_value = np.reshape(dense_layer_value, dense_layer_shape)
50+
new_var = tf.Variable(new_var_value, name=dense_layer)
51+
new_var_list.append(new_var)
52+
53+
sess.run(tf.global_variables_initializer())
54+
tf.train.Saver(var_list=new_var_list).save(sess, output_ckpt_path, write_meta_graph=False, write_state=False)
55+
print ("Rewriting checkpoint completed")
56+
57+
if __name__=='__main__':
58+
parser = argparse.ArgumentParser()
59+
parser.add_argument('--input', type=str, required=True, help='Path to pretrained RN50 checkpoint with dense layer')
60+
parser.add_argument('--dense_layer', type=str, default='resnet50/output/dense/kernel')
61+
parser.add_argument('--output', type=str, default='output_dir', help="Output directory to store new checkpoint")
62+
args = parser.parse_args()
63+
64+
input_ckpt = args.input
65+
# Create an output directory
66+
os.mkdir(args.output)
67+
68+
new_ckpt='new.ckpt'
69+
new_ckpt_path = os.path.join(args.output, new_ckpt)
70+
with open(os.path.join(args.output, "checkpoint"), 'w') as file:
71+
file.write("model_checkpoint_path: "+ "\"" + new_ckpt + "\"")
72+
73+
# Process the input checkpoint, apply transforms and generate a new checkpoint.
74+
process_checkpoint(input_ckpt, new_ckpt_path, args.dense_layer)

0 commit comments

Comments
 (0)