Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the Agent reset immediately after Done #3291

Merged
merged 12 commits into from
Jan 28, 2020
21 changes: 18 additions & 3 deletions ml-agents/mlagents/trainers/agent_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def add_experiences(
"Policy/Learning Rate", take_action_outputs["learning_rate"]
)

terminated_agents: List[str] = []
# Make unique agent_ids that are global across workers
action_global_agent_ids = [
get_global_agent_id(worker_id, ag_id) for ag_id in previous_action.agent_ids
Expand Down Expand Up @@ -144,7 +145,7 @@ def add_experiences(
)
for traj_queue in self.trajectory_queues:
traj_queue.put(trajectory)
del self.experience_buffers[global_id]

if curr_agent_step.done:
self.stats_reporter.add_stat(
"Environment/Cumulative Reward",
Expand All @@ -154,8 +155,7 @@ def add_experiences(
"Environment/Episode Length",
self.episode_steps.get(global_id, 0),
)
del self.episode_steps[global_id]
del self.episode_rewards[global_id]
terminated_agents += [global_id]
elif not curr_agent_step.done:
self.episode_steps[global_id] += 1

Expand All @@ -166,6 +166,21 @@ def add_experiences(
previous_action.agent_ids, take_action_outputs["action"]
)

for terminated_id in terminated_agents:
self._clean_agent_data(terminated_id)

def _clean_agent_data(self, global_id: str) -> None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ervteng Tell me what you think

"""
Removes the data for an Agent.
"""
del self.experience_buffers[global_id]
del self.last_take_action_outputs[global_id]
del self.episode_steps[global_id]
del self.episode_rewards[global_id]
del self.last_step_result[global_id]
self.policy.remove_previous_action([global_id])
self.policy.remove_memories([global_id])

def publish_trajectory_queue(
self, trajectory_queue: "AgentManagerQueue[Trajectory]"
) -> None:
Expand Down