Skip to content

Commit

Permalink
fix the non-exiting bug due to deadlock (NVlabs#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
hma02 committed May 30, 2017
1 parent b147d5c commit b1d3d67
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 19 deletions.
16 changes: 14 additions & 2 deletions ga3c/ProcessAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,11 @@ def predict(self, state):
# put the state in the prediction q
self.prediction_q.put((self.id, state))
# wait for the prediction to come back
p, v = self.wait_q.get()
try:
p, v = self.wait_q.get(10)
except:
return None, None

return p, v

def select_action(self, prediction):
Expand All @@ -90,13 +94,21 @@ def run_episode(self):
time_count = 0
reward_sum = 0.0

while not done:
while not done and self.exit_flag.value == 0:

# very first few frames
if self.env.current_state is None:
self.env.step(0) # 0 == NOOP
continue

prediction, value = self.predict(self.env.current_state)
if prediction is None and value is None:
if self.exit_flag.value !=0:
break
else:
print("Warning: couldn't get prediction. Giving up.")
continue

action = self.select_action(prediction)
reward, done = self.env.step(action)
reward_sum += reward
Expand Down
9 changes: 7 additions & 2 deletions ga3c/ProcessStats.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(self):
self.predictor_count = Value('i', 0)
self.agent_count = Value('i', 0)
self.total_frame_count = 0
self.exit_flag = Value('i', 0)

def FPS(self):
# average FPS from the beginning of the training (not current FPS)
Expand All @@ -67,8 +68,12 @@ def run(self):

self.start_time = time.time()
first_time = datetime.now()
while True:
episode_time, reward, length = self.episode_log_q.get()
while self.exit_flag.value == 0:
try:
episode_time, reward, length = self.episode_log_q.get(timeout=0.1)
except:
continue

results_logger.write('%s, %d, %d\n' % (episode_time.strftime("%Y-%m-%d %H:%M:%S"), reward, length))
results_logger.flush()

Expand Down
38 changes: 28 additions & 10 deletions ga3c/Server.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,27 +63,36 @@ def add_agent(self):
self.agents[-1].start()

def remove_agent(self):
self.agents[-1].exit_flag.value = True
self.agents[-1].join()
self.agents.pop()

for p in self.agents:
p.exit_flag.value = True
for p in self.agents:
p.join()
self.agents.pop()

def add_predictor(self):
self.predictors.append(ThreadPredictor(self, len(self.predictors)))
self.predictors[-1].start()

def remove_predictor(self):
self.predictors[-1].exit_flag = True
self.predictors[-1].join()
self.predictors.pop()

for p in self.predictors:
p.exit_flag = True
for p in self.predictors:
p.join()
self.predictors.pop()

def add_trainer(self):
self.trainers.append(ThreadTrainer(self, len(self.trainers)))
self.trainers[-1].start()

def remove_trainer(self):
self.trainers[-1].exit_flag = True
self.trainers[-1].join()
self.trainers.pop()

for p in self.trainers:
p.exit_flag = True
for p in self.trainers:
p.join()
self.trainers.pop()

def train_model(self, x_, r_, a_, trainer_id):
self.model.train(x_, r_, a_, trainer_id)
Expand Down Expand Up @@ -122,11 +131,20 @@ def main(self):
self.stats.should_save_model.value = 0

time.sleep(0.01)


print('Finished. Exiting subprocesses ...')
join_start=time.time()
self.dynamic_adjustment.exit_flag = True
self.dynamic_adjustment.join()
while self.agents:
self.remove_agent()
while self.predictors:
self.remove_predictor()
while self.trainers:
self.remove_trainer()
self.stats.exit_flag.value = True
self.stats.join()
print('Exit. Joining takes %.2f s' % (time.time()-join_start))



16 changes: 12 additions & 4 deletions ga3c/ThreadPredictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,21 @@ def run(self):
dtype=np.float32)

while not self.exit_flag:
ids[0], states[0] = self.server.prediction_q.get()
try:
ids[0], states[0] = self.server.prediction_q.get(timeout=0.1)
except:
continue

size = 1
while size < Config.PREDICTION_BATCH_SIZE and not self.server.prediction_q.empty():
ids[size], states[size] = self.server.prediction_q.get()
size += 1

try:
ids[size], states[size] = self.server.prediction_q.get(timeout=0.1)
size += 1
except:
if self.exit_flag: break

if self.exit_flag: break

batch = states[:size]
p, v = self.server.model.predict_p_and_v(batch)

Expand Down
7 changes: 6 additions & 1 deletion ga3c/ThreadTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,14 @@ def __init__(self, server, id):

def run(self):
while not self.exit_flag:

batch_size = 0
while batch_size <= Config.TRAINING_MIN_BATCH_SIZE:
x_, r_, a_ = self.server.training_q.get()
try:
x_, r_, a_ = self.server.training_q.get(timeout=0.1)
except:
if self.exit_flag: break
continue
if batch_size == 0:
x__ = x_; r__ = r_; a__ = a_
else:
Expand Down

0 comments on commit b1d3d67

Please sign in to comment.