From 4f8266be39a25880f75636400406fe8355ce6fc4 Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Fri, 9 Nov 2018 20:19:28 -0200 Subject: [PATCH] Move globals handling out of util/coordinator.py --- DeepSpeech.py | 29 +++--- evaluate.py | 2 +- util/coordinator.py | 244 ++++++++++---------------------------------- util/flags.py | 4 +- 4 files changed, 72 insertions(+), 207 deletions(-) diff --git a/DeepSpeech.py b/DeepSpeech.py index 9cc5bf481c..6371daff34 100755 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -20,10 +20,11 @@ from tensorflow.contrib.lite.python import tflite_convert from tensorflow.python.tools import freeze_graph from util.audio import audiofile_to_input_vector +from util.config import C, initialize_globals from util.feeding import DataSet, ModelFeeder from util.logging import * from util.flags import create_flags, FLAGS -from util.coordinator import C, initialize_globals +from util.coordinator import TrainingCoordinator from util.preprocess import preprocess from util.text import Alphabet @@ -367,6 +368,10 @@ def train(server=None): If no server provided, it performs single process training. ''' + # Initializing and starting the training coordinator + coord = TrainingCoordinator(C.is_chief) + coord.start() + # Create a variable to hold the global_step. # It will automagically get incremented by the optimizer. global_step = tf.Variable(0, trainable=False, name='global_step') @@ -384,7 +389,7 @@ def train(server=None): train_set = DataSet(train_data, FLAGS.train_batch_size, limit=FLAGS.limit_train, - next_index=lambda i: C.COORD.get_next_index('train')) + next_index=lambda i: coord.get_next_index('train')) # Reading validation set dev_data = preprocess(FLAGS.dev_files.split(','), @@ -397,7 +402,7 @@ def train(server=None): dev_set = DataSet(dev_data, FLAGS.dev_batch_size, limit=FLAGS.limit_dev, - next_index=lambda i: C.COORD.get_next_index('dev')) + next_index=lambda i: coord.get_next_index('dev')) # Combining all sets to a multi set model feeder model_feeder = ModelFeeder(train_set, @@ -502,14 +507,14 @@ def update_progressbar(set_name): update_progressbar.total_jobs = None update_progressbar.current_job_index = 0 - current_epoch = C.COORD._epoch-1 + current_epoch = coord._epoch-1 if set_name == "train": log_info('Training epoch %i...' % current_epoch) - update_progressbar.total_jobs = C.COORD._num_jobs_train + update_progressbar.total_jobs = coord._num_jobs_train else: log_info('Validating epoch %i...' % current_epoch) - update_progressbar.total_jobs = C.COORD._num_jobs_dev + update_progressbar.total_jobs = coord._num_jobs_dev # recreate pbar update_progressbar.pbar = progressbar.ProgressBar(max_value=update_progressbar.total_jobs, @@ -542,10 +547,10 @@ def update_progressbar(set_name): # Retrieving global_step from the (potentially restored) model model_feeder.set_data_set(no_dropout_feed_dict, model_feeder.train) step = session.run(global_step, feed_dict=no_dropout_feed_dict) - C.COORD.start_coordination(model_feeder, step) + coord.start_coordination(model_feeder, step) # Get the first job - job = C.COORD.get_job() + job = coord.get_job() while job and not session.should_stop(): log_debug('Computing %s...' % job) @@ -606,7 +611,7 @@ def update_progressbar(set_name): # Send the current job to coordinator and receive the next one log_debug('Sending %s...' % job) - job = C.COORD.next_job(job) + job = coord.next_job(job) if update_progressbar.pbar: update_progressbar.pbar.finish() @@ -634,6 +639,9 @@ def update_progressbar(set_name): ' or removing the contents of {0}.'.format(FLAGS.checkpoint_dir)) sys.exit(1) + # Stopping the coordinator + coord.stop() + def test(): # Reading test set @@ -926,9 +934,6 @@ def main(_): if len(FLAGS.one_shot_infer): do_single_file_inference(FLAGS.one_shot_infer) - # Stopping the coordinator - C.COORD.stop() - if __name__ == '__main__' : create_flags() tf.app.run(main) diff --git a/evaluate.py b/evaluate.py index b3f54c9358..8b77e349b4 100755 --- a/evaluate.py +++ b/evaluate.py @@ -15,8 +15,8 @@ from attrdict import AttrDict from collections import namedtuple from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer +from util.config import C, initialize_globals from util.flags import create_flags -from util.coordinator import C, initialize_globals from util.logging import log_debug, log_info, log_warn, log_error from multiprocessing import Pool, cpu_count from six.moves import zip, range diff --git a/util/coordinator.py b/util/coordinator.py index f2fc9036e3..3564fdf931 100644 --- a/util/coordinator.py +++ b/util/coordinator.py @@ -1,18 +1,15 @@ from __future__ import absolute_import, division, print_function -import os import pickle import tensorflow as tf -from attrdict import AttrDict from datetime import datetime from six.moves import zip, range, filter, urllib, BaseHTTPServer from threading import Thread, Lock -from util.gpu import get_available_gpus +from util.config import C from util.flags import FLAGS from util.logging import * -from util.text import Alphabet -from xdg import BaseDirectory as xdg + # Execution # ========= @@ -104,8 +101,9 @@ class Epoch(object): Kwargs: set_name (str): the name of the data-set - one of 'train', 'dev' ''' - def __init__(self, index, num_jobs, set_name='train'): + def __init__(self, coord, index, num_jobs, set_name='train'): self.id = new_id() + self.coord = coord self.index = index self.num_jobs = num_jobs self.set_name = set_name @@ -186,7 +184,7 @@ def done(self): # if the job was for validation dataset then append it to the COORD's _loss for early stop verification if (FLAGS.early_stop is True) and (self.set_name == 'dev'): - COORD._dev_losses.append(self.loss) + self.coord_dev_losses.append(self.loss) return True return False @@ -213,57 +211,61 @@ class TrainingCoordinator(object): HTTP-forwarded to the chief worker instance. ''' - class TrainingCoordinationHandler(BaseHTTPServer.BaseHTTPRequestHandler): - '''Handles HTTP requests from remote workers to the Training Coordinator. - ''' - def _send_answer(self, data=None): - self.send_response(200) - self.send_header('content-type', 'text/plain') - self.end_headers() - if data: - self.wfile.write(data) - - def do_GET(self): - if COORD.started: - if self.path.startswith(PREFIX_NEXT_INDEX): - index = COORD.get_next_index(self.path[len(PREFIX_NEXT_INDEX):]) - if index >= 0: - self._send_answer(str(index).encode("utf-8")) - return - elif self.path.startswith(PREFIX_GET_JOB): - job = COORD.get_job(worker=int(self.path[len(PREFIX_GET_JOB):])) + def make_handler(coord): + class TrainingCoordinationHandler(BaseHTTPServer.BaseHTTPRequestHandler): + '''Handles HTTP requests from remote workers to the Training Coordinator. + ''' + def _send_answer(self, data=None): + self.send_response(200) + self.send_header('content-type', 'text/plain') + self.end_headers() + if data: + self.wfile.write(data) + + def do_GET(self): + if coord.started: + if self.path.startswith(PREFIX_NEXT_INDEX): + index = coord.get_next_index(self.path[len(PREFIX_NEXT_INDEX):]) + if index >= 0: + self._send_answer(str(index).encode("utf-8")) + return + elif self.path.startswith(PREFIX_GET_JOB): + job = coord.get_job(worker=int(self.path[len(PREFIX_GET_JOB):])) + if job: + self._send_answer(pickle.dumps(job)) + return + self.send_response(204) # end of training + else: + self.send_response(202) # not ready yet + self.end_headers() + + def do_POST(self): + if coord.started: + src = self.rfile.read(int(self.headers['content-length'])) + job = coord.next_job(pickle.loads(src)) if job: self._send_answer(pickle.dumps(job)) return - self.send_response(204) # end of training - else: - self.send_response(202) # not ready yet - self.end_headers() - - def do_POST(self): - if COORD.started: - src = self.rfile.read(int(self.headers['content-length'])) - job = COORD.next_job(pickle.loads(src)) - if job: - self._send_answer(pickle.dumps(job)) - return - self.send_response(204) # end of training - else: - self.send_response(202) # not ready yet - self.end_headers() + self.send_response(204) # end of training + else: + self.send_response(202) # not ready yet + self.end_headers() - def log_message(self, format, *args): - '''Overriding base method to suppress web handler messages on stdout. - ''' - return + def log_message(self, format, *args): + '''Overriding base method to suppress web handler messages on stdout. + ''' + return + + return TrainingCoordinationHandler def __init__(self, is_chief): self._init() self._lock = Lock() + self._thread = None self.started = False self.is_chief = is_chief if is_chief: - self._httpd = BaseHTTPServer.HTTPServer((FLAGS.coord_host, FLAGS.coord_port), TrainingCoordinator.TrainingCoordinationHandler) + self._httpd = BaseHTTPServer.HTTPServer((FLAGS.coord_host, FLAGS.coord_port), TrainingCoordinator.make_handler(self)) def _reset_counters(self): self._index_train = 0 @@ -385,7 +387,7 @@ def _next_epoch(self): self._reset_counters() # Append the training epoch - self._epochs_running.append(Epoch(self._epoch, num_jobs_train, set_name='train')) + self._epochs_running.append(Epoch(self, self._epoch, num_jobs_train, set_name='train')) if FLAGS.validation_step > 0 and (FLAGS.validation_step == 1 or self._epoch > 0) and self._epoch % FLAGS.validation_step == 0: # The current epoch should also have a validation part @@ -415,15 +417,15 @@ def start(self): if self.is_chief: log_debug('Starting coordinator...') self._thread = Thread(target=self._httpd.serve_forever) - self._thread.daemon = True + # self._thread.daemon = True self._thread.start() - log_debug('Coordinator started.') + log_debug('Coordinator started. Thread id {}'.format(self._thread.ident)) def stop(self, wait_for_running_epochs=True): '''Stops Training Coordinator. If chief, it waits for all epochs to be 'done' and then shuts down the web server. ''' - if self.is_chief: + if self.is_chief and self._thread: if wait_for_running_epochs: while len(self._epochs_running) > 0: log_traffic('Coordinator is waiting for epochs to finish...') @@ -564,143 +566,3 @@ def next_job(self, job): if result: result = pickle.loads(result) return result - -class GlobalConfig: - _config = None - - def __getattr__(self, name): - if not GlobalConfig._config: - raise RuntimeError("Global configuration not yet initialized.") - if not hasattr(GlobalConfig._config, name): - raise RuntimeError("Configuration option {} not found in config.".format(name)) - return GlobalConfig._config[name] - -C = GlobalConfig() - -def initialize_globals(): - c = AttrDict() - - # ps and worker hosts required for p2p cluster setup - FLAGS.ps_hosts = list(filter(len, FLAGS.ps_hosts.split(','))) - FLAGS.worker_hosts = list(filter(len, FLAGS.worker_hosts.split(','))) - - # The absolute number of computing nodes - regardless of cluster or single mode - c.num_workers = max(1, len(FLAGS.worker_hosts)) - - # Create a cluster from the parameter server and worker hosts. - c.cluster = tf.train.ClusterSpec({'ps': FLAGS.ps_hosts, 'worker': FLAGS.worker_hosts}) - - # If replica numbers are negative, we multiply their absolute values with the number of workers - if FLAGS.replicas < 0: - FLAGS.replicas = c.num_workers * -FLAGS.replicas - if FLAGS.replicas_to_agg < 0: - FLAGS.replicas_to_agg = c.num_workers * -FLAGS.replicas_to_agg - - # The device path base for this node - c.worker_device = '/job:%s/task:%d' % (FLAGS.job_name, FLAGS.task_index) - - # This node's CPU device - c.cpu_device = c.worker_device + '/cpu:0' - - # This node's available GPU devices - c.available_devices = [c.worker_device + gpu for gpu in get_available_gpus()] - - # If there is no GPU available, we fall back to CPU based operation - if 0 == len(c.available_devices): - c.available_devices = [c.cpu_device] - - # Set default dropout rates - if FLAGS.dropout_rate2 < 0: - FLAGS.dropout_rate2 = FLAGS.dropout_rate - if FLAGS.dropout_rate3 < 0: - FLAGS.dropout_rate3 = FLAGS.dropout_rate - if FLAGS.dropout_rate6 < 0: - FLAGS.dropout_rate6 = FLAGS.dropout_rate - - c.no_dropout = [ 0.0 ] * 6 - - # Set default checkpoint dir - if len(FLAGS.checkpoint_dir) == 0: - FLAGS.checkpoint_dir = xdg.save_data_path(os.path.join('deepspeech','checkpoints')) - - # Set default summary dir - if len(FLAGS.summary_dir) == 0: - FLAGS.summary_dir = xdg.save_data_path(os.path.join('deepspeech','summaries')) - - # Standard session configuration that'll be used for all new sessions. - c.session_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=FLAGS.log_placement, - inter_op_parallelism_threads=FLAGS.inter_op_parallelism_threads, - intra_op_parallelism_threads=FLAGS.intra_op_parallelism_threads) - - c.alphabet = Alphabet(os.path.abspath(FLAGS.alphabet_config_path)) - - # Geometric Constants - # =================== - - # For an explanation of the meaning of the geometric constants, please refer to - # doc/Geometry.md - - # Number of MFCC features - c.n_input = 26 # TODO: Determine this programatically from the sample rate - - # The number of frames in the context - c.n_context = 9 # TODO: Determine the optimal value using a validation data set - - # Number of units in hidden layers - c.n_hidden = FLAGS.n_hidden - - c.n_hidden_1 = c.n_hidden - - c.n_hidden_2 = c.n_hidden - - c.n_hidden_5 = c.n_hidden - - # LSTM cell state dimension - c.n_cell_dim = c.n_hidden - - # The number of units in the third layer, which feeds in to the LSTM - c.n_hidden_3 = c.n_cell_dim - - # The number of characters in the target language plus one - c.n_character = c.alphabet.size() + 1 # +1 for CTC blank label - - # The number of units in the sixth layer - c.n_hidden_6 = c.n_character - - # Queues that are used to gracefully stop parameter servers. - # Each queue stands for one ps. A finishing worker sends a token to each queue before joining/quitting. - # Each ps will dequeue as many tokens as there are workers before joining/quitting. - # This ensures parameter servers won't quit, if still required by at least one worker and - # also won't wait forever (like with a standard `server.join()`). - c.done_queues = [] - for i, ps in enumerate(FLAGS.ps_hosts): - # Queues are hosted by their respective owners - with tf.device('/job:ps/task:%d' % i): - c.done_queues.append(tf.FIFOQueue(1, tf.int32, shared_name=('queue%i' % i))) - - # Placeholder to pass in the worker's index as token - c.token_placeholder = tf.placeholder(tf.int32) - - # Enqueue operations for each parameter server - c.done_enqueues = [queue.enqueue(token_placeholder) for queue in c.done_queues] - - # Dequeue operations for each parameter server - c.done_dequeues = [queue.dequeue() for queue in c.done_queues] - - if len(FLAGS.one_shot_infer) > 0: - FLAGS.train = False - FLAGS.test = False - FLAGS.export_dir = '' - if not os.path.exists(FLAGS.one_shot_infer): - log_error('Path specified in --one_shot_infer is not a valid file.') - exit(1) - - # Determine, if we are the chief worker - c.is_chief = len(FLAGS.worker_hosts) == 0 or (FLAGS.task_index == 0 and FLAGS.job_name == 'worker') - - # Initializing and starting the training coordinator - c.COORD = TrainingCoordinator(c.is_chief) - c.COORD.start() - - GlobalConfig._config = c - diff --git a/util/flags.py b/util/flags.py index b58d8d428b..34456b9741 100644 --- a/util/flags.py +++ b/util/flags.py @@ -1,9 +1,7 @@ -from __future__ import print_function +from __future__ import absolute_import, division, print_function import tensorflow as tf -from xdg import BaseDirectory as xdg - FLAGS = tf.app.flags.FLAGS