From 7fae8734459a4ecd232bb8c63bfada7bed52ede3 Mon Sep 17 00:00:00 2001 From: Vincent-Pierre BERGES Date: Wed, 31 Mar 2021 13:28:47 -0700 Subject: [PATCH] =?UTF-8?q?[=F0=9F=90=9B=20=F0=9F=94=A8=20]Adding=20the=20?= =?UTF-8?q?ELO=20to=20the=20GlobalTrainingStatus=20(#5202)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Adding the ELO to the GlobalTrainingStatus * Update ml-agents/mlagents/trainers/ghost/trainer.py Co-authored-by: andrewcoh <54679309+andrewcoh@users.noreply.github.com> Co-authored-by: andrewcoh <54679309+andrewcoh@users.noreply.github.com> (cherry picked from commit 9c3dc4542ce2fa35f25dfa7809a7e1c2ee06001e) --- ml-agents/mlagents/trainers/ghost/trainer.py | 11 +++++++++-- ml-agents/mlagents/trainers/training_status.py | 1 + 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/ml-agents/mlagents/trainers/ghost/trainer.py b/ml-agents/mlagents/trainers/ghost/trainer.py index 2449734b75..faa04f1312 100644 --- a/ml-agents/mlagents/trainers/ghost/trainer.py +++ b/ml-agents/mlagents/trainers/ghost/trainer.py @@ -18,6 +18,7 @@ BehaviorIdentifiers, create_name_behavior_id, ) +from mlagents.trainers.training_status import GlobalTrainingStatus, StatusType logger = get_logger(__name__) @@ -128,8 +129,11 @@ def __init__( self.last_swap: int = 0 self.last_team_change: int = 0 - # Chosen because it is the initial ELO in Chess - self.initial_elo: float = self_play_parameters.initial_elo + self.initial_elo = GlobalTrainingStatus.get_parameter_state( + self.brain_name, StatusType.ELO + ) + if self.initial_elo is None: + self.initial_elo = self_play_parameters.initial_elo self.policy_elos: List[float] = [self.initial_elo] * ( self.window + 1 ) # for learning policy @@ -323,6 +327,9 @@ def save_model(self) -> None: """ Forwarding call to wrapped trainers save_model. """ + GlobalTrainingStatus.set_parameter_state( + self.brain_name, StatusType.ELO, self.current_elo + ) self.trainer.save_model() def create_policy( diff --git a/ml-agents/mlagents/trainers/training_status.py b/ml-agents/mlagents/trainers/training_status.py index 41ea9e907e..6d69093411 100644 --- a/ml-agents/mlagents/trainers/training_status.py +++ b/ml-agents/mlagents/trainers/training_status.py @@ -20,6 +20,7 @@ class StatusType(Enum): STATS_METADATA = "metadata" CHECKPOINTS = "checkpoints" FINAL_CHECKPOINT = "final_checkpoint" + ELO = "elo" @attr.s(auto_attribs=True)