Skip to content

Commit

Permalink
Move globals handling out of util/coordinator.py
Browse files Browse the repository at this point in the history
  • Loading branch information
reuben committed Nov 9, 2018
1 parent 38b5447 commit 4f8266b
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 207 deletions.
29 changes: 17 additions & 12 deletions DeepSpeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
from tensorflow.contrib.lite.python import tflite_convert
from tensorflow.python.tools import freeze_graph
from util.audio import audiofile_to_input_vector
from util.config import C, initialize_globals
from util.feeding import DataSet, ModelFeeder
from util.logging import *
from util.flags import create_flags, FLAGS
from util.coordinator import C, initialize_globals
from util.coordinator import TrainingCoordinator
from util.preprocess import preprocess
from util.text import Alphabet

Expand Down Expand Up @@ -367,6 +368,10 @@ def train(server=None):
If no server provided, it performs single process training.
'''

# Initializing and starting the training coordinator
coord = TrainingCoordinator(C.is_chief)
coord.start()

# Create a variable to hold the global_step.
# It will automagically get incremented by the optimizer.
global_step = tf.Variable(0, trainable=False, name='global_step')
Expand All @@ -384,7 +389,7 @@ def train(server=None):
train_set = DataSet(train_data,
FLAGS.train_batch_size,
limit=FLAGS.limit_train,
next_index=lambda i: C.COORD.get_next_index('train'))
next_index=lambda i: coord.get_next_index('train'))

# Reading validation set
dev_data = preprocess(FLAGS.dev_files.split(','),
Expand All @@ -397,7 +402,7 @@ def train(server=None):
dev_set = DataSet(dev_data,
FLAGS.dev_batch_size,
limit=FLAGS.limit_dev,
next_index=lambda i: C.COORD.get_next_index('dev'))
next_index=lambda i: coord.get_next_index('dev'))

# Combining all sets to a multi set model feeder
model_feeder = ModelFeeder(train_set,
Expand Down Expand Up @@ -502,14 +507,14 @@ def update_progressbar(set_name):
update_progressbar.total_jobs = None
update_progressbar.current_job_index = 0

current_epoch = C.COORD._epoch-1
current_epoch = coord._epoch-1

if set_name == "train":
log_info('Training epoch %i...' % current_epoch)
update_progressbar.total_jobs = C.COORD._num_jobs_train
update_progressbar.total_jobs = coord._num_jobs_train
else:
log_info('Validating epoch %i...' % current_epoch)
update_progressbar.total_jobs = C.COORD._num_jobs_dev
update_progressbar.total_jobs = coord._num_jobs_dev

# recreate pbar
update_progressbar.pbar = progressbar.ProgressBar(max_value=update_progressbar.total_jobs,
Expand Down Expand Up @@ -542,10 +547,10 @@ def update_progressbar(set_name):
# Retrieving global_step from the (potentially restored) model
model_feeder.set_data_set(no_dropout_feed_dict, model_feeder.train)
step = session.run(global_step, feed_dict=no_dropout_feed_dict)
C.COORD.start_coordination(model_feeder, step)
coord.start_coordination(model_feeder, step)

# Get the first job
job = C.COORD.get_job()
job = coord.get_job()

while job and not session.should_stop():
log_debug('Computing %s...' % job)
Expand Down Expand Up @@ -606,7 +611,7 @@ def update_progressbar(set_name):

# Send the current job to coordinator and receive the next one
log_debug('Sending %s...' % job)
job = C.COORD.next_job(job)
job = coord.next_job(job)

if update_progressbar.pbar:
update_progressbar.pbar.finish()
Expand Down Expand Up @@ -634,6 +639,9 @@ def update_progressbar(set_name):
' or removing the contents of {0}.'.format(FLAGS.checkpoint_dir))
sys.exit(1)

# Stopping the coordinator
coord.stop()


def test():
# Reading test set
Expand Down Expand Up @@ -926,9 +934,6 @@ def main(_):
if len(FLAGS.one_shot_infer):
do_single_file_inference(FLAGS.one_shot_infer)

# Stopping the coordinator
C.COORD.stop()

if __name__ == '__main__' :
create_flags()
tf.app.run(main)
2 changes: 1 addition & 1 deletion evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from attrdict import AttrDict
from collections import namedtuple
from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
from util.config import C, initialize_globals
from util.flags import create_flags
from util.coordinator import C, initialize_globals
from util.logging import log_debug, log_info, log_warn, log_error
from multiprocessing import Pool, cpu_count
from six.moves import zip, range
Expand Down
Loading

0 comments on commit 4f8266b

Please sign in to comment.