-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4601186
commit 70a936b
Showing
19 changed files
with
4,310 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""A simple script for inspect checkpoint files.""" | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import argparse | ||
import sys | ||
|
||
import numpy as np | ||
|
||
from tensorflow.python import pywrap_tensorflow | ||
from tensorflow.python.platform import app | ||
from tensorflow.python.platform import flags | ||
|
||
FLAGS = None | ||
|
||
|
||
def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors): | ||
"""Prints tensors in a checkpoint file. | ||
If no `tensor_name` is provided, prints the tensor names and shapes | ||
in the checkpoint file. | ||
If `tensor_name` is provided, prints the content of the tensor. | ||
Args: | ||
file_name: Name of the checkpoint file. | ||
tensor_name: Name of the tensor in the checkpoint file to print. | ||
all_tensors: Boolean indicating whether to print all tensors. | ||
""" | ||
try: | ||
reader = pywrap_tensorflow.NewCheckpointReader(file_name) | ||
if all_tensors: | ||
var_to_shape_map = reader.get_variable_to_shape_map() | ||
for key in var_to_shape_map: | ||
print("tensor_name: ", key) | ||
print(reader.get_tensor(key)) | ||
elif not tensor_name: | ||
print(reader.debug_string().decode("utf-8")) | ||
else: | ||
print("tensor_name: ", tensor_name) | ||
print(reader.get_tensor(tensor_name)) | ||
except Exception as e: # pylint: disable=broad-except | ||
print(str(e)) | ||
if "corrupted compressed block contents" in str(e): | ||
print("It's likely that your checkpoint file has been compressed " | ||
"with SNAPPY.") | ||
|
||
|
||
def parse_numpy_printoption(kv_str): | ||
"""Sets a single numpy printoption from a string of the form 'x=y'. | ||
See documentation on numpy.set_printoptions() for details about what values | ||
x and y can take. x can be any option listed there other than 'formatter'. | ||
Args: | ||
kv_str: A string of the form 'x=y', such as 'threshold=100000' | ||
Raises: | ||
argparse.ArgumentTypeError: If the string couldn't be used to set any | ||
nump printoption. | ||
""" | ||
k_v_str = kv_str.split("=", 1) | ||
if len(k_v_str) != 2 or not k_v_str[0]: | ||
raise argparse.ArgumentTypeError("'%s' is not in the form k=v." % kv_str) | ||
k, v_str = k_v_str | ||
printoptions = np.get_printoptions() | ||
if k not in printoptions: | ||
raise argparse.ArgumentTypeError("'%s' is not a valid printoption." % k) | ||
v_type = type(printoptions[k]) | ||
if v_type is type(None): | ||
raise argparse.ArgumentTypeError( | ||
"Setting '%s' from the command line is not supported." % k) | ||
try: | ||
v = (v_type(v_str) if v_type is not bool | ||
else flags.BooleanParser().Parse(v_str)) | ||
except ValueError as e: | ||
raise argparse.ArgumentTypeError(e.message) | ||
np.set_printoptions(**{k: v}) | ||
|
||
|
||
def main(unused_argv): | ||
if not FLAGS.file_name: | ||
print("Usage: inspect_checkpoint --file_name=checkpoint_file_name " | ||
"[--tensor_name=tensor_to_print]") | ||
sys.exit(1) | ||
else: | ||
print_tensors_in_checkpoint_file(FLAGS.file_name, FLAGS.tensor_name, | ||
FLAGS.all_tensors) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.register("type", "bool", lambda v: v.lower() == "true") | ||
parser.add_argument( | ||
"--file_name", type=str, default="", help="Checkpoint filename. " | ||
"Note, if using Checkpoint V2 format, file_name is the " | ||
"shared prefix between all files in the checkpoint.") | ||
parser.add_argument( | ||
"--tensor_name", | ||
type=str, | ||
default="", | ||
help="Name of the tensor to inspect") | ||
parser.add_argument( | ||
"--all_tensors", | ||
nargs="?", | ||
const=True, | ||
type="bool", | ||
default=False, | ||
help="If True, print the values of all the tensors.") | ||
parser.add_argument( | ||
"--printoptions", | ||
nargs="*", | ||
type=parse_numpy_printoption, | ||
help="Argument for numpy.set_printoptions(), in the form 'k=v'.") | ||
FLAGS, unparsed = parser.parse_known_args() | ||
app.run(main=main, argv=[sys.argv[0]] + unparsed) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# *_* coding:utf-8 *_* | ||
|
||
""" | ||
This script produce a batch trainig | ||
""" | ||
import sys | ||
import os | ||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),'..'))) | ||
import tensorflow as tf | ||
from datasets import sythtextprovider | ||
import tf_utils | ||
from preprocessing import ssd_vgg_preprocessing | ||
slim = tf.contrib.slim | ||
|
||
|
||
def get_batch(dataset_dir, | ||
num_readers, | ||
batch_size, | ||
out_shape, | ||
net, | ||
anchors, | ||
num_preprocessing_threads, | ||
file_pattern = '*.tfrecord', | ||
is_training = True): | ||
|
||
dataset = sythtextprovider.get_datasets(dataset_dir,file_pattern = file_pattern) | ||
|
||
provider = slim.dataset_data_provider.DatasetDataProvider( | ||
dataset, | ||
num_readers=num_readers, | ||
common_queue_capacity=20 * batch_size, | ||
common_queue_min=10 * batch_size, | ||
shuffle=True) | ||
|
||
[image, shape, glabels, gbboxes] = provider.get(['image', 'shape', | ||
'object/label', | ||
'object/bbox']) | ||
|
||
image, glabels, gbboxes,num = \ | ||
ssd_vgg_preprocessing.preprocess_image(image, glabels,gbboxes, | ||
out_shape,is_training=is_training) | ||
|
||
gclasses, glocalisations, gscores = \ | ||
net.bboxes_encode( glabels, gbboxes, anchors, num) | ||
|
||
batch_shape = [1] + [len(anchors)] * 3 | ||
|
||
|
||
r = tf.train.batch( | ||
tf_utils.reshape_list([image, gclasses, glocalisations, gscores]), | ||
batch_size=batch_size, | ||
num_threads=num_preprocessing_threads, | ||
capacity=5 * batch_size) | ||
|
||
b_image, b_gclasses, b_glocalisations, b_gscores= \ | ||
tf_utils.reshape_list(r, batch_shape) | ||
|
||
return [b_image, b_gclasses, b_glocalisations, b_gscores] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
# Copyright 2015 Paul Balanca. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Implement some custom layers, not provided by TensorFlow. | ||
Trying to follow as much as possible the style/standards used in | ||
tf.contrib.layers | ||
""" | ||
import tensorflow as tf | ||
|
||
from tensorflow.contrib.framework.python.ops import add_arg_scope | ||
from tensorflow.contrib.layers.python.layers import initializers | ||
from tensorflow.contrib.framework.python.ops import variables | ||
from tensorflow.contrib.layers.python.layers import utils | ||
from tensorflow.python.ops import nn | ||
from tensorflow.python.ops import init_ops | ||
from tensorflow.python.ops import variable_scope | ||
|
||
|
||
def abs_smooth(x): | ||
"""Smoothed absolute function. Useful to compute an L1 smooth error. | ||
Define as: | ||
x^2 / 2 if abs(x) < 1 | ||
abs(x) - 0.5 if abs(x) > 1 | ||
We use here a differentiable definition using min(x) and abs(x). Clearly | ||
not optimal, but good enough for our purpose! | ||
""" | ||
absx = tf.abs(x) | ||
minx = tf.minimum(absx, 1) | ||
r = 0.5 * ((absx - 1) * minx + absx) | ||
return r | ||
|
||
|
||
@add_arg_scope | ||
def l2_normalization( | ||
inputs, | ||
scaling=False, | ||
scale_initializer=init_ops.ones_initializer(), | ||
reuse=None, | ||
variables_collections=None, | ||
outputs_collections=None, | ||
data_format='NHWC', | ||
trainable=True, | ||
scope=None): | ||
"""Implement L2 normalization on every feature (i.e. spatial normalization). | ||
Should be extended in some near future to other dimensions, providing a more | ||
flexible normalization framework. | ||
Args: | ||
inputs: a 4-D tensor with dimensions [batch_size, height, width, channels]. | ||
scaling: whether or not to add a post scaling operation along the dimensions | ||
which have been normalized. | ||
scale_initializer: An initializer for the weights. | ||
reuse: whether or not the layer and its variables should be reused. To be | ||
able to reuse the layer scope must be given. | ||
variables_collections: optional list of collections for all the variables or | ||
a dictionary containing a different list of collection per variable. | ||
outputs_collections: collection to add the outputs. | ||
data_format: NHWC or NCHW data format. | ||
trainable: If `True` also add variables to the graph collection | ||
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). | ||
scope: Optional scope for `variable_scope`. | ||
Returns: | ||
A `Tensor` representing the output of the operation. | ||
""" | ||
|
||
with variable_scope.variable_scope( | ||
scope, 'L2Normalization', [inputs], reuse=reuse) as sc: | ||
inputs_shape = inputs.get_shape() | ||
inputs_rank = inputs_shape.ndims | ||
dtype = inputs.dtype.base_dtype | ||
if data_format == 'NHWC': | ||
# norm_dim = tf.range(1, inputs_rank-1) | ||
norm_dim = tf.range(inputs_rank-1, inputs_rank) | ||
params_shape = inputs_shape[-1:] | ||
elif data_format == 'NCHW': | ||
# norm_dim = tf.range(2, inputs_rank) | ||
norm_dim = tf.range(1, 2) | ||
params_shape = (inputs_shape[1]) | ||
|
||
# Normalize along spatial dimensions. | ||
outputs = nn.l2_normalize(inputs, norm_dim, epsilon=1e-12) | ||
# Additional scaling. | ||
if scaling: | ||
scale_collections = utils.get_variable_collections( | ||
variables_collections, 'scale') | ||
scale = variables.model_variable('gamma', | ||
shape=params_shape, | ||
dtype=dtype, | ||
initializer=scale_initializer, | ||
collections=scale_collections, | ||
trainable=trainable) | ||
if data_format == 'NHWC': | ||
outputs = tf.multiply(outputs, scale) | ||
elif data_format == 'NCHW': | ||
scale = tf.expand_dims(scale, axis=-1) | ||
scale = tf.expand_dims(scale, axis=-1) | ||
outputs = tf.multiply(outputs, scale) | ||
# outputs = tf.transpose(outputs, perm=(0, 2, 3, 1)) | ||
|
||
return utils.collect_named_outputs(outputs_collections, | ||
sc.original_name_scope, outputs) | ||
|
||
|
||
@add_arg_scope | ||
def pad2d(inputs, | ||
pad=(0, 0), | ||
mode='CONSTANT', | ||
data_format='NHWC', | ||
trainable=True, | ||
scope=None): | ||
"""2D Padding layer, adding a symmetric padding to H and W dimensions. | ||
Aims to mimic padding in Caffe and MXNet, helping the port of models to | ||
TensorFlow. Tries to follow the naming convention of `tf.contrib.layers`. | ||
Args: | ||
inputs: 4D input Tensor; | ||
pad: 2-Tuple with padding values for H and W dimensions; | ||
mode: Padding mode. C.f. `tf.pad` | ||
data_format: NHWC or NCHW data format. | ||
""" | ||
with tf.name_scope(scope, 'pad2d', [inputs]): | ||
# Padding shape. | ||
if data_format == 'NHWC': | ||
paddings = [[0, 0], [pad[0], pad[0]], [pad[1], pad[1]], [0, 0]] | ||
elif data_format == 'NCHW': | ||
paddings = [[0, 0], [0, 0], [pad[0], pad[0]], [pad[1], pad[1]]] | ||
net = tf.pad(inputs, paddings, mode=mode) | ||
return net | ||
|
||
|
||
@add_arg_scope | ||
def channel_to_last(inputs, | ||
data_format='NHWC', | ||
scope=None): | ||
"""Move the channel axis to the last dimension. Allows to | ||
provide a single output format whatever the input data format. | ||
Args: | ||
inputs: Input Tensor; | ||
data_format: NHWC or NCHW. | ||
Return: | ||
Input in NHWC format. | ||
""" | ||
with tf.name_scope(scope, 'channel_to_last', [inputs]): | ||
if data_format == 'NHWC': | ||
net = inputs | ||
elif data_format == 'NCHW': | ||
net = tf.transpose(inputs, perm=(0, 2, 3, 1)) | ||
return net |
Oops, something went wrong.