Skip to content

Commit

Permalink
Added get/set_param functions to RL algos to support fast pbt (#251)
Browse files Browse the repository at this point in the history
* Added get/set_param functions to support fast pbt, without restarting training from scratch.
* Release notes updates. Save_freq SAC fix.
* Fixed SAC weight loading crash.
  • Loading branch information
ViktorM authored Sep 22, 2023
1 parent 990b478 commit 66ce12f
Show file tree
Hide file tree
Showing 12 changed files with 152 additions and 59 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,8 @@ Additional environment supported properties and functions
* Fixed bug with SAC not saving weights with save_frequency.
* Added multi-node training support for GPU-accelerated training environments like Isaac Gym. No changes in training scripts are required. Thanks to @ankurhanda and @ArthurAllshire for assistance in implementation.
* Added evaluation feature for inferencing during training. Checkpoints from training process can be automatically picked up and updated in the inferencing process when enabled.
* Added get/set API for runtime update of rl training parameters. Thanks to @ArthurAllshire for the initial version of fast PBT code.
* Fixed SAC not loading weights properly.

1.6.0

Expand Down
9 changes: 4 additions & 5 deletions rl_games/algos_torch/a2c_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@

from torch import optim
import torch
from torch import nn
import numpy as np
import gym


class A2CAgent(a2c_common.ContinuousA2CBase):

def __init__(self, base_name, params):
a2c_common.ContinuousA2CBase.__init__(self, base_name, params)
obs_shape = self.obs_shape
Expand Down Expand Up @@ -68,9 +67,9 @@ def save(self, fn):
state = self.get_full_state_weights()
torch_ext.save_checkpoint(fn, state)

def restore(self, fn):
def restore(self, fn, set_epoch=True):
checkpoint = torch_ext.load_checkpoint(fn)
self.set_full_state_weights(checkpoint)
self.set_full_state_weights(checkpoint, set_epoch=set_epoch)

def get_masked_action_values(self, obs, action_masks):
assert False
Expand Down
4 changes: 2 additions & 2 deletions rl_games/algos_torch/a2c_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ def save(self, fn):
state = self.get_full_state_weights()
torch_ext.save_checkpoint(fn, state)

def restore(self, fn):
def restore(self, fn, set_epoch=True):
checkpoint = torch_ext.load_checkpoint(fn)
self.set_full_state_weights(checkpoint)
self.set_full_state_weights(checkpoint, set_epoch=set_epoch)

def get_masked_action_values(self, obs, action_masks):
processed_obs = self._preproc_obs(obs['obs'])
Expand Down
4 changes: 4 additions & 0 deletions rl_games/algos_torch/players.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def rescale_actions(low, high, action):


class PpoPlayerContinuous(BasePlayer):

def __init__(self, params):
BasePlayer.__init__(self, params)
self.network = self.config['network']
Expand Down Expand Up @@ -81,7 +82,9 @@ def restore(self, fn):
def reset(self):
self.init_rnn()


class PpoPlayerDiscrete(BasePlayer):

def __init__(self, params):
BasePlayer.__init__(self, params)

Expand Down Expand Up @@ -185,6 +188,7 @@ def reset(self):


class SACPlayer(BasePlayer):

def __init__(self, params):
BasePlayer.__init__(self, params)
self.network = self.config['network']
Expand Down
89 changes: 57 additions & 32 deletions rl_games/algos_torch/sac_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(self, base_name, params):
self.num_steps_per_episode = config.get("num_steps_per_episode", 1)
self.normalize_input = config.get("normalize_input", False)

# TODO: double-check! To use bootstrap instead?
self.max_env_steps = config.get("max_env_steps", 1000) # temporary, in future we will use other approach

print(self.batch_size, self.num_actors, self.num_agents)
Expand All @@ -60,7 +61,6 @@ def __init__(self, base_name, params):
'action_dim': self.env_info["action_space"].shape[0],
'actions_num' : self.actions_num,
'input_shape' : obs_shape,
'normalize_input' : self.normalize_input,
'normalize_input': self.normalize_input,
}
self.model = self.network.build(net_config)
Expand Down Expand Up @@ -88,12 +88,8 @@ def __init__(self, base_name, params):
self.target_entropy = self.target_entropy_coef * -self.env_info['action_space'].shape[0]
print("Target entropy", self.target_entropy)

self.step = 0
self.algo_observer = config['features']['observer']

# TODO: Is there a better way to get the maximum number of episodes?
self.max_episodes = torch.ones(self.num_actors, device=self._device)*self.num_steps_per_episode

def load_networks(self, params):
builder = model_builder.ModelBuilder()
self.config['network'] = builder.load(params)
Expand Down Expand Up @@ -133,6 +129,8 @@ def base_init(self, base_name, config):
self.max_epochs = self.config.get('max_epochs', -1)
self.max_frames = self.config.get('max_frames', -1)

self.save_freq = config.get('save_frequency', 0)

self.network = config['network']
self.rewards_shaper = config['reward_shaper']
self.num_agents = self.env_info.get('agents', 1)
Expand All @@ -146,10 +144,10 @@ def base_init(self, base_name, config):
self.min_alpha = torch.tensor(np.log(1)).float().to(self._device)

self.frame = 0
self.epoch_num = 0
self.update_time = 0
self.last_mean_rewards = -100500
self.last_mean_rewards = -1000000000
self.play_time = 0
self.epoch_num = 0

# TODO: put it into the separate class
pbt_str = ''
Expand Down Expand Up @@ -205,17 +203,8 @@ def alpha(self):
def device(self):
return self._device

def get_full_state_weights(self):
state = self.get_weights()

state['steps'] = self.step
state['actor_optimizer'] = self.actor_optimizer.state_dict()
state['critic_optimizer'] = self.critic_optimizer.state_dict()
state['log_alpha_optimizer'] = self.log_alpha_optimizer.state_dict()

return state

def get_weights(self):
print("Loading weights")
state = {'actor': self.model.sac_network.actor.state_dict(),
'critic': self.model.sac_network.critic.state_dict(),
'critic_target': self.model.sac_network.critic_target.state_dict()}
Expand All @@ -233,17 +222,45 @@ def set_weights(self, weights):
if self.normalize_input and 'running_mean_std' in weights:
self.model.running_mean_std.load_state_dict(weights['running_mean_std'])

def set_full_state_weights(self, weights):
def get_full_state_weights(self):
print("Loading full weights")
state = self.get_weights()

state['epoch'] = self.epoch_num
state['frame'] = self.frame
state['actor_optimizer'] = self.actor_optimizer.state_dict()
state['critic_optimizer'] = self.critic_optimizer.state_dict()
state['log_alpha_optimizer'] = self.log_alpha_optimizer.state_dict()

return state

def set_full_state_weights(self, weights, set_epoch=True):
self.set_weights(weights)

self.step = weights['step']
if set_epoch:
self.epoch_num = weights['epoch']
self.frame = weights['frame']

self.actor_optimizer.load_state_dict(weights['actor_optimizer'])
self.critic_optimizer.load_state_dict(weights['critic_optimizer'])
self.log_alpha_optimizer.load_state_dict(weights['log_alpha_optimizer'])

def restore(self, fn):
self.last_mean_rewards = weights.get('last_mean_rewards', -1000000000)

if self.vec_env is not None:
env_state = weights.get('env_state', None)
self.vec_env.set_env_state(env_state)

def restore(self, fn, set_epoch=True):
print("SAC restore")
checkpoint = torch_ext.load_checkpoint(fn)
self.set_full_state_weights(checkpoint)
self.set_full_state_weights(checkpoint, set_epoch=set_epoch)

def get_param(self, param_name):
pass

def set_param(self, param_name, param_value):
pass

def get_masked_action_values(self, obs, action_masks):
assert False
Expand Down Expand Up @@ -334,6 +351,7 @@ def preproc_obs(self, obs):
if isinstance(obs, dict):
obs = obs['obs']
obs = self.model.norm_obs(obs)

return obs

def cast_obs(self, obs):
Expand All @@ -348,7 +366,7 @@ def cast_obs(self, obs):

return obs

# todo: move to common utils
# TODO: move to common utils
def obs_to_tensors(self, obs):
obs_is_dict = isinstance(obs, dict)
if obs_is_dict:
Expand All @@ -359,6 +377,7 @@ def obs_to_tensors(self, obs):
upd_obs = self.cast_obs(obs)
if not obs_is_dict or 'obs' not in obs:
upd_obs = {'obs' : upd_obs}

return upd_obs

def _obs_to_tensors_internal(self, obs):
Expand All @@ -368,18 +387,19 @@ def _obs_to_tensors_internal(self, obs):
upd_obs[key] = self._obs_to_tensors_internal(value)
else:
upd_obs = self.cast_obs(obs)

return upd_obs

def preprocess_actions(self, actions):
if not self.is_tensor_obses:
actions = actions.cpu().numpy()

return actions

def env_step(self, actions):
actions = self.preprocess_actions(actions)
obs, rewards, dones, infos = self.vec_env.step(actions) # (obs_space) -> (n, obs_space)

self.step += self.num_actors
if self.is_tensor_obses:
return self.obs_to_tensors(obs), rewards.to(self._device), dones.to(self._device), infos
else:
Expand Down Expand Up @@ -415,7 +435,7 @@ def extract_actor_stats(self, actor_losses, entropies, alphas, alpha_losses, act
def clear_stats(self):
self.game_rewards.clear()
self.game_lengths.clear()
self.mean_rewards = self.last_mean_rewards = -100500
self.mean_rewards = self.last_mean_rewards = -1000000000
self.algo_observer.after_clear_stats()

def play_steps(self, random_exploration = False):
Expand All @@ -431,6 +451,11 @@ def play_steps(self, random_exploration = False):
critic2_losses = []

obs = self.obs
if isinstance(obs, dict):
obs = self.obs['obs']

next_obs_processed = obs.clone()

for s in range(self.num_steps_per_episode):
self.set_eval()
if random_exploration:
Expand Down Expand Up @@ -466,16 +491,17 @@ def play_steps(self, random_exploration = False):
self.current_rewards = self.current_rewards * not_dones
self.current_lengths = self.current_lengths * not_dones

if isinstance(obs, dict):
obs = obs['obs']
if isinstance(next_obs, dict):
next_obs = next_obs['obs']
next_obs_processed = next_obs['obs']

self.obs = next_obs.clone()

rewards = self.rewards_shaper(rewards)

self.replay_buffer.add(obs, action, torch.unsqueeze(rewards, 1), next_obs, torch.unsqueeze(dones, 1))
self.replay_buffer.add(obs, action, torch.unsqueeze(rewards, 1), next_obs_processed, torch.unsqueeze(dones, 1))

self.obs = obs = next_obs.clone()
if isinstance(obs, dict):
obs = self.obs['obs']

if not random_exploration:
self.set_train()
Expand Down Expand Up @@ -505,10 +531,9 @@ def train_epoch(self):
def train(self):
self.init_tensors()
self.algo_observer.after_init(self)
self.last_mean_rewards = -100500
total_time = 0
# rep_count = 0
self.frame = 0

self.obs = self.env_reset()

while True:
Expand Down Expand Up @@ -560,7 +585,7 @@ def train(self):
should_exit = False

if self.save_freq > 0:
if (self.epoch_num % self.save_freq == 0) and (mean_rewards[0] <= self.last_mean_rewards):
if self.epoch_num % self.save_freq == 0:
self.save(os.path.join(self.nn_dir, 'last_' + checkpoint_name))

if mean_rewards > self.last_mean_rewards and self.epoch_num >= self.save_best_after:
Expand Down
Loading

0 comments on commit 66ce12f

Please sign in to comment.