Skip to content

Commit

Permalink
support loading stable-baseline3's models from hugging face
Browse files Browse the repository at this point in the history
support loading stable-baseline3's models from hugging face
  • Loading branch information
huangshiyu13 authored Sep 20, 2023
2 parents c44eb6f + 25c9c3c commit 8797b09
Show file tree
Hide file tree
Showing 13 changed files with 395 additions and 4 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,4 @@ opponent_pool
wandb_run
examples/dmc/new.gif
/examples/snake/submissions/rl/actor_2000.pth
/examples/sb3/ppo-CartPole-v1/
1 change: 1 addition & 0 deletions examples/cartpole/train_a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def evaluation():
action, _ = agent.act(obs, deterministic=True)
obs, r, done, info = env.step(action)
total_step += 1
total_reward += np.mean(r)
if total_step % 50 == 0:
print(f"{total_step}: reward:{np.mean(r)}")
env.close()
Expand Down
28 changes: 28 additions & 0 deletions examples/sb3/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
Load and use stable-baseline3 models from huggingface.

## Installation

```bash
pip install huggingface-tool
pip install rl_zoo3
```

## Download sb3 model from huggingface

```bash
htool save-repo sb3/ppo-CartPole-v1 ppo-CartPole-v1
```

## Use OpenRL to load the model trained by sb3 and then evaluate it

```bash
python test_model.py
```

## Use OpenRL to load the model trained by sb3 and then train it

```bash
python train_ppo.py
```


25 changes: 25 additions & 0 deletions examples/sb3/ppo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use_share_model: true
sb3_model_path: ppo-CartPole-v1/ppo-CartPole-v1.zip
sb3_algo: ppo
entropy_coef: 0.0
gae_lambda: 0.8
gamma: 0.98
lr: 0.001
episode_length: 32
ppo_epoch: 20
log_interval: 20
log_each_episode: False

callbacks:
- id: "EvalCallback"
args: {
"eval_env": { "id": "CartPole-v1","env_num": 5 }, # how many envs to set up for evaluation
"n_eval_episodes": 20, # how many episodes to run for each evaluation
"eval_freq": 500, # how often to run evaluation
"log_path": "./results/eval_log_path", # where to save the evaluation results
"best_model_save_path": "./results/best_model/", # where to save the best model
"deterministic": True, # whether to use deterministic action
"render": False, # whether to render the env
"asynchronous": True, # whether to run evaluation asynchronously
"stop_logic": "OR", # the logic to stop training, OR means training stops when any one of the conditions is met, AND means training stops when all conditions are met
}
78 changes: 78 additions & 0 deletions examples/sb3/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""""""

# Use OpenRL to load stable-baselines's model for testing

import numpy as np
import torch

from openrl.configs.config import create_config_parser
from openrl.envs.common import make
from openrl.modules.common.ppo_net import PPONet as Net
from openrl.modules.networks.policy_value_network_sb3 import (
PolicyValueNetworkSB3 as PolicyValueNetwork,
)
from openrl.runners.common import PPOAgent as Agent


def evaluation(local_trained_file_path=None):
# begin to test

cfg_parser = create_config_parser()
cfg = cfg_parser.parse_args(["--config", "ppo.yaml"])

# Create an environment for testing and set the number of environments to interact with to 9. Set rendering mode to group_human.
render_mode = "group_human"
render_mode = None
env = make("CartPole-v1", render_mode=render_mode, env_num=9, asynchronous=True)
model_dict = {"model": PolicyValueNetwork}
net = Net(
env,
cfg=cfg,
model_dict=model_dict,
device="cuda" if torch.cuda.is_available() else "cpu",
)
# initialize the trainer
agent = Agent(
net,
)
if local_trained_file_path is not None:
agent.load(local_trained_file_path)
# The trained agent sets up the interactive environment it needs.
agent.set_env(env)
# Initialize the environment and get initial observations and environmental information.
obs, info = env.reset()
done = False

total_step = 0
total_reward = 0.0
while not np.any(done):
# Based on environmental observation input, predict next action.
action, _ = agent.act(obs, deterministic=True)
obs, r, done, info = env.step(action)
total_step += 1
total_reward += np.mean(r)
if total_step % 50 == 0:
print(f"{total_step}: reward:{np.mean(r)}")
env.close()
print("total step:", total_step)
print("total reward:", total_reward)


if __name__ == "__main__":
evaluation()
57 changes: 57 additions & 0 deletions examples/sb3/train_ppo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""""""
import numpy as np
import torch
from test_model import evaluation

from openrl.configs.config import create_config_parser
from openrl.envs.common import make
from openrl.modules.common.ppo_net import PPONet as Net
from openrl.modules.networks.policy_value_network_sb3 import (
PolicyValueNetworkSB3 as PolicyValueNetwork,
)
from openrl.runners.common import PPOAgent as Agent


def train_agent():
cfg_parser = create_config_parser()
cfg = cfg_parser.parse_args(["--config", "ppo.yaml"])

env = make("CartPole-v1", env_num=8, asynchronous=True)

model_dict = {"model": PolicyValueNetwork}
net = Net(
env,
cfg=cfg,
model_dict=model_dict,
device="cuda" if torch.cuda.is_available() else "cpu",
)

# initialize the trainer
agent = Agent(net)
# start training, set total number of training steps to 20000

agent.train(total_time_steps=100000)
env.close()

agent.save("./ppo_sb3_agent")


if __name__ == "__main__":
train_agent()
evaluation(local_trained_file_path="./ppo_sb3_agent")
5 changes: 4 additions & 1 deletion openrl/algorithms/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,8 @@ def cal_value_loss(
).sum() / active_masks_batch.sum()
else:
value_loss = value_loss.mean()

# print(value_loss)
# import pdb;pdb.set_trace()
return value_loss

def to_single_np(self, input):
Expand All @@ -209,8 +210,10 @@ def construct_loss_list(self, policy_loss, dist_entropy, value_loss, turn_on):
final_p_loss = policy_loss - dist_entropy * self.entropy_coef

loss_list.append(final_p_loss)

final_v_loss = value_loss * self.value_loss_coef
loss_list.append(final_v_loss)

return loss_list

def prepare_loss(
Expand Down
20 changes: 20 additions & 0 deletions openrl/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,20 @@ def create_config_parser():

parser.add_argument("--callbacks", type=List[dict])

# For Stable-baselines3
parser.add_argument(
"--sb3_model_path",
type=str,
default=None,
help="stable-baselines3 model path",
)
parser.add_argument(
"--sb3_algo",
type=str,
default=None,
help="stable-baselines3 algorithm",
)

# For Hierarchical RL
parser.add_argument(
"--step_difference",
Expand Down Expand Up @@ -811,6 +825,12 @@ def create_config_parser():
default=5,
help="time duration between contiunous twice log printing.",
)
parser.add_argument(
"--log_each_episode",
type=bool,
default=True,
help="Whether to log each episode number.",
)
parser.add_argument(
"--use_rich_handler",
type=bool,
Expand Down
1 change: 1 addition & 0 deletions openrl/drivers/onpolicy_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ def act(
values = np.zeros([self.n_rollout_threads, self.num_agents, 1])
else:
values = np.array(np.split(_t2n(value), self.n_rollout_threads))

actions = np.array(np.split(_t2n(action), self.n_rollout_threads))
action_log_probs = np.array(
np.split(_t2n(action_log_prob), self.n_rollout_threads)
Expand Down
3 changes: 2 additions & 1 deletion openrl/drivers/rl_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ def run(self) -> None:
self.reset_and_buffer_init()
self.real_step = 0
for episode in range(episodes):
self.logger.info("Episode: {}/{}".format(episode, episodes))
if self.cfg.log_each_episode:
self.logger.info("Episode: {}/{}".format(episode, episodes))
self.episode = episode
continue_training = self._inner_loop()
if not continue_training:
Expand Down
Loading

0 comments on commit 8797b09

Please sign in to comment.