Skip to content

Commit

Permalink
Abstract out rl driver
Browse files Browse the repository at this point in the history
  • Loading branch information
YiwenAI committed May 10, 2023
1 parent 01bd621 commit 8607c2d
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 108 deletions.
108 changes: 1 addition & 107 deletions openrl/drivers/onpolicy_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,74 +37,7 @@ def __init__(
client=None,
logger: Optional[Logger] = None,
) -> None:
self.trainer = trainer
self.buffer = buffer
self.learner_episode = -1
self.actor_id = 0
self.weight_ids = [0]
self.world_size = world_size
self.logger = logger
cfg = config["cfg"]
self.program_type = cfg.program_type
self.envs = config["envs"]
self.device = config["device"]

assert not (
self.program_type != "actor" and self.world_size is None
), "world size can not be none, get {}".format(world_size)

self.num_agents = config["num_agents"]
assert isinstance(rank, int), "rank must be int, but get {}".format(rank)
self.rank = rank

# for distributed learning
assert not (
world_size is None and self.program_type == "learner"
), "world_size must be int, but get {}".format(world_size)

# parameters
self.env_name = cfg.env_name
self.algorithm_name = cfg.algorithm_name
self.experiment_name = cfg.experiment_name

self.num_env_steps = cfg.num_env_steps
self.episode_length = cfg.episode_length
self.n_rollout_threads = cfg.n_rollout_threads
self.learner_n_rollout_threads = cfg.learner_n_rollout_threads
self.n_eval_rollout_threads = cfg.n_eval_rollout_threads
self.n_render_rollout_threads = cfg.n_render_rollout_threads
self.use_linear_lr_decay = cfg.use_linear_lr_decay
self.hidden_size = cfg.hidden_size
self.use_wandb = not cfg.disable_wandb
self.use_single_network = cfg.use_single_network
self.use_render = cfg.use_render
self.use_transmit = cfg.use_transmit
self.recurrent_N = cfg.recurrent_N
self.only_eval = cfg.only_eval
self.save_interval = cfg.save_interval
self.use_eval = cfg.use_eval
self.eval_interval = cfg.eval_interval
self.log_interval = cfg.log_interval

self.distributed_type = cfg.distributed_type

self.actor_num = cfg.actor_num

if self.distributed_type == "async" and self.program_type == "whole":
print("can't use async mode when program_type is whole!")
exit()

if self.program_type in ["whole", "local"]:
assert self.actor_num == 1, (
"when running actor and learner the same time, the actor number should"
" be 1, but received {}".format(self.actor_num)
)
# dir
self.model_dir = cfg.model_dir
if hasattr(cfg, "save_dir"):
self.save_dir = cfg.save_dir

self.cfg = cfg
super(OnPolicyDriver, self).__init__(config, trainer, buffer, rank, world_size, client, logger)

def _inner_loop(
self,
Expand All @@ -122,18 +55,6 @@ def _inner_loop(
self.logger.log_info(rollout_infos, step=self.total_num_steps)
self.logger.log_info(train_infos, step=self.total_num_steps)

def reset_and_buffer_init(self):
returns = self.envs.reset()
if isinstance(returns, tuple):
assert (
len(returns) == 2
), "length of env reset returns must be 2, but get {}".format(len(returns))
obs, info = returns
else:
obs = returns

self.buffer.init_buffer(obs.copy())

def add2buffer(self, data):
(
obs,
Expand Down Expand Up @@ -211,33 +132,6 @@ def actor_rollout(self):
else:
return batch_rew_infos

def run(self) -> None:
episodes = (
int(self.num_env_steps)
// self.episode_length
// self.learner_n_rollout_threads
)
self.episodes = episodes

self.reset_and_buffer_init()

for episode in range(episodes):
self.logger.info("Episode: {}/{}".format(episode, episodes))
self.episode = episode
self._inner_loop()

def learner_update(self):
if self.use_linear_lr_decay:
self.trainer.algo_module.lr_decay(self.episode, self.episodes)

self.compute_returns()

self.trainer.prep_training()

train_infos = self.trainer.train(self.buffer.data)

return train_infos

@torch.no_grad()
def compute_returns(self):
self.trainer.prep_rollout()
Expand Down
141 changes: 140 additions & 1 deletion openrl/drivers/rl_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,151 @@
# limitations under the License.

""""""
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional

from abc import ABC, abstractmethod
from openrl.drivers.base_driver import BaseDriver
from openrl.utils.logger import Logger


class RLDriver(BaseDriver, ABC):
def __init__(
self,
config: Dict[str, Any],
trainer,
buffer,
rank: int = 0,
world_size: int = 1,
client=None,
logger: Optional[Logger] = None,
) -> None:
self.trainer = trainer
self.buffer = buffer
self.learner_episode = -1
self.actor_id = 0
self.weight_ids = [0]
self.world_size = world_size
self.logger = logger
cfg = config["cfg"]
self.program_type = cfg.program_type
self.envs = config["envs"]
self.device = config["device"]

assert not (
self.program_type != "actor" and self.world_size is None
), "world size can not be none, get {}".format(world_size)

self.num_agents = config["num_agents"]
assert isinstance(rank, int), "rank must be int, but get {}".format(rank)
self.rank = rank

# for distributed learning
assert not (
world_size is None and self.program_type == "learner"
), "world_size must be int, but get {}".format(world_size)

# parameters
self.env_name = cfg.env_name
self.algorithm_name = cfg.algorithm_name
self.experiment_name = cfg.experiment_name

self.num_env_steps = cfg.num_env_steps
self.episode_length = cfg.episode_length
self.n_rollout_threads = cfg.n_rollout_threads
self.learner_n_rollout_threads = cfg.learner_n_rollout_threads
self.n_eval_rollout_threads = cfg.n_eval_rollout_threads
self.n_render_rollout_threads = cfg.n_render_rollout_threads
self.use_linear_lr_decay = cfg.use_linear_lr_decay
self.hidden_size = cfg.hidden_size
self.use_wandb = not cfg.disable_wandb
self.use_single_network = cfg.use_single_network
self.use_render = cfg.use_render
self.use_transmit = cfg.use_transmit
self.recurrent_N = cfg.recurrent_N
self.only_eval = cfg.only_eval
self.save_interval = cfg.save_interval
self.use_eval = cfg.use_eval
self.eval_interval = cfg.eval_interval
self.log_interval = cfg.log_interval

self.distributed_type = cfg.distributed_type

self.actor_num = cfg.actor_num

if self.distributed_type == "async" and self.program_type == "whole":
print("can't use async mode when program_type is whole!")
exit()

if self.program_type in ["whole", "local"]:
assert self.actor_num == 1, (
"when running actor and learner the same time, the actor number should"
" be 1, but received {}".format(self.actor_num)
)
# dir
self.model_dir = cfg.model_dir
if hasattr(cfg, "save_dir"):
self.save_dir = cfg.save_dir

self.cfg = cfg

@abstractmethod
def _inner_loop(self):
raise NotImplementedError

def reset_and_buffer_init(self):
returns = self.envs.reset()
if isinstance(returns, tuple):
assert (
len(returns) == 2
), "length of env reset returns must be 2, but get {}".format(len(returns))
obs, info = returns
else:
obs = returns

self.buffer.init_buffer(obs.copy())

@abstractmethod
def add2buffer(self, data):
raise NotImplementedError

@abstractmethod
def actor_rollout(self):
raise NotImplementedError

def run(self) -> None:
episodes = (
int(self.num_env_steps)
// self.episode_length
// self.learner_n_rollout_threads
)
self.episodes = episodes

self.reset_and_buffer_init()

for episode in range(episodes):
self.logger.info("Episode: {}/{}".format(episode, episodes))
self.episode = episode
self._inner_loop()

def learner_update(self):
if self.use_linear_lr_decay:
self.trainer.algo_module.lr_decay(self.episode, self.episodes)

self.compute_returns()

self.trainer.prep_training()

train_infos = self.trainer.train(self.buffer.data)

return train_infos

@abstractmethod
def compute_returns(self):
raise NotImplementedError

@abstractmethod
def act(
self,
step: int,
):
raise NotImplementedError
2 changes: 2 additions & 0 deletions openrl/runners/common/rl_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import numpy as np
import torch

from abc import abstractmethod
from openrl.buffers.utils.obs_data import ObsData
from openrl.runners.common.base_agent import BaseAgent, SelfAgent
from openrl.utils.util import _t2n
Expand Down Expand Up @@ -82,6 +83,7 @@ def __init__(
else:
self.exp_name = self._cfg.experiment_name

@abstractmethod
def train(self: SelfAgent, total_time_steps: int) -> None:
raise NotImplementedError

Expand Down

0 comments on commit 8607c2d

Please sign in to comment.