This repository was archived by the owner on Aug 8, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 36
/
Copy pathpredictor_views_mean.py
105 lines (74 loc) · 3.82 KB
/
predictor_views_mean.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import argparse
import os
import shutil
import sys
import cv2
import numpy as np
import tensorflow as tf
from tensorflow.contrib.learn import RunConfig
from datasets.DatasetFactory import DatasetFactory
from helper.model_helper import get_model_function, get_input_function
from nets import nets_factory
slim = tf.contrib.slim
def start_prediction(output_directory, data_directory, dataset_name, model_dir, network_name, batch_size, batch_threads, num_classes=None):
dataset_factory = DatasetFactory(dataset_name=dataset_name, data_directory=data_directory, augment=False)
if num_classes is None:
num_classes = dataset_factory.get_dataset('train').num_classes()
run_config = RunConfig(keep_checkpoint_max=10, save_checkpoints_steps=None)
# Instantiate Estimator
estimator = tf.estimator.Estimator(
model_fn=get_model_function(model_dir, network_name, num_classes),
model_dir=model_dir,
config=run_config,
params={})
image_size = nets_factory.get_input_size(network_name)
run_prediction_and_evaluation(output_directory, batch_size, batch_threads, dataset_factory, estimator, image_size)
def run_prediction_and_evaluation(output_directory, batch_size, batch_threads, dataset_factory, estimator, image_size):
predict_views(batch_size, batch_threads, dataset_factory, estimator, image_size, output_directory, 'train')
predict_views(batch_size, batch_threads, dataset_factory, estimator, image_size, output_directory, 'test')
def predict_views(batch_size, batch_threads, dataset_factory, estimator, image_size, output_directory, dataset_part):
print('Starting views evaluation...')
dataset = dataset_factory.get_dataset(dataset_part)
output_directory = os.path.join(output_directory, dataset_part)
if os.path.exists(output_directory):
shutil.rmtree(output_directory)
os.makedirs(output_directory)
print('\n\nRunning Prediction for %s' % dataset_part)
input_function = get_input_function(dataset, batch_size, batch_threads, False, image_size)
predicted = estimator.predict(input_fn=input_function)
num_samples = dataset.get_number_of_samples()
sum_images = np.zeros([3, 128, 64, 3], dtype=np.longlong)
counters = np.zeros([3])
for sample, prediction in enumerate(predicted):
original_path = prediction['paths'].decode('UTF-8')
image = cv2.imread(original_path)
image = cv2.resize(image, (64, 128))
predicted_view = prediction['views_classifications']
sum_images[predicted_view, :, :, :] += image
counters[predicted_view] += 1
if (sample + 1) % batch_size == 0:
sys.stdout.write('\r>> Processed %d samples of %d' % (sample + 1, num_samples))
sys.stdout.flush()
for view in range(3):
mean_image = sum_images[view] / counters[view]
cv2.imwrite(os.path.join(output_directory, str(view) + '.png'), mean_image)
print('\n\nFinished views prediction.')
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--data', help='Specify the folder with the images to be trained and evaluated', dest='data_directory')
parser.add_argument('--dataset-name', help='The name of the dataset')
parser.add_argument('--batch-size', help='The batch size', type=int, default=16)
parser.add_argument('--batch-threads', help='The number of threads to be used for batching', type=int, default=4)
parser.add_argument('--model-dir', help='The model to be loaded')
parser.add_argument('--network-name', help='Name of the network')
parser.add_argument('--output', help='Output directory')
parser.add_argument('--num-classes', help='Number of classes', type=int, default=None)
args = parser.parse_args()
print('Running with command line arguments:')
print(args)
print('\n\n')
# tf.logging.set_verbosity(tf.logging.INFO)
start_prediction(args.output, args.data_directory, args.dataset_name, args.model_dir, args.network_name, args.batch_size, args.batch_threads, args.num_classes)
print('Exiting ...')
if __name__ == '__main__':
main()