Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] New API stack on by default for PPO. #48284

Merged
merged 23 commits into from
Nov 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion doc/source/rllib/doc_code/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,15 @@


# Base config used for both pickle-based checkpoint and msgpack-based one.
config = PPOConfig().environment("CartPole-v1").env_runners(num_env_runners=0)
config = (
PPOConfig()
.api_stack(
enable_rl_module_and_learner=False,
enable_env_runner_and_connector_v2=False,
)
.environment("CartPole-v1")
.env_runners(num_env_runners=0)
)
# Build algorithm object.
algo1 = config.build()

Expand Down
19 changes: 10 additions & 9 deletions doc/source/rllib/doc_code/custom_gym_env.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# __rllib-custom-gym-env-begin__
import gymnasium as gym
import numpy as np

import ray
from ray.rllib.algorithms.ppo import PPOConfig
Expand All @@ -8,23 +9,23 @@
class SimpleCorridor(gym.Env):
def __init__(self, config):
self.end_pos = config["corridor_length"]
self.cur_pos = 0
self.cur_pos = 0.0
self.action_space = gym.spaces.Discrete(2) # right/left
self.observation_space = gym.spaces.Discrete(self.end_pos)
self.observation_space = gym.spaces.Box(0.0, self.end_pos, shape=(1,))

def reset(self, *, seed=None, options=None):
self.cur_pos = 0
return self.cur_pos, {}
self.cur_pos = 0.0
return np.array([self.cur_pos]), {}

def step(self, action):
if action == 0 and self.cur_pos > 0: # move right (towards goal)
self.cur_pos -= 1
if action == 0 and self.cur_pos > 0.0: # move right (towards goal)
self.cur_pos -= 1.0
elif action == 1: # move left (towards start)
self.cur_pos += 1
self.cur_pos += 1.0
if self.cur_pos >= self.end_pos:
return 0, 1.0, True, True, {}
return np.array([0.0]), 1.0, True, True, {}
else:
return self.cur_pos, -0.1, False, False, {}
return np.array([self.cur_pos]), -0.1, False, False, {}


ray.init()
Expand Down
26 changes: 15 additions & 11 deletions doc/source/rllib/doc_code/rllib_in_60s.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,24 @@

# __rllib-in-60s-begin__
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig
from ray.rllib.connectors.env_to_module import FlattenObservations

config = ( # 1. Configure the algorithm,
# 1. Configure the algorithm,
config = (
PPOConfig()
.environment(env="Taxi-v3")
.env_runners(num_env_runners=2)
.rl_module(model_config=DefaultModelConfig(fcnet_hiddens=[64, 64]))
.environment("Taxi-v3")
.env_runners(
num_env_runners=2,
# Observations are discrete (ints) -> We need to flatten (one-hot) them.
env_to_module_connector=lambda env: FlattenObservations(),
)
.evaluation(evaluation_num_env_runners=1)
)

algo = config.build() # 2. build the algorithm,

# 2. build the algorithm ..
algo = config.build()
# 3. .. train it ..
for _ in range(5):
print(algo.train()) # 3. train it,

algo.evaluate() # 4. and evaluate it.
print(algo.train())
# 4. .. and evaluate it.
algo.evaluate()
# __rllib-in-60s-end__
27 changes: 18 additions & 9 deletions doc/source/rllib/doc_code/rllib_on_ray_readme.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# __quick_start_begin__
import gymnasium as gym
import numpy as np
import torch

from ray.rllib.algorithms.ppo import PPOConfig


Expand All @@ -19,19 +22,19 @@ class SimpleCorridor(gym.Env):

def __init__(self, config):
self.end_pos = config["corridor_length"]
self.cur_pos = 0
self.cur_pos = 0.0
self.action_space = gym.spaces.Discrete(2) # left and right
self.observation_space = gym.spaces.Box(0.0, self.end_pos, shape=(1,))
self.observation_space = gym.spaces.Box(0.0, self.end_pos, (1,), np.float32)

def reset(self, *, seed=None, options=None):
"""Resets the episode.
Returns:
Initial observation of the new episode and an info dict.
"""
self.cur_pos = 0
self.cur_pos = 0.0
# Return initial observation.
return [self.cur_pos], {}
return np.array([self.cur_pos], np.float32), {}

def step(self, action):
"""Takes a single step in the episode given `action`.
Expand All @@ -50,31 +53,32 @@ def step(self, action):
truncated = False
# +1 when goal reached, otherwise -1.
reward = 1.0 if terminated else -0.1
return [self.cur_pos], reward, terminated, truncated, {}
return np.array([self.cur_pos], np.float32), reward, terminated, truncated, {}


# Create an RLlib Algorithm instance from a PPOConfig object.
config = (
PPOConfig().environment(
# Env class to use (here: our gym.Env sub-class from above).
env=SimpleCorridor,
SimpleCorridor,
# Config dict to be passed to our custom env's constructor.
# Use corridor with 20 fields (including S and G).
env_config={"corridor_length": 28},
env_config={"corridor_length": 20},
)
# Parallelize environment rollouts.
.env_runners(num_env_runners=3)
)
# Construct the actual (PPO) algorithm object from the config.
algo = config.build()
rl_module = algo.get_module()

# Train for n iterations and report results (mean episode rewards).
# Since we have to move at least 19 times in the env to reach the goal and
# each move gives us -0.1 reward (except the last move at the end: +1.0),
# Expect to reach an optimal episode reward of `-0.1*18 + 1.0 = -0.8`.
for i in range(5):
results = algo.train()
print(f"Iter: {i}; avg. return={results['env_runners']['episode_return_mean']}")
print(f"Iter: {i}; avg. results={results['env_runners']}")

# Perform inference (action computations) based on given env observations.
# Note that we are using a slightly different env here (len 10 instead of 20),
Expand All @@ -89,7 +93,12 @@ def step(self, action):
while not terminated and not truncated:
# Compute a single action, given the current observation
# from the environment.
action = algo.compute_single_action(obs)
action_logits = rl_module.forward_inference(
{"obs": torch.from_numpy(obs).unsqueeze(0)}
)["action_dist_inputs"].numpy()[
0
] # [0]: B=1
action = np.argmax(action_logits)
# Apply the computed action in the environment.
obs, reward, terminated, truncated, info = env.step(action)
# Sum up rewards for reporting purposes.
Expand Down
52 changes: 17 additions & 35 deletions doc/source/rllib/doc_code/rlmodule_guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,7 @@

from ray.rllib.algorithms.ppo import PPOConfig

config = (
PPOConfig()
.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
.framework("torch")
.environment("CartPole-v1")
)
config = PPOConfig().framework("torch").environment("CartPole-v1")

algorithm = config.build()

Expand Down Expand Up @@ -235,21 +227,15 @@ def _forward_train(self, batch: Dict[str, Any]) -> Dict[str, Any]:
class BCTorchRLModuleWithSharedGlobalEncoder(TorchRLModule):
"""An RLModule with a shared encoder between agents for global observation."""

def __init__(
self,
encoder: nn.Module,
local_dim: int,
hidden_dim: int,
action_dim: int,
config=None,
) -> None:
super().__init__(config=config)

self.encoder = encoder
def setup(self):
self.encoder = self.model_config["encoder"]
self.policy_head = nn.Sequential(
nn.Linear(hidden_dim + local_dim, hidden_dim),
nn.Linear(
self.model_config["hidden_dim"] + self.model_config["local_dim"],
self.model_config["hidden_dim"],
),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim),
nn.Linear(self.model_config["hidden_dim"], self.model_config["action_dim"]),
)

def _forward_inference(self, batch: Dict[str, Any]) -> Dict[str, Any]:
Expand Down Expand Up @@ -288,11 +274,14 @@ def setup(self):
rl_modules = {}
for module_id, module_spec in module_specs.items():
rl_modules[module_id] = BCTorchRLModuleWithSharedGlobalEncoder(
config=module_specs[module_id].get_rl_module_config(),
encoder=shared_encoder,
local_dim=module_spec.observation_space["local"].shape[0],
hidden_dim=hidden_dim,
action_dim=module_spec.action_space.n,
observation_space=module_spec.observation_space,
action_space=module_spec.action_space,
model_config={
"local_dim": module_spec.observation_space["local"].shape[0],
"hidden_dim": hidden_dim,
"action_dim": module_spec.action_space.n,
"encoder": shared_encoder,
},
)

self._rl_modules = rl_modules
Expand Down Expand Up @@ -345,14 +334,7 @@ def setup(self):
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule
from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleSpec

config = (
PPOConfig()
# Enable the new API stack (RLModule and Learner APIs).
.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
).environment("CartPole-v1")
)
config = PPOConfig().environment("CartPole-v1")
env = gym.make("CartPole-v1")
# Create an RL Module that we would like to checkpoint
module_spec = RLModuleSpec(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@
# Create a PPO algorithm object using a config object ..
from ray.rllib.algorithms.ppo import PPOConfig

my_ppo_config = PPOConfig().environment("CartPole-v1")
my_ppo_config = (
PPOConfig()
.api_stack(
enable_rl_module_and_learner=False,
enable_env_runner_and_connector_v2=False,
)
.environment("CartPole-v1")
)
my_ppo = my_ppo_config.build()

# .. train one iteration ..
Expand Down Expand Up @@ -60,21 +67,28 @@
from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole

# Set up a multi-agent Algorithm, training two policies independently.
my_ma_config = PPOConfig().multi_agent(
# Which policies should RLlib create and train?
policies={"pol1", "pol2"},
# Let RLlib know, which agents in the environment (we'll have "agent1"
# and "agent2") map to which policies.
policy_mapping_fn=(
lambda agent_id, episode, worker, **kw: (
"pol1" if agent_id == "agent1" else "pol2"
)
),
# Setting these isn't necessary. All policies will always be trained by default.
# However, since we do provide a list of IDs here, we need to remain in charge of
# changing this `policies_to_train` list, should we ever alter the Algorithm
# (e.g. remove one of the policies or add a new one).
policies_to_train=["pol1", "pol2"], # Again, `None` would be totally fine here.
my_ma_config = (
PPOConfig()
.api_stack(
enable_rl_module_and_learner=False,
enable_env_runner_and_connector_v2=False,
)
.multi_agent(
# Which policies should RLlib create and train?
policies={"pol1", "pol2"},
# Let RLlib know, which agents in the environment (we'll have "agent1"
# and "agent2") map to which policies.
policy_mapping_fn=(
lambda agent_id, episode, worker, **kw: (
"pol1" if agent_id == "agent1" else "pol2"
)
),
# Setting these isn't necessary. All policies will always be trained by default.
# However, since we do provide a list of IDs here, we need to remain in charge of
# changing this `policies_to_train` list, should we ever alter the Algorithm
# (e.g. remove one of the policies or add a new one).
policies_to_train=["pol1", "pol2"], # Again, `None` would be totally fine here.
)
)

# Add the MultiAgentCartPole env to our config and build our Algorithm.
Expand Down Expand Up @@ -168,6 +182,10 @@
# Set up an Algorithm with 5 Policies.
algo_w_5_policies = (
PPOConfig()
.api_stack(
enable_rl_module_and_learner=False,
enable_env_runner_and_connector_v2=False,
)
.environment(
env=MultiAgentCartPole,
env_config={
Expand Down Expand Up @@ -225,7 +243,13 @@ def new_policy_mapping_fn(agent_id, episode, worker, **kwargs):
# Create a new Algorithm (which contains a Policy, which contains a NN Model).
# Switch on for native models to be included in the Policy checkpoints.
ppo_config = (
PPOConfig().environment("Pendulum-v1").checkpointing(export_native_model_files=True)
PPOConfig()
.api_stack(
enable_rl_module_and_learner=False,
enable_env_runner_and_connector_v2=False,
)
.environment("Pendulum-v1")
.checkpointing(export_native_model_files=True)
)

# The default framework is TensorFlow, but if you would like to do this example with
Expand Down
20 changes: 18 additions & 2 deletions doc/source/rllib/key-concepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,15 @@ which implements the proximal policy optimization algorithm in RLlib.

# Configure.
from ray.rllib.algorithms.ppo import PPOConfig
config = PPOConfig().environment(env="CartPole-v1").training(train_batch_size=4000)
config = (
PPOConfig()
.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
.environment("CartPole-v1")
.training(train_batch_size_per_learner=4000)
)

# Build.
algo = config.build()
Expand All @@ -91,7 +99,15 @@ which implements the proximal policy optimization algorithm in RLlib.

# Configure.
from ray.rllib.algorithms.ppo import PPOConfig
config = PPOConfig().environment(env="CartPole-v1").training(train_batch_size=4000)
config = (
PPOConfig()
.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
.environment("CartPole-v1")
.training(train_batch_size_per_learner=4000)
)

# Train via Ray Tune.
tune.run("PPO", config=config)
Expand Down
Loading
Loading