From f929bd7dab9c84add555fad08cb3cb3d8d34ce6c Mon Sep 17 00:00:00 2001 From: He Date: Tue, 30 May 2017 16:03:25 -0400 Subject: [PATCH] fix the non-exiting bug due to deadlock (#4) --- ga3c/ProcessStats.py | 9 +++++++-- ga3c/Server.py | 5 +++++ ga3c/ThreadPredictor.py | 5 ++++- ga3c/ThreadTrainer.py | 7 ++++++- 4 files changed, 22 insertions(+), 4 deletions(-) diff --git a/ga3c/ProcessStats.py b/ga3c/ProcessStats.py index 937e0fb..b46d731 100644 --- a/ga3c/ProcessStats.py +++ b/ga3c/ProcessStats.py @@ -50,6 +50,7 @@ def __init__(self): self.predictor_count = Value('i', 0) self.agent_count = Value('i', 0) self.total_frame_count = 0 + self.exit_flag = Value('i', 0) def FPS(self): # average FPS from the beginning of the training (not current FPS) @@ -67,8 +68,12 @@ def run(self): self.start_time = time.time() first_time = datetime.now() - while True: - episode_time, reward, length = self.episode_log_q.get() + while self.exit_flag.value == 0: + try: + episode_time, reward, length = self.episode_log_q.get(timeout=2) + except: + continue + results_logger.write('%s, %d, %d\n' % (episode_time.strftime("%Y-%m-%d %H:%M:%S"), reward, length)) results_logger.flush() diff --git a/ga3c/Server.py b/ga3c/Server.py index 28d8a46..c906537 100644 --- a/ga3c/Server.py +++ b/ga3c/Server.py @@ -124,9 +124,14 @@ def main(self): time.sleep(0.01) self.dynamic_adjustment.exit_flag = True + self.dynamic_adjustment.join() while self.agents: self.remove_agent() while self.predictors: self.remove_predictor() while self.trainers: self.remove_trainer() + + self.stats.exit_flag.value = True + self.stats.join() + diff --git a/ga3c/ThreadPredictor.py b/ga3c/ThreadPredictor.py index 38c9ed1..fd7036f 100644 --- a/ga3c/ThreadPredictor.py +++ b/ga3c/ThreadPredictor.py @@ -47,7 +47,10 @@ def run(self): dtype=np.float32) while not self.exit_flag: - ids[0], states[0] = self.server.prediction_q.get() + try: + ids[0], states[0] = self.server.prediction_q.get(timeout=2) + except: + continue size = 1 while size < Config.PREDICTION_BATCH_SIZE and not self.server.prediction_q.empty(): diff --git a/ga3c/ThreadTrainer.py b/ga3c/ThreadTrainer.py index 4e364ad..fb9528c 100644 --- a/ga3c/ThreadTrainer.py +++ b/ga3c/ThreadTrainer.py @@ -41,9 +41,14 @@ def __init__(self, server, id): def run(self): while not self.exit_flag: + batch_size = 0 while batch_size <= Config.TRAINING_MIN_BATCH_SIZE: - x_, r_, a_ = self.server.training_q.get() + try: + x_, r_, a_ = self.server.training_q.get(timeout=2) + except: + if self.exit_flag: break + continue if batch_size == 0: x__ = x_; r__ = r_; a__ = a_ else: