Skip to content

Commit

Permalink
fix the non-exiting bug due to deadlock (NVlabs#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
hma02 committed May 30, 2017
1 parent b57efb4 commit b52664c
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 4 deletions.
9 changes: 7 additions & 2 deletions ga3c/ProcessStats.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(self):
self.predictor_count = Value('i', 0)
self.agent_count = Value('i', 0)
self.total_frame_count = 0
self.exit_flag = Value('i', 0)

def FPS(self):
# average FPS from the beginning of the training (not current FPS)
Expand All @@ -67,8 +68,12 @@ def run(self):

self.start_time = time.time()
first_time = datetime.now()
while True:
episode_time, reward, length = self.episode_log_q.get()
while self.exit_flag.value == 0:
try:
episode_time, reward, length = self.episode_log_q.get(timeout=2)
except:
continue

results_logger.write('%s, %d, %d\n' % (episode_time.strftime("%Y-%m-%d %H:%M:%S"), reward, length))
results_logger.flush()

Expand Down
5 changes: 5 additions & 0 deletions ga3c/Server.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,14 @@ def main(self):
time.sleep(0.01)

self.dynamic_adjustment.exit_flag = True
self.dynamic_adjustment.join()
while self.agents:
self.remove_agent()
while self.predictors:
self.remove_predictor()
while self.trainers:
self.remove_trainer()

self.stats.exit_flag.value = True
self.stats.join()

5 changes: 4 additions & 1 deletion ga3c/ThreadPredictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ def run(self):
dtype=np.float32)

while not self.exit_flag:
ids[0], states[0] = self.server.prediction_q.get()
try:
ids[0], states[0] = self.server.prediction_q.get(timeout=2)
except:
continue

size = 1
while size < Config.PREDICTION_BATCH_SIZE and not self.server.prediction_q.empty():
Expand Down
7 changes: 6 additions & 1 deletion ga3c/ThreadTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,14 @@ def __init__(self, server, id):

def run(self):
while not self.exit_flag:

batch_size = 0
while batch_size <= Config.TRAINING_MIN_BATCH_SIZE:
x_, r_, a_ = self.server.training_q.get()
try:
x_, r_, a_ = self.server.training_q.get(timeout=2)
except:
if self.exit_flag: break
continue
if batch_size == 0:
x__ = x_; r__ = r_; a__ = a_
else:
Expand Down

0 comments on commit b52664c

Please sign in to comment.