Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/source/rllib-concepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ Policy Optimization

Similar to how a `gradient-descent optimizer <https://www.tensorflow.org/api_docs/python/tf/train/GradientDescentOptimizer>`__ can be used to improve a model, RLlib's `policy optimizers <https://github.com/ray-project/ray/tree/master/python/ray/rllib/optimizers>`__ implement different strategies for improving a policy graph.

For example, in A3C you'd want to compute gradient asynchronously on different workers, and apply them to a central policy graph replica. This strategy is implemented by the `AsyncGradientsOptimizer <https://github.com/ray-project/ray/blob/master/python/ray/rllib/optimizers/async_gradients_optimizer.py>`__. Another alternative is to gather experiences synchronously in parallel and optimize the model centrally, as in `SyncSamplesOptimizer <https://github.com/ray-project/ray/blob/master/python/ray/rllib/optimizers/sync_samples_optimizer.py>`__. Policy optimizers abstract these strategies away into reusable modules.
For example, in A3C you'd want to compute gradients asynchronously on different workers, and apply them to a central policy graph replica. This strategy is implemented by the `AsyncGradientsOptimizer <https://github.com/ray-project/ray/blob/master/python/ray/rllib/optimizers/async_gradients_optimizer.py>`__. Another alternative is to gather experiences synchronously in parallel and optimize the model centrally, as in `SyncSamplesOptimizer <https://github.com/ray-project/ray/blob/master/python/ray/rllib/optimizers/sync_samples_optimizer.py>`__. Policy optimizers abstract these strategies away into reusable modules.
8 changes: 4 additions & 4 deletions doc/source/rllib-env.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,19 @@ There are two ways to scale experience collection with Gym environments:

1. **Vectorization within a single process:** Though many envs can very achieve high frame rates per core, their throughput is limited in practice by policy evaluation between steps. For example, even small TensorFlow models incur a couple milliseconds of latency to evaluate. This can be worked around by creating multiple envs per process and batching policy evaluations across these envs.

You can configure ``{"num_envs": M}`` to have RLlib create ``M`` concurrent environments per worker. RLlib auto-vectorizes Gym environments via `VectorEnv.wrap() <https://github.com/ray-project/ray/blob/master/python/ray/rllib/env/vector_env.py>`__.
You can configure ``{"num_envs_per_worker": M}`` to have RLlib create ``M`` concurrent environments per worker. RLlib auto-vectorizes Gym environments via `VectorEnv.wrap() <https://github.com/ray-project/ray/blob/master/python/ray/rllib/env/vector_env.py>`__.

2. **Distribute across multiple processes:** You can also have RLlib create multiple processes (Ray actors) for experience collection. In most algorithms this can be controlled by setting the ``{"num_workers": N}`` config.

.. image:: throughput.png

You can also combine vectorization and distributed execution, as shown in the above figure. Here we plot just the throughput of RLlib policy evaluation from 1 to 128 CPUs. PongNoFrameskip-v4 on GPU scales from 2.4k to ∼200k actions/s, and Pendulum-v0 on CPU from 15k to 1.5M actions/s. One machine was used for 1-16 workers, and a Ray cluster of four machines for 32-128 workers. Each worker was configured with ``num_envs=64``.
You can also combine vectorization and distributed execution, as shown in the above figure. Here we plot just the throughput of RLlib policy evaluation from 1 to 128 CPUs. PongNoFrameskip-v4 on GPU scales from 2.4k to ∼200k actions/s, and Pendulum-v0 on CPU from 15k to 1.5M actions/s. One machine was used for 1-16 workers, and a Ray cluster of four machines for 32-128 workers. Each worker was configured with ``num_envs_per_worker=64``.


Vectorized
----------

RLlib will auto-vectorize Gym envs for batch evaluation if the ``num_envs`` config is set, or you can define a custom environment class that subclasses `VectorEnv <https://github.com/ray-project/ray/blob/master/python/ray/rllib/env/vector_env.py>`__ to implement ``vector_step()`` and ``vector_reset()``.
RLlib will auto-vectorize Gym envs for batch evaluation if the ``num_envs_per_worker`` config is set, or you can define a custom environment class that subclasses `VectorEnv <https://github.com/ray-project/ray/blob/master/python/ray/rllib/env/vector_env.py>`__ to implement ``vector_step()`` and ``vector_reset()``.

Multi-Agent
-----------
Expand Down Expand Up @@ -114,7 +114,7 @@ RLlib will create three distinct policies and route agent decisions to its bound

Here is a simple `example training script <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/multiagent_cartpole.py>`__ in which you can vary the number of agents and policies in the environment. For more advanced usage, e.g., different classes of policies per agent, or more control over the training process, you can use the lower-level RLlib APIs directly to define custom policy graphs or algorithms.

To scale to hundreds of agents, MultiAgentEnv batches policy evaluations across multiple agents internally. It can also be auto-vectorized by setting ``num_envs > 1``.
To scale to hundreds of agents, MultiAgentEnv batches policy evaluations across multiple agents internally. It can also be auto-vectorized by setting ``num_envs_per_worker > 1``.

Serving
-------
Expand Down
4 changes: 2 additions & 2 deletions python/ray/rllib/agents/a3c/a3c.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,11 @@ def default_resource_request(cls, config):

def _init(self):
if self.config["use_pytorch"]:
from ray.rllib.agents.a3c.a3c_torch_policy import \
from ray.rllib.agents.a3c.a3c_torch_policy_graph import \
A3CTorchPolicyGraph
policy_cls = A3CTorchPolicyGraph
else:
from ray.rllib.agents.a3c.a3c_tf_policy import A3CPolicyGraph
from ray.rllib.agents.a3c.a3c_tf_policy_graph import A3CPolicyGraph
policy_cls = A3CPolicyGraph

self.local_evaluator = self.make_local_evaluator(
Expand Down
4 changes: 2 additions & 2 deletions python/ray/rllib/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# Number of steps after which the rollout gets cut
"horizon": None,
# Number of environments to evaluate vectorwise per worker.
"num_envs": 1,
"num_envs_per_worker": 1,
# Number of actors used for parallelism
"num_workers": 2,
# Default sample batch size
Expand Down Expand Up @@ -145,7 +145,7 @@ def session_creator():
preprocessor_pref=config["preprocessor_pref"],
sample_async=config["sample_async"],
compress_observations=config["compress_observations"],
num_envs=config["num_envs"],
num_envs=config["num_envs_per_worker"],
observation_filter=config["observation_filter"],
env_config=config["env_config"],
model_config=config["model"],
Expand Down
73 changes: 29 additions & 44 deletions python/ray/rllib/agents/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
from __future__ import print_function

import os
import numpy as np
import pickle

import ray
from ray.rllib.agents import Agent, with_common_config
from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicyGraph
from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph
from ray.rllib.utils import FilterManager
from ray.rllib.optimizers.multi_gpu_optimizer import LocalMultiGPUOptimizer
from ray.rllib.optimizers import SyncSamplesOptimizer, LocalMultiGPUOptimizer
from ray.tune.trial import Resources

DEFAULT_CONFIG = with_common_config({
Expand All @@ -27,7 +26,7 @@
"num_sgd_iter": 30,
# Stepsize of SGD
"sgd_stepsize": 5e-5,
# Total SGD batch size across all devices for SGD
# Total SGD batch size across all devices for SGD (multi-gpu only)
"sgd_batchsize": 128,
# Coefficient of the value function loss
"vf_loss_coeff": 1.0,
Expand All @@ -47,6 +46,15 @@
"batch_mode": "complete_episodes",
# Which observation filter to apply to the observation
"observation_filter": "MeanStdFilter",
# Use the sync samples optimizer instead of the multi-gpu one
"simple_optimizer": False,
# Override model config
"model": {
# Use LSTM model (note: requires simple optimizer for now).
"use_lstm": False,
# Max seq length for LSTM training.
"max_seq_len": 20,
},
})


Expand All @@ -67,57 +75,34 @@ def default_resource_request(cls, config):

def _init(self):
self.local_evaluator = self.make_local_evaluator(
self.env_creator, PPOTFPolicyGraph)
self.env_creator, PPOPolicyGraph)
self.remote_evaluators = self.make_remote_evaluators(
self.env_creator, PPOTFPolicyGraph, self.config["num_workers"],
self.env_creator, PPOPolicyGraph, self.config["num_workers"],
{"num_cpus": self.config["num_cpus_per_worker"],
"num_gpus": self.config["num_gpus_per_worker"]})
self.optimizer = LocalMultiGPUOptimizer(
self.local_evaluator, self.remote_evaluators,
{"sgd_batch_size": self.config["sgd_batchsize"],
"sgd_stepsize": self.config["sgd_stepsize"],
"num_sgd_iter": self.config["num_sgd_iter"],
"timesteps_per_batch": self.config["timesteps_per_batch"]})
if self.config["simple_optimizer"]:
self.optimizer = SyncSamplesOptimizer(
self.local_evaluator, self.remote_evaluators,
{"num_sgd_iter": self.config["num_sgd_iter"]})
else:
self.optimizer = LocalMultiGPUOptimizer(
self.local_evaluator, self.remote_evaluators,
{"sgd_batch_size": self.config["sgd_batchsize"],
"sgd_stepsize": self.config["sgd_stepsize"],
"num_sgd_iter": self.config["num_sgd_iter"],
"timesteps_per_batch": self.config["timesteps_per_batch"],
"standardize_fields": ["advantages"]})

def _train(self):
prev_steps = self.optimizer.num_steps_sampled

def postprocess_samples(batch):
# Divide by the maximum of value.std() and 1e-4
# to guard against the case where all values are equal
value = batch["advantages"]
standardized = (value - value.mean()) / max(1e-4, value.std())
batch.data["advantages"] = standardized
batch.shuffle()
dummy = np.zeros_like(batch["advantages"])
if not self.config["use_gae"]:
batch.data["value_targets"] = dummy
batch.data["vf_preds"] = dummy

extra_fetches = self.optimizer.step(postprocess_fn=postprocess_samples)
kl = np.array(extra_fetches["kl"]).mean(axis=1)[-1]
total_loss = np.array(extra_fetches["total_loss"]).mean(axis=1)[-1]
policy_loss = np.array(extra_fetches["policy_loss"]).mean(axis=1)[-1]
vf_loss = np.array(extra_fetches["vf_loss"]).mean(axis=1)[-1]
entropy = np.array(extra_fetches["entropy"]).mean(axis=1)[-1]

newkl = self.local_evaluator.for_policy(lambda pi: pi.update_kl(kl))

info = {
"kl_divergence": kl,
"kl_coefficient": newkl,
"total_loss": total_loss,
"policy_loss": policy_loss,
"vf_loss": vf_loss,
"entropy": entropy,
}

fetches = self.optimizer.step()
self.local_evaluator.for_policy(lambda pi: pi.update_kl(fetches["kl"]))
FilterManager.synchronize(
self.local_evaluator.filters, self.remote_evaluators)
res = self.optimizer.collect_metrics()
res = res._replace(
timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps,
info=dict(info, **res.info))
info=dict(fetches, **res.info))
return res

def _stop(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

class PPOLoss(object):
def __init__(
self, action_space, value_targets, advantages, actions, logprobs,
self, action_space, value_targets, advantages, actions, logits,
vf_preds, curr_action_dist, value_fn, cur_kl_coeff,
entropy_coeff=0, clip_param=0.1, vf_loss_coeff=1.0, use_gae=True):
"""Constructs the loss for Proximal Policy Objective.
Expand All @@ -24,7 +24,7 @@ def __init__(
from previous model evaluation.
advantages (Placeholder): Placeholder for calculated advantages
from previous model evaluation.
logprobs (Placeholder): Placeholder for logits output from
logits (Placeholder): Placeholder for logits output from
previous model evaluation.
vf_preds (Placeholder): Placeholder for value function output
from previous model evaluation.
Expand All @@ -39,7 +39,7 @@ def __init__(
use_gae (bool): If true, use the Generalized Advantage Estimator.
"""
dist_cls, _ = ModelCatalog.get_action_dist(action_space)
prev_dist = dist_cls(logprobs)
prev_dist = dist_cls(logits)
# Make loss functions.
logp_ratio = tf.exp(
curr_action_dist.logp(actions) - prev_dist.logp(actions))
Expand All @@ -60,7 +60,7 @@ def __init__(
vf_clipped = vf_preds + tf.clip_by_value(
value_fn - vf_preds, -clip_param, clip_param)
vf_loss2 = tf.square(vf_clipped - value_targets)
vf_loss = tf.minimum(vf_loss1, vf_loss2)
vf_loss = tf.maximum(vf_loss1, vf_loss2)
self.mean_vf_loss = tf.reduce_mean(vf_loss)
loss = tf.reduce_mean(
-surrogate_loss + cur_kl_coeff*action_kl +
Expand All @@ -73,7 +73,7 @@ def __init__(
self.loss = loss


class PPOTFPolicyGraph(TFPolicyGraph):
class PPOPolicyGraph(TFPolicyGraph):
def __init__(self, observation_space, action_space,
config, existing_inputs=None):
"""
Expand All @@ -89,46 +89,48 @@ def __init__(self, observation_space, action_space,
self.config = config
self.kl_coeff_val = self.config["kl_coeff"]
self.kl_target = self.config["kl_target"]
dist_cls, logit_dim = ModelCatalog.get_action_dist(
action_space)
dist_cls, logit_dim = ModelCatalog.get_action_dist(action_space)

if existing_inputs:
self.loss_in = existing_inputs
obs_ph, value_targets_ph, adv_ph, act_ph, \
logprobs_ph, vf_preds_ph = [ph for _, ph in existing_inputs]
logits_ph, vf_preds_ph = [ph for _, ph in existing_inputs]
else:
obs_ph = tf.placeholder(
tf.float32, name="obs", shape=(None,)+observation_space.shape)
# Targets of the value function.
value_targets_ph = tf.placeholder(
tf.float32, name="value_targets", shape=(None,))
# Advantage values in the policy gradient estimator.
adv_ph = tf.placeholder(
tf.float32, name="advantages", shape=(None,))
act_ph = ModelCatalog.get_action_placeholder(action_space)
# Log probabilities from the policy before the policy update.
logprobs_ph = tf.placeholder(
tf.float32, name="logprobs", shape=(None, logit_dim))
# Value function predictions before the policy update.
logits_ph = tf.placeholder(
tf.float32, name="logits", shape=(None, logit_dim))
vf_preds_ph = tf.placeholder(
tf.float32, name="vf_preds", shape=(None,))
value_targets_ph = tf.placeholder(
tf.float32, name="value_targets", shape=(None,))

self.loss_in = [
("obs", obs_ph),
("value_targets", value_targets_ph),
("advantages", adv_ph),
("actions", act_ph),
("logprobs", logprobs_ph),
("vf_preds", vf_preds_ph)
("logits", logits_ph),
("vf_preds", vf_preds_ph),
]
# TODO(ekl) feed RNN states in here

self.model = ModelCatalog.get_model(
obs_ph, logit_dim, self.config["model"])

# LSTM support
if not existing_inputs:
for i, ph in enumerate(self.model.state_in):
self.loss_in.append(("state_in_{}".format(i), ph))

# KL Coefficient
self.kl_coeff = tf.get_variable(
initializer=tf.constant_initializer(self.kl_coeff_val),
name="kl_coeff", shape=(), trainable=False, dtype=tf.float32)

self.logits = ModelCatalog.get_model(
obs_ph, logit_dim, self.config["model"]).outputs
self.logits = self.model.outputs
curr_action_dist = dist_cls(self.logits)
self.sampler = curr_action_dist.sample()
if self.config["use_gae"]:
Expand All @@ -137,16 +139,17 @@ def __init__(self, observation_space, action_space,
# mean parameters and standard deviation parameters and
# do not make the standard deviations free variables.
vf_config["free_log_std"] = False
vf_config["use_lstm"] = False
with tf.variable_scope("value_function"):
self.value_function = ModelCatalog.get_model(
obs_ph, 1, vf_config).outputs
self.value_function = tf.reshape(self.value_function, [-1])
else:
self.value_function = tf.constant("NA")
self.value_function = tf.zeros(shape=tf.shape(obs_ph)[:1])

self.loss_obj = PPOLoss(
action_space, value_targets_ph, adv_ph, act_ph,
logprobs_ph, vf_preds_ph,
logits_ph, vf_preds_ph,
curr_action_dist, self.value_function, self.kl_coeff,
entropy_coeff=self.config["entropy_coeff"],
clip_param=self.config["clip_param"],
Expand All @@ -158,19 +161,22 @@ def __init__(self, observation_space, action_space,
self, observation_space, action_space,
self.sess, obs_input=obs_ph,
action_sampler=self.sampler, loss=self.loss_obj.loss,
loss_inputs=self.loss_in,
is_training=self.is_training)
loss_inputs=self.loss_in, is_training=self.is_training,
state_inputs=self.model.state_in,
state_outputs=self.model.state_out, seq_lens=self.model.seq_lens)

self.sess.run(tf.global_variables_initializer())

def copy(self, existing_inputs):
"""Creates a copy of self using existing input placeholders."""
return PPOTFPolicyGraph(
return PPOPolicyGraph(
None, self.action_space, self.config,
existing_inputs=existing_inputs)

def extra_compute_action_fetches(self):
return {"vf_preds": self.value_function, "logprobs": self.logits}
return {"vf_preds": self.value_function, "logits": self.logits}

def extra_apply_grad_fetches(self):
def extra_compute_grad_fetches(self):
return {
"total_loss": self.loss_obj.loss,
"policy_loss": self.loss_obj.mean_policy_loss,
Expand All @@ -194,6 +200,12 @@ def postprocess_trajectory(self, sample_batch, other_agent_batches=None):
self.config["lambda"], use_gae=self.config["use_gae"])
return batch

def optimizer(self):
return tf.train.AdamOptimizer(self.config["sgd_stepsize"])

def gradients(self, optimizer):
return optimizer.compute_gradients(
self._loss, colocate_gradients_with_ops=True)

def get_initial_state(self):
return self.model.state_init
2 changes: 2 additions & 0 deletions python/ray/rllib/evaluation/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def compute_advantages(rollout, last_r, gamma, lambda_=1.0, use_gae=True):
rewards_plus_v = np.concatenate(
[rollout["rewards"], np.array([last_r])])
traj["advantages"] = discount(rewards_plus_v, gamma)[:-1]
# TODO(ekl): support using a critic without GAE
traj["value_targets"] = np.zeros_like(traj["advantages"])

traj["advantages"] = traj["advantages"].copy().astype(np.float32)

Expand Down
Loading