Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add quantization aware training (QAT) support for Resnet 50 #584

Merged
merged 8 commits into from
Jul 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions TensorFlow/Classification/ConvNets/export_frozen_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os

import tensorflow as tf

import horovod.tensorflow as hvd
from model import resnet

tf.app.flags.DEFINE_string(
'model_name', 'resnet50', 'The name of the architecture to save. The default name was being '
'used to train the model')

tf.app.flags.DEFINE_integer(
'image_size', 224,
'The image size to use, otherwise use the model default_image_size.')

tf.app.flags.DEFINE_integer(
'num_classes', 1001,
'The number of classes to predict.')

tf.app.flags.DEFINE_integer(
'batch_size', None,
'Batch size for the exported model. Defaulted to "None" so batch size can '
'be specified at model runtime.')


tf.app.flags.DEFINE_string('input_format', 'NCHW',
'The dataformat used by the layers in the model')

tf.app.flags.DEFINE_string('compute_format', 'NCHW',
'The dataformat used by the layers in the model')

tf.app.flags.DEFINE_string('checkpoint', '',
'The trained model checkpoint.')

tf.app.flags.DEFINE_string(
'output_file', '', 'Where to save the resulting file to.')

tf.app.flags.DEFINE_bool(
'quantize', False, 'whether to use quantized graph or not.')

tf.app.flags.DEFINE_bool(
'symmetric', False, 'Using symmetric quantization or not.')


tf.app.flags.DEFINE_bool(
'use_qdq', False, 'Use quantize and dequantize op instead of fake quant op')

tf.app.flags.DEFINE_bool(
'use_final_conv', False, 'whether to use quantized graph or not.')

tf.app.flags.DEFINE_bool('write_text_graphdef', False,
'Whether to write a text version of graphdef.')

FLAGS = tf.app.flags.FLAGS


def main(_):

# Initialize Horovod (TODO: Remove dependency of horovod for freezing graphs)
hvd.init()

if not FLAGS.output_file:
raise ValueError('You must supply the path to save to with --output_file')

tf.logging.set_verbosity(tf.logging.INFO)
with tf.Graph().as_default() as graph:
if FLAGS.input_format=='NCHW':
input_shape = [FLAGS.batch_size, 3, FLAGS.image_size, FLAGS.image_size]
else:
input_shape = [FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, 3]
input_images = tf.placeholder(name='input', dtype=tf.float32, shape=input_shape)

resnet50_config = resnet.model_architectures[FLAGS.model_name]
network = resnet.ResnetModel(FLAGS.model_name,
FLAGS.num_classes,
resnet50_config['layers'],
resnet50_config['widths'],
resnet50_config['expansions'],
FLAGS.compute_format,
FLAGS.input_format)
probs, logits = network.build_model(
input_images,
training=False,
reuse=False,
use_final_conv=FLAGS.use_final_conv)

if FLAGS.quantize:
tf.contrib.quantize.experimental_create_eval_graph(symmetric=FLAGS.symmetric,
use_qdq=FLAGS.use_qdq)

# Define the saver and restore the checkpoint
saver = tf.train.Saver()
with tf.Session() as sess:
if FLAGS.checkpoint:
saver.restore(sess, FLAGS.checkpoint)
else:
sess.run(tf.global_variables_initializer())
graph_def = graph.as_graph_def()
frozen_graph_def = tf.graph_util.convert_variables_to_constants(sess, graph_def, [probs.op.name])

# Write out the frozen graph
tf.io.write_graph(
frozen_graph_def,
os.path.dirname(FLAGS.output_file),
os.path.basename(FLAGS.output_file),
as_text=FLAGS.write_text_graphdef)


if __name__ == '__main__':
tf.app.run()
14 changes: 12 additions & 2 deletions TensorFlow/Classification/ConvNets/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@
use_static_loss_scaling=FLAGS.use_static_loss_scaling,
use_cosine_lr=FLAGS.use_cosine_lr,
is_benchmark=FLAGS.mode == 'training_benchmark',
use_final_conv=FLAGS.use_final_conv,
quantize=FLAGS.quantize,
symmetric=FLAGS.symmetric,
quant_delay = FLAGS.quant_delay,
use_qdq = FLAGS.use_qdq,
finetune_checkpoint = FLAGS.finetune_checkpoint,
)

if FLAGS.mode in ["train_and_evaluate", 'evaluate', 'inference_benchmark']:
Expand All @@ -110,7 +116,11 @@
batch_size=FLAGS.batch_size,
log_every_n_steps=FLAGS.display_every,
is_benchmark=FLAGS.mode == 'inference_benchmark',
export_dir=FLAGS.export_dir
export_dir=FLAGS.export_dir,
quantize=FLAGS.quantize,
symmetric=FLAGS.symmetric,
use_final_conv=FLAGS.use_final_conv,
use_qdq=FLAGS.use_qdq
)

if FLAGS.mode == 'predict':
Expand All @@ -124,4 +134,4 @@
raise NotImplementedError("Only single GPU inference is implemented.")

elif not hvd_utils.is_using_hvd() or hvd.rank() == 0:
runner.predict(FLAGS.to_predict)
runner.predict(FLAGS.to_predict, quantize=FLAGS.quantize, symmetric=FLAGS.symmetric, use_qdq=FLAGS.use_qdq, use_final_conv=FLAGS.use_final_conv)
6 changes: 4 additions & 2 deletions TensorFlow/Classification/ConvNets/model/layers/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def conv2d(
use_bias=True,
kernel_initializer=tf.variance_scaling_initializer(),
bias_initializer=tf.zeros_initializer(),
trainable=True
trainable=True,
name=None
):

if data_format not in ['NHWC', 'NCHW']:
Expand All @@ -52,7 +53,8 @@ def conv2d(
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
trainable=trainable,
activation=None
activation=None,
name=name
)

return net
Expand Down
77 changes: 54 additions & 23 deletions TensorFlow/Classification/ConvNets/model/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import tensorflow as tf

import horovod.tensorflow as hvd
import dllogger

from model import layers
from model import blocks
Expand Down Expand Up @@ -183,8 +184,12 @@ def __call__(self, features, labels, mode, params):
probs, logits = self.build_model(
features,
training=mode == tf.estimator.ModeKeys.TRAIN,
reuse=False
reuse=False,
use_final_conv=params['use_final_conv']
)

if mode!=tf.estimator.ModeKeys.PREDICT:
logits = tf.squeeze(logits)

y_preds = tf.argmax(logits, axis=1, output_type=tf.int32)

Expand All @@ -196,16 +201,25 @@ def __call__(self, features, labels, mode, params):
tf.identity(logits, name="logits_ref")
tf.identity(probs, name="probs_ref")
tf.identity(y_preds, name="y_preds_ref")

#if mode == tf.estimator.ModeKeys.TRAIN:
#
# assert (len(tf.trainable_variables()) == 161)
#
#else:
#
# assert (len(tf.trainable_variables()) == 0)



if mode == tf.estimator.ModeKeys.TRAIN and params['quantize']:
dllogger.log(data={"QUANTIZATION AWARE TRAINING ENABLED": True}, step=tuple())
if params['symmetric']:
dllogger.log(data={"MODE":"USING SYMMETRIC MODE"}, step=tuple())
tf.contrib.quantize.experimental_create_training_graph(tf.get_default_graph(), symmetric=True, use_qdq=params['use_qdq'] ,quant_delay=params['quant_delay'])
else:
dllogger.log(data={"MODE":"USING ASSYMETRIC MODE"}, step=tuple())
tf.contrib.quantize.create_training_graph(tf.get_default_graph(), quant_delay=params['quant_delay'], use_qdq=params['use_qdq'])

# Fix for restoring variables during fine-tuning of Resnet-50
if 'finetune_checkpoint' in params.keys():
train_vars = tf.trainable_variables()
train_var_dict = {}
for var in train_vars:
train_var_dict[var.op.name] = var
dllogger.log(data={"Restoring variables from checkpoint": params['finetune_checkpoint']}, step=tuple())
tf.train.init_from_checkpoint(params['finetune_checkpoint'], train_var_dict)

if mode == tf.estimator.ModeKeys.PREDICT:

predictions = {'classes': y_preds, 'probabilities': probs}
Expand Down Expand Up @@ -352,7 +366,7 @@ def _stage(tensors):



def build_model(self, inputs, training=True, reuse=False):
def build_model(self, inputs, training=True, reuse=False, use_final_conv=False):

with var_storage.model_variable_scope(
self.model_hparams.model_name,
Expand Down Expand Up @@ -416,20 +430,37 @@ def build_model(self, inputs, training=True, reuse=False):

with tf.variable_scope("output"):
net = layers.reduce_mean(
net, keepdims=False, data_format=self.model_hparams.compute_format, name='spatial_mean')

logits = layers.dense(
inputs=net,
units=self.model_hparams.n_classes,
use_bias=True,
trainable=training,
kernel_initializer=self.dense_hparams.kernel_initializer,
bias_initializer=self.dense_hparams.bias_initializer)
net, keepdims=use_final_conv, data_format=self.model_hparams.compute_format, name='spatial_mean')

if use_final_conv:
logits = layers.conv2d(
net,
n_channels=self.model_hparams.n_classes,
kernel_size=(1, 1),
strides=(1, 1),
padding='SAME',
data_format=self.model_hparams.compute_format,
dilation_rate=(1, 1),
use_bias=True,
kernel_initializer=self.dense_hparams.kernel_initializer,
bias_initializer=self.dense_hparams.bias_initializer,
trainable=training,
name='dense'
)
else:
logits = layers.dense(
inputs=net,
units=self.model_hparams.n_classes,
use_bias=True,
trainable=training,
kernel_initializer=self.dense_hparams.kernel_initializer,
bias_initializer=self.dense_hparams.bias_initializer)

if logits.dtype != tf.float32:
logits = tf.cast(logits, tf.float32)

probs = layers.softmax(logits, name="softmax", axis=1)

axis = 3 if self.model_hparams.compute_format=="NHWC" and use_final_conv else 1
probs = layers.softmax(logits, name="softmax", axis=axis)

return probs, logits

Expand Down
74 changes: 74 additions & 0 deletions TensorFlow/Classification/ConvNets/postprocess_ckpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf
import numpy as np
import argparse
import os

def process_checkpoint(input_ckpt, output_ckpt_path, dense_layer):
"""
This function loads a RN50 checkpoint with Dense layer as the final layer
and transforms the final dense layer into a 1x1 convolution layer. The weights
of the dense layer are reshaped into weights of 1x1 conv layer.
Args:
input_ckpt: Path to the input RN50 ckpt which has dense layer as classification layer.
Returns:
None. New checkpoint with 1x1 conv layer as classification layer is generated.
"""
with tf.Session() as sess:
# Load all the variables
all_vars = tf.train.list_variables(input_ckpt)
# Capture the dense layer weights and reshape them to a 4D tensor which would be
# the weights of a 1x1 convolution layer. This code replaces the dense (FC) layer
# to a 1x1 conv layer.
dense_layer_value=0.
new_var_list=[]
for var in all_vars:
curr_var = tf.train.load_variable(input_ckpt, var[0])
if var[0]==dense_layer:
dense_layer_value = curr_var
else:
new_var_list.append(tf.Variable(curr_var, name=var[0]))

dense_layer_shape = [1, 1, 2048, 1001]
new_var_value = np.reshape(dense_layer_value, dense_layer_shape)
new_var = tf.Variable(new_var_value, name=dense_layer)
new_var_list.append(new_var)

sess.run(tf.global_variables_initializer())
tf.train.Saver(var_list=new_var_list).save(sess, output_ckpt_path, write_meta_graph=False, write_state=False)
print ("Rewriting checkpoint completed")

if __name__=='__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--input', type=str, required=True, help='Path to pretrained RN50 checkpoint with dense layer')
parser.add_argument('--dense_layer', type=str, default='resnet50/output/dense/kernel')
parser.add_argument('--output', type=str, default='output_dir', help="Output directory to store new checkpoint")
args = parser.parse_args()

input_ckpt = args.input
# Create an output directory
os.mkdir(args.output)

new_ckpt='new.ckpt'
new_ckpt_path = os.path.join(args.output, new_ckpt)
with open(os.path.join(args.output, "checkpoint"), 'w') as file:
file.write("model_checkpoint_path: "+ "\"" + new_ckpt + "\"")

# Process the input checkpoint, apply transforms and generate a new checkpoint.
process_checkpoint(input_ckpt, new_ckpt_path, args.dense_layer)
Loading