Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
ahainaut committed Feb 8, 2021
1 parent 14fe80d commit a3e899f
Showing 1 changed file with 31 additions and 38 deletions.
69 changes: 31 additions & 38 deletions muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,11 @@ def train(self, log_in_tensorboard=True):

# Initialize workers
self.training_worker = trainer.Trainer.options(
num_cpus=0,
num_gpus=num_gpus_per_worker if self.config.train_on_gpu else 0,
num_cpus=0, num_gpus=num_gpus_per_worker if self.config.train_on_gpu else 0,
).remote(self.checkpoint, self.config)

self.shared_storage_worker = shared_storage.SharedStorage.remote(
self.checkpoint,
self.config,
self.checkpoint, self.config,
)
self.shared_storage_worker.set_info.remote("terminate", False)

Expand All @@ -174,10 +172,7 @@ def train(self, log_in_tensorboard=True):
num_cpus=0,
num_gpus=num_gpus_per_worker if self.config.selfplay_on_gpu else 0,
).remote(
self.checkpoint,
self.Game,
self.config,
self.config.seed + seed,
self.checkpoint, self.Game, self.config, self.config.seed + seed,
)
for seed in range(self.config.num_workers)
]
Expand Down Expand Up @@ -208,8 +203,7 @@ def logging_loop(self, num_gpus):
"""
# Launch the test worker to get performance metrics
self.test_worker = self_play.SelfPlay.options(
num_cpus=0,
num_gpus=num_gpus,
num_cpus=0, num_gpus=num_gpus,
).remote(
self.checkpoint,
self.Game,
Expand Down Expand Up @@ -237,8 +231,7 @@ def logging_loop(self, num_gpus):
)
# Save model representation
writer.add_text(
"Model summary",
self.summary,
"Model summary", self.summary,
)
# Loop for updating the training performance
counter = 0
Expand All @@ -263,34 +256,24 @@ def logging_loop(self, num_gpus):
while info["training_step"] < self.config.training_steps:
info = ray.get(self.shared_storage_worker.get_info.remote(keys))
writer.add_scalar(
"1.Total_reward/1.Total_reward",
info["total_reward"],
counter,
"1.Total_reward/1.Total_reward", info["total_reward"], counter,
)
writer.add_scalar(
"1.Total_reward/2.Mean_value",
info["mean_value"],
counter,
"1.Total_reward/2.Mean_value", info["mean_value"], counter,
)
writer.add_scalar(
"1.Total_reward/3.Episode_length",
info["episode_length"],
counter,
"1.Total_reward/3.Episode_length", info["episode_length"], counter,
)
writer.add_scalar(
"1.Total_reward/4.MuZero_reward",
info["muzero_reward"],
counter,
"1.Total_reward/4.MuZero_reward", info["muzero_reward"], counter,
)
writer.add_scalar(
"1.Total_reward/5.Opponent_reward",
info["opponent_reward"],
counter,
)
writer.add_scalar(
"2.Workers/1.Self_played_games",
info["num_played_games"],
counter,
"2.Workers/1.Self_played_games", info["num_played_games"], counter,
)
writer.add_scalar(
"2.Workers/2.Training_steps", info["training_step"], counter
Expand Down Expand Up @@ -330,7 +313,12 @@ def logging_loop(self, num_gpus):
# Persist replay buffer to disk
print("\n\nPersisting replay buffer games to disk...")
pickle.dump(
self.replay_buffer,
{
"buffer": self.replay_buffer,
"num_played_games": self.checkpoint["num_played_games"],
"num_played_steps": self.checkpoint["num_played_steps"],
"num_reanalysed_games": self.checkpoint["num_reanalysed_games"],
},
open(os.path.join(self.config.results_path, "replay_buffer.pkl"), "wb"),
)

Expand Down Expand Up @@ -378,20 +366,15 @@ def test(
opponent = opponent if opponent else self.config.opponent
muzero_player = muzero_player if muzero_player else self.config.muzero_player
self_play_worker = self_play.SelfPlay.options(
num_cpus=0,
num_gpus=num_gpus,
num_cpus=0, num_gpus=num_gpus,
).remote(self.checkpoint, self.Game, self.config, numpy.random.randint(10000))
results = []
for i in range(num_tests):
print(f"Testing {i+1}/{num_tests}")
results.append(
ray.get(
self_play_worker.play_game.remote(
0,
0,
render,
opponent,
muzero_player,
0, 0, render, opponent, muzero_player,
)
)
)
Expand Down Expand Up @@ -433,7 +416,18 @@ def load_model(self, checkpoint_path=None, replay_buffer_path=None):
if replay_buffer_path:
if os.path.exists(replay_buffer_path):
with open(replay_buffer_path, "rb") as f:
self.replay_buffer = pickle.load(f)
replay_buffer_infos = pickle.load(f)
self.replay_buffer = replay_buffer_infos["buffer"]
self.checkpoint["num_played_steps"] = replay_buffer_infos[
"num_played_steps"
]
self.checkpoint["num_played_games"] = replay_buffer_infos[
"num_played_games"
]
self.checkpoint["num_reanalysed_games"] = replay_buffer_infos[
"num_reanalysed_games"
]

print(f"\nInitializing replay buffer with {replay_buffer_path}")
else:
print(
Expand Down Expand Up @@ -593,8 +587,7 @@ def load_model_menu(muzero, game_name):
replay_buffer_path = f"{options[choice]}replay_buffer.pkl"

muzero.load_model(
checkpoint_path=checkpoint_path,
replay_buffer_path=replay_buffer_path,
checkpoint_path=checkpoint_path, replay_buffer_path=replay_buffer_path,
)


Expand Down

0 comments on commit a3e899f

Please sign in to comment.