diff --git a/ml-agents/mlagents/trainers/behavior_id_utils.py b/ml-agents/mlagents/trainers/behavior_id_utils.py index 0cdec55ce9..26d18745b3 100644 --- a/ml-agents/mlagents/trainers/behavior_id_utils.py +++ b/ml-agents/mlagents/trainers/behavior_id_utils.py @@ -1,4 +1,4 @@ -from typing import NamedTuple +from typing import NamedTuple, Optional from urllib.parse import urlparse, parse_qs @@ -42,14 +42,21 @@ def from_name_behavior_id(name_behavior_id: str) -> "BehaviorIdentifiers": ) -def create_name_behavior_id(name: str, team_id: int) -> str: +def create_name_behavior_id( + name: str, team_id: Optional[int] = None, group_id: Optional[int] = None +) -> str: """ - Reconstructs fully qualified behavior name from name and team_id - :param name: brain name - :param team_id: team ID - :return: name_behavior_id - """ - return name + "?team=" + str(team_id) + Reconstructs fully qualified behavior name from name and team_id + :param name: brain name + :param team_id: team ID + :return: name_behavior_id + """ + final_name = name + if team_id is not None: + final_name += f"?team={team_id}" + if group_id is not None: + final_name += f"&group={group_id}" + return final_name def get_global_agent_id(worker_id: int, agent_id: int) -> str: diff --git a/ml-agents/mlagents/trainers/ghost/trainer.py b/ml-agents/mlagents/trainers/ghost/trainer.py index b2db465c20..2229cab44a 100644 --- a/ml-agents/mlagents/trainers/ghost/trainer.py +++ b/ml-agents/mlagents/trainers/ghost/trainer.py @@ -351,8 +351,13 @@ def add_policy( :param parsed_behavior_id: Behavior ID that the policy should belong to. :param policy: Policy to associate with name_behavior_id. """ - name_behavior_id = parsed_behavior_id.behavior_id - self._name_to_parsed_behavior_id[name_behavior_id] = parsed_behavior_id + name_behavior_id = create_name_behavior_id( + parsed_behavior_id.brain_name, team_id=parsed_behavior_id.team_id + ) + # Add policy only based on the team id, not the group id + self._name_to_parsed_behavior_id[ + parsed_behavior_id.behavior_id + ] = parsed_behavior_id self.policies[name_behavior_id] = policy def get_policy(self, name_behavior_id: str) -> Policy: @@ -361,6 +366,10 @@ def get_policy(self, name_behavior_id: str) -> Policy: :param name_behavior_id: Fully qualified behavior name :return: Policy associated with name_behavior_id """ + parsed_behavior_id = BehaviorIdentifiers.from_name_behavior_id(name_behavior_id) + name_behavior_id = create_name_behavior_id( + parsed_behavior_id.brain_name, team_id=parsed_behavior_id.team_id + ) return self.policies[name_behavior_id] def _save_snapshot(self) -> None: