Skip to content

Commit

Permalink
Add cc to ghost trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
Ervin Teng committed Dec 18, 2020
1 parent 95b3522 commit afd7476
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 10 deletions.
23 changes: 15 additions & 8 deletions ml-agents/mlagents/trainers/behavior_id_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import NamedTuple
from typing import NamedTuple, Optional
from urllib.parse import urlparse, parse_qs


Expand Down Expand Up @@ -42,14 +42,21 @@ def from_name_behavior_id(name_behavior_id: str) -> "BehaviorIdentifiers":
)


def create_name_behavior_id(name: str, team_id: int) -> str:
def create_name_behavior_id(
name: str, team_id: Optional[int] = None, group_id: Optional[int] = None
) -> str:
"""
Reconstructs fully qualified behavior name from name and team_id
:param name: brain name
:param team_id: team ID
:return: name_behavior_id
"""
return name + "?team=" + str(team_id)
Reconstructs fully qualified behavior name from name and team_id
:param name: brain name
:param team_id: team ID
:return: name_behavior_id
"""
final_name = name
if team_id is not None:
final_name += f"?team={team_id}"
if group_id is not None:
final_name += f"&group={group_id}"
return final_name


def get_global_agent_id(worker_id: int, agent_id: int) -> str:
Expand Down
13 changes: 11 additions & 2 deletions ml-agents/mlagents/trainers/ghost/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,8 +351,13 @@ def add_policy(
:param parsed_behavior_id: Behavior ID that the policy should belong to.
:param policy: Policy to associate with name_behavior_id.
"""
name_behavior_id = parsed_behavior_id.behavior_id
self._name_to_parsed_behavior_id[name_behavior_id] = parsed_behavior_id
name_behavior_id = create_name_behavior_id(
parsed_behavior_id.brain_name, team_id=parsed_behavior_id.team_id
)
# Add policy only based on the team id, not the group id
self._name_to_parsed_behavior_id[
parsed_behavior_id.behavior_id
] = parsed_behavior_id
self.policies[name_behavior_id] = policy

def get_policy(self, name_behavior_id: str) -> Policy:
Expand All @@ -361,6 +366,10 @@ def get_policy(self, name_behavior_id: str) -> Policy:
:param name_behavior_id: Fully qualified behavior name
:return: Policy associated with name_behavior_id
"""
parsed_behavior_id = BehaviorIdentifiers.from_name_behavior_id(name_behavior_id)
name_behavior_id = create_name_behavior_id(
parsed_behavior_id.brain_name, team_id=parsed_behavior_id.team_id
)
return self.policies[name_behavior_id]

def _save_snapshot(self) -> None:
Expand Down

0 comments on commit afd7476

Please sign in to comment.