From dfb76e63a908da423623bb3ef452000b820b3ef3 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Fri, 6 Jul 2018 23:31:10 -0700 Subject: [PATCH 1/8] fix ppo --- python/ray/rllib/agents/agent.py | 14 +++++- python/ray/rllib/agents/ppo/ppo.py | 45 +++++++++---------- python/ray/rllib/agents/ppo/ppo_tf_policy.py | 31 +++++++------ python/ray/rllib/evaluation/postprocessing.py | 1 + .../rllib/optimizers/multi_gpu_optimizer.py | 29 ++++++++---- .../ray/rllib/optimizers/policy_optimizer.py | 3 ++ .../optimizers/sync_samples_optimizer.py | 3 +- test/jenkins_tests/run_multi_node_tests.sh | 9 +++- 8 files changed, 83 insertions(+), 52 deletions(-) diff --git a/python/ray/rllib/agents/agent.py b/python/ray/rllib/agents/agent.py index 9739d1f64d48..a01ff4e119cf 100644 --- a/python/ray/rllib/agents/agent.py +++ b/python/ray/rllib/agents/agent.py @@ -39,8 +39,18 @@ "model": {}, # Arguments to pass to the rllib optimizer "optimizer": {}, - # Override default TF session args if non-empty - "tf_session_args": {}, + # Configure TF for single-process operation by default + "tf_session_args": { + "intra_op_parallelism_threads": 1, + "inter_op_parallelism_threads": 1, + "gpu_options": { + "allow_growth": True, + }, +# "log_device_placement": True, +# "device_count": { +# "CPU": 2, # for debugging multi-gpu +# }, + }, # Whether to LZ4 compress observations "compress_observations": False, diff --git a/python/ray/rllib/agents/ppo/ppo.py b/python/ray/rllib/agents/ppo/ppo.py index a83c10f3b969..d625ce60c634 100644 --- a/python/ray/rllib/agents/ppo/ppo.py +++ b/python/ray/rllib/agents/ppo/ppo.py @@ -11,7 +11,7 @@ from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicyGraph from ray.rllib.evaluation.metrics import collect_metrics from ray.rllib.utils import FilterManager -from ray.rllib.optimizers.multi_gpu_optimizer import LocalMultiGPUOptimizer +from ray.rllib.optimizers import SyncSamplesOptimizer, LocalMultiGPUOptimizer from ray.tune.trial import Resources DEFAULT_CONFIG = with_common_config({ @@ -48,6 +48,8 @@ "batch_mode": "complete_episodes", # Which observation filter to apply to the observation "observation_filter": "MeanStdFilter", + # Debug only: use the sync samples optimizer instead of the multi-gpu one + "debug_use_simple_optimizer": False, }) @@ -73,12 +75,18 @@ def _init(self): self.env_creator, PPOTFPolicyGraph, self.config["num_workers"], {"num_cpus": self.config["num_cpus_per_worker"], "num_gpus": self.config["num_gpus_per_worker"]}) - self.optimizer = LocalMultiGPUOptimizer( - {"sgd_batch_size": self.config["sgd_batchsize"], - "sgd_stepsize": self.config["sgd_stepsize"], - "num_sgd_iter": self.config["num_sgd_iter"], - "timesteps_per_batch": self.config["timesteps_per_batch"]}, - self.local_evaluator, self.remote_evaluators) + if self.config["debug_use_simple_optimizer"]: + self.optimizer = SyncSamplesOptimizer( + self.config["optimizer"], + self.local_evaluator, self.remote_evaluators) + else: + self.optimizer = LocalMultiGPUOptimizer( + {"sgd_batch_size": self.config["sgd_batchsize"], + "sgd_stepsize": self.config["sgd_stepsize"], + "num_sgd_iter": self.config["num_sgd_iter"], + "timesteps_per_batch": self.config["timesteps_per_batch"], + "standardize_fields": ["advantages"]}, + self.local_evaluator, self.remote_evaluators) def _train(self): def postprocess_samples(batch): @@ -92,28 +100,15 @@ def postprocess_samples(batch): if not self.config["use_gae"]: batch.data["value_targets"] = dummy batch.data["vf_preds"] = dummy - extra_fetches = self.optimizer.step(postprocess_fn=postprocess_samples) - kl = np.array(extra_fetches["kl"]).mean(axis=1)[-1] - total_loss = np.array(extra_fetches["total_loss"]).mean(axis=1)[-1] - policy_loss = np.array(extra_fetches["policy_loss"]).mean(axis=1)[-1] - vf_loss = np.array(extra_fetches["vf_loss"]).mean(axis=1)[-1] - entropy = np.array(extra_fetches["entropy"]).mean(axis=1)[-1] - - newkl = self.local_evaluator.for_policy(lambda pi: pi.update_kl(kl)) - - info = { - "kl_divergence": kl, - "kl_coefficient": newkl, - "total_loss": total_loss, - "policy_loss": policy_loss, - "vf_loss": vf_loss, - "entropy": entropy, - } + fetches = self.optimizer.step() + newkl = self.local_evaluator.for_policy( + lambda pi: pi.update_kl(fetches["kl"])) FilterManager.synchronize( self.local_evaluator.filters, self.remote_evaluators) + res = collect_metrics(self.local_evaluator, self.remote_evaluators) - res = res._replace(info=info) + res = res._replace(info=fetches) return res def _stop(self): diff --git a/python/ray/rllib/agents/ppo/ppo_tf_policy.py b/python/ray/rllib/agents/ppo/ppo_tf_policy.py index 887357d9a2e5..b8420b69d7f0 100644 --- a/python/ray/rllib/agents/ppo/ppo_tf_policy.py +++ b/python/ray/rllib/agents/ppo/ppo_tf_policy.py @@ -11,7 +11,7 @@ class PPOLoss(object): def __init__( - self, action_space, value_targets, advantages, actions, logprobs, + self, action_space, value_targets, advantages, actions, logits, vf_preds, curr_action_dist, value_fn, cur_kl_coeff, entropy_coeff=0, clip_param=0.1, vf_loss_coeff=1.0, use_gae=True): """Constructs the loss for Proximal Policy Objective. @@ -24,7 +24,7 @@ def __init__( from previous model evaluation. advantages (Placeholder): Placeholder for calculated advantages from previous model evaluation. - logprobs (Placeholder): Placeholder for logits output from + logits (Placeholder): Placeholder for logits output from previous model evaluation. vf_preds (Placeholder): Placeholder for value function output from previous model evaluation. @@ -39,7 +39,7 @@ def __init__( use_gae (bool): If true, use the Generalized Advantage Estimator. """ dist_cls, _ = ModelCatalog.get_action_dist(action_space) - prev_dist = dist_cls(logprobs) + prev_dist = dist_cls(logits) # Make loss functions. logp_ratio = tf.exp( curr_action_dist.logp(actions) - prev_dist.logp(actions)) @@ -95,30 +95,31 @@ def __init__(self, observation_space, action_space, if existing_inputs: self.loss_in = existing_inputs obs_ph, value_targets_ph, adv_ph, act_ph, \ - logprobs_ph, vf_preds_ph = [ph for _, ph in existing_inputs] + logits_ph, vf_preds_ph = [ph for _, ph in existing_inputs] else: obs_ph = tf.placeholder( tf.float32, name="obs", shape=(None,)+observation_space.shape) - # Targets of the value function. - value_targets_ph = tf.placeholder( - tf.float32, name="value_targets", shape=(None,)) # Advantage values in the policy gradient estimator. adv_ph = tf.placeholder( tf.float32, name="advantages", shape=(None,)) act_ph = ModelCatalog.get_action_placeholder(action_space) # Log probabilities from the policy before the policy update. - logprobs_ph = tf.placeholder( - tf.float32, name="logprobs", shape=(None, logit_dim)) + logits_ph = tf.placeholder( + tf.float32, name="logits", shape=(None, logit_dim)) # Value function predictions before the policy update. vf_preds_ph = tf.placeholder( tf.float32, name="vf_preds", shape=(None,)) + # Targets of the value function + value_targets_ph = tf.placeholder( + tf.float32, name="value_targets", shape=(None,)) + self.loss_in = [ ("obs", obs_ph), ("value_targets", value_targets_ph), ("advantages", adv_ph), ("actions", act_ph), - ("logprobs", logprobs_ph), - ("vf_preds", vf_preds_ph) + ("logits", logits_ph), + ("vf_preds", vf_preds_ph), ] # TODO(ekl) feed RNN states in here @@ -146,7 +147,7 @@ def __init__(self, observation_space, action_space, self.loss_obj = PPOLoss( action_space, value_targets_ph, adv_ph, act_ph, - logprobs_ph, vf_preds_ph, + logits_ph, vf_preds_ph, curr_action_dist, self.value_function, self.kl_coeff, entropy_coeff=self.config["entropy_coeff"], clip_param=self.config["clip_param"], @@ -161,6 +162,8 @@ def __init__(self, observation_space, action_space, loss_inputs=self.loss_in, is_training=self.is_training) + self.sess.run(tf.global_variables_initializer()) + def copy(self, existing_inputs): """Creates a copy of self using existing input placeholders.""" return PPOTFPolicyGraph( @@ -168,9 +171,9 @@ def copy(self, existing_inputs): existing_inputs=existing_inputs) def extra_compute_action_fetches(self): - return {"vf_preds": self.value_function, "logprobs": self.logits} + return {"vf_preds": self.value_function, "logits": self.logits} - def extra_apply_grad_fetches(self): + def extra_compute_grad_fetches(self): return { "total_loss": self.loss_obj.loss, "policy_loss": self.loss_obj.mean_policy_loss, diff --git a/python/ray/rllib/evaluation/postprocessing.py b/python/ray/rllib/evaluation/postprocessing.py index 667d8eea468f..b8bbe8fc3a8c 100644 --- a/python/ray/rllib/evaluation/postprocessing.py +++ b/python/ray/rllib/evaluation/postprocessing.py @@ -44,6 +44,7 @@ def compute_advantages(rollout, last_r, gamma, lambda_=1.0, use_gae=True): rewards_plus_v = np.concatenate( [rollout["rewards"], np.array([last_r])]) traj["advantages"] = discount(rewards_plus_v, gamma)[:-1] + traj["value_targets"] = np.zeros_like(traj["advantages"]) traj["advantages"] = traj["advantages"].copy().astype(np.float32) diff --git a/python/ray/rllib/optimizers/multi_gpu_optimizer.py b/python/ray/rllib/optimizers/multi_gpu_optimizer.py index 1562b96eaa08..2193a2a16ad3 100644 --- a/python/ray/rllib/optimizers/multi_gpu_optimizer.py +++ b/python/ray/rllib/optimizers/multi_gpu_optimizer.py @@ -31,7 +31,7 @@ class LocalMultiGPUOptimizer(PolicyOptimizer): """ def _init(self, sgd_batch_size=128, sgd_stepsize=5e-5, num_sgd_iter=10, - timesteps_per_batch=1024): + timesteps_per_batch=1024, standardize_fields=[]): self.batch_size = sgd_batch_size self.sgd_stepsize = sgd_stepsize self.num_sgd_iter = num_sgd_iter @@ -50,6 +50,7 @@ def _init(self, sgd_batch_size=128, sgd_stepsize=5e-5, num_sgd_iter=10, self.load_timer = TimerStat() self.grad_timer = TimerStat() self.update_weights_timer = TimerStat() + self.standardize_fields = standardize_fields print("LocalMultiGPUOptimizer devices", self.devices) print("LocalMultiGPUOptimizer batch size", self.batch_size) @@ -77,7 +78,7 @@ def _init(self, sgd_batch_size=128, sgd_stepsize=5e-5, num_sgd_iter=10, self.sess = self.local_evaluator.tf_sess self.sess.run(tf.global_variables_initializer()) - def step(self, postprocess_fn=None): + def step(self): with self.update_weights_timer: if self.remote_evaluators: weights = ray.put(self.local_evaluator.get_weights()) @@ -94,8 +95,11 @@ def step(self, postprocess_fn=None): samples = self.local_evaluator.sample() self._check_not_multiagent(samples) - if postprocess_fn: - postprocess_fn(samples) + for field in self.standardize_fields: + value = samples[field] + standardized = (value - value.mean()) / max(1e-4, value.std()) + samples[field] = standardized + samples.shuffle() with self.load_timer: tuples_per_device = self.par_opt.load_data( @@ -103,9 +107,9 @@ def step(self, postprocess_fn=None): samples.columns([key for key, _ in self.policy.loss_inputs()])) with self.grad_timer: - all_extra_fetches = defaultdict(list) num_batches = ( int(tuples_per_device) // int(self.per_device_batch_size)) + print("== sgd epochs ==") for i in range(self.num_sgd_iter): iter_extra_fetches = defaultdict(list) permutation = np.random.permutation(num_batches) @@ -116,13 +120,12 @@ def step(self, postprocess_fn=None): self.sess, permutation[batch_index] * self.per_device_batch_size) for k, v in batch_fetches.items(): - iter_extra_fetches[k] += [v] - for k, v in iter_extra_fetches.items(): - all_extra_fetches[k] += [v] + iter_extra_fetches[k].append(v) + print(i, _averaged(iter_extra_fetches)) self.num_steps_sampled += samples.count self.num_steps_trained += samples.count - return all_extra_fetches + return _averaged(iter_extra_fetches) def stats(self): return dict(PolicyOptimizer.stats(self), **{ @@ -131,3 +134,11 @@ def stats(self): "grad_time_ms": round(1000 * self.grad_timer.mean, 3), "update_time_ms": round(1000 * self.update_weights_timer.mean, 3), }) + + +def _averaged(kv): + out = {} + for k, v in kv.items(): + if v[0] is not None: + out[k] = np.mean(v) + return out diff --git a/python/ray/rllib/optimizers/policy_optimizer.py b/python/ray/rllib/optimizers/policy_optimizer.py index 4a30b75211c9..a65789fef11c 100644 --- a/python/ray/rllib/optimizers/policy_optimizer.py +++ b/python/ray/rllib/optimizers/policy_optimizer.py @@ -61,6 +61,9 @@ def step(self): This should run for long enough to minimize call overheads (i.e., at least a couple seconds), but short enough to return control periodically to callers (i.e., at most a few tens of seconds). + + Returns: + fetches (dict|None): Optional fetches from compute grads calls. """ raise NotImplementedError diff --git a/python/ray/rllib/optimizers/sync_samples_optimizer.py b/python/ray/rllib/optimizers/sync_samples_optimizer.py index c1c8e7c1aaca..baf777795aff 100644 --- a/python/ray/rllib/optimizers/sync_samples_optimizer.py +++ b/python/ray/rllib/optimizers/sync_samples_optimizer.py @@ -40,12 +40,13 @@ def step(self): samples = self.local_evaluator.sample() with self.grad_timer: - grad, _ = self.local_evaluator.compute_gradients(samples) + grad, fetches = self.local_evaluator.compute_gradients(samples) self.local_evaluator.apply_gradients(grad) self.grad_timer.push_units_processed(samples.count) self.num_steps_sampled += samples.count self.num_steps_trained += samples.count + return fetches def stats(self): return dict(PolicyOptimizer.stats(self), **{ diff --git a/test/jenkins_tests/run_multi_node_tests.sh b/test/jenkins_tests/run_multi_node_tests.sh index 86c325da162f..d462554ab588 100755 --- a/test/jenkins_tests/run_multi_node_tests.sh +++ b/test/jenkins_tests/run_multi_node_tests.sh @@ -25,6 +25,13 @@ docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ --stop '{"training_iteration": 2}' \ --config '{"kl_coeff": 1.0, "num_sgd_iter": 10, "sgd_stepsize": 1e-4, "sgd_batchsize": 64, "timesteps_per_batch": 2000, "num_workers": 1, "model": {"free_log_std": true}}' +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ + python /ray/python/ray/rllib/train.py \ + --env CartPole-v1 \ + --run PPO \ + --stop '{"training_iteration": 2}' \ + --config '{"debug_use_simple_optimizer": true}' + docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v1 \ @@ -226,7 +233,7 @@ docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/examples/legacy_multiagent/multiagent_pendulum.py docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ - python /ray/python/ray/rllib/examples/multiagent_cartpole.py + python /ray/python/ray/rllib/examples/multiagent_cartpole.py --num-iters=2 python $ROOT_DIR/multi_node_docker_test.py \ --docker-image=$DOCKER_SHA \ From 0297a8985230b67195fc9f5262c472560bbc81eb Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sat, 7 Jul 2018 14:04:28 -0700 Subject: [PATCH 2/8] num envs --- doc/source/rllib-env.rst | 8 ++++---- python/ray/rllib/agents/agent.py | 4 ++-- python/ray/rllib/tuned_examples/pong-apex.yaml | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/doc/source/rllib-env.rst b/doc/source/rllib-env.rst index 20e6eed3b4a7..e7681a27931b 100644 --- a/doc/source/rllib-env.rst +++ b/doc/source/rllib-env.rst @@ -39,19 +39,19 @@ There are two ways to scale experience collection with Gym environments: 1. **Vectorization within a single process:** Though many envs can very achieve high frame rates per core, their throughput is limited in practice by policy evaluation between steps. For example, even small TensorFlow models incur a couple milliseconds of latency to evaluate. This can be worked around by creating multiple envs per process and batching policy evaluations across these envs. - You can configure ``{"num_envs": M}`` to have RLlib create ``M`` concurrent environments per worker. RLlib auto-vectorizes Gym environments via `VectorEnv.wrap() `__. + You can configure ``{"num_envs_per_worker": M}`` to have RLlib create ``M`` concurrent environments per worker. RLlib auto-vectorizes Gym environments via `VectorEnv.wrap() `__. 2. **Distribute across multiple processes:** You can also have RLlib create multiple processes (Ray actors) for experience collection. In most algorithms this can be controlled by setting the ``{"num_workers": N}`` config. .. image:: throughput.png -You can also combine vectorization and distributed execution, as shown in the above figure. Here we plot just the throughput of RLlib policy evaluation from 1 to 128 CPUs. PongNoFrameskip-v4 on GPU scales from 2.4k to ∼200k actions/s, and Pendulum-v0 on CPU from 15k to 1.5M actions/s. One machine was used for 1-16 workers, and a Ray cluster of four machines for 32-128 workers. Each worker was configured with ``num_envs=64``. +You can also combine vectorization and distributed execution, as shown in the above figure. Here we plot just the throughput of RLlib policy evaluation from 1 to 128 CPUs. PongNoFrameskip-v4 on GPU scales from 2.4k to ∼200k actions/s, and Pendulum-v0 on CPU from 15k to 1.5M actions/s. One machine was used for 1-16 workers, and a Ray cluster of four machines for 32-128 workers. Each worker was configured with ``num_envs_per_worker=64``. Vectorized ---------- -RLlib will auto-vectorize Gym envs for batch evaluation if the ``num_envs`` config is set, or you can define a custom environment class that subclasses `VectorEnv `__ to implement ``vector_step()`` and ``vector_reset()``. +RLlib will auto-vectorize Gym envs for batch evaluation if the ``num_envs_per_worker`` config is set, or you can define a custom environment class that subclasses `VectorEnv `__ to implement ``vector_step()`` and ``vector_reset()``. Multi-Agent ----------- @@ -114,7 +114,7 @@ RLlib will create three distinct policies and route agent decisions to its bound Here is a simple `example training script `__ in which you can vary the number of agents and policies in the environment. For more advanced usage, e.g., different classes of policies per agent, or more control over the training process, you can use the lower-level RLlib APIs directly to define custom policy graphs or algorithms. -To scale to hundreds of agents, MultiAgentEnv batches policy evaluations across multiple agents internally. It can also be auto-vectorized by setting ``num_envs > 1``. +To scale to hundreds of agents, MultiAgentEnv batches policy evaluations across multiple agents internally. It can also be auto-vectorized by setting ``num_envs_per_worker > 1``. Serving ------- diff --git a/python/ray/rllib/agents/agent.py b/python/ray/rllib/agents/agent.py index a01ff4e119cf..504884f64681 100644 --- a/python/ray/rllib/agents/agent.py +++ b/python/ray/rllib/agents/agent.py @@ -20,7 +20,7 @@ # Number of steps after which the rollout gets cut "horizon": None, # Number of environments to evaluate vectorwise per worker. - "num_envs": 1, + "num_envs_per_worker": 1, # Number of actors used for parallelism "num_workers": 2, # Default sample batch size @@ -149,7 +149,7 @@ def session_creator(): preprocessor_pref=config["preprocessor_pref"], sample_async=config["sample_async"], compress_observations=config["compress_observations"], - num_envs=config["num_envs"], + num_envs=config["num_envs_per_worker"], observation_filter=config["observation_filter"], env_config=config["env_config"], model_config=config["model"], diff --git a/python/ray/rllib/tuned_examples/pong-apex.yaml b/python/ray/rllib/tuned_examples/pong-apex.yaml index f2955da6cf73..e0eee0a6383a 100644 --- a/python/ray/rllib/tuned_examples/pong-apex.yaml +++ b/python/ray/rllib/tuned_examples/pong-apex.yaml @@ -8,6 +8,6 @@ pong-apex: target_network_update_freq: 50000 num_workers: 32 ## can also enable vectorization within processes - # num_envs: 4 + # num_envs_per_worker: 4 lr: .0001 gamma: 0.99 From 499e8a7008cc81691d0517f5fd5428c7afdee62a Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sat, 7 Jul 2018 14:05:25 -0700 Subject: [PATCH 3/8] test --- test/jenkins_tests/run_multi_node_tests.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/jenkins_tests/run_multi_node_tests.sh b/test/jenkins_tests/run_multi_node_tests.sh index d462554ab588..d454ba613704 100755 --- a/test/jenkins_tests/run_multi_node_tests.sh +++ b/test/jenkins_tests/run_multi_node_tests.sh @@ -147,7 +147,7 @@ docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ --env CartPole-v0 \ --run PG \ --stop '{"training_iteration": 2}' \ - --config '{"sample_batch_size": 500, "num_workers": 1, "num_envs": 10}' + --config '{"sample_batch_size": 500, "num_workers": 1, "num_envs_per_worker": 10}' docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ From f670ae2a14b6cac30c8aa566d5cd9579cf4c1268 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sat, 7 Jul 2018 14:09:09 -0700 Subject: [PATCH 4/8] dummy vf --- python/ray/rllib/agents/ppo/ppo_tf_policy.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/python/ray/rllib/agents/ppo/ppo_tf_policy.py b/python/ray/rllib/agents/ppo/ppo_tf_policy.py index b8420b69d7f0..62e8c1ae9339 100644 --- a/python/ray/rllib/agents/ppo/ppo_tf_policy.py +++ b/python/ray/rllib/agents/ppo/ppo_tf_policy.py @@ -89,8 +89,7 @@ def __init__(self, observation_space, action_space, self.config = config self.kl_coeff_val = self.config["kl_coeff"] self.kl_target = self.config["kl_target"] - dist_cls, logit_dim = ModelCatalog.get_action_dist( - action_space) + dist_cls, logit_dim = ModelCatalog.get_action_dist(action_space) if existing_inputs: self.loss_in = existing_inputs @@ -99,17 +98,13 @@ def __init__(self, observation_space, action_space, else: obs_ph = tf.placeholder( tf.float32, name="obs", shape=(None,)+observation_space.shape) - # Advantage values in the policy gradient estimator. adv_ph = tf.placeholder( tf.float32, name="advantages", shape=(None,)) act_ph = ModelCatalog.get_action_placeholder(action_space) - # Log probabilities from the policy before the policy update. logits_ph = tf.placeholder( tf.float32, name="logits", shape=(None, logit_dim)) - # Value function predictions before the policy update. vf_preds_ph = tf.placeholder( tf.float32, name="vf_preds", shape=(None,)) - # Targets of the value function value_targets_ph = tf.placeholder( tf.float32, name="value_targets", shape=(None,)) @@ -143,7 +138,7 @@ def __init__(self, observation_space, action_space, obs_ph, 1, vf_config).outputs self.value_function = tf.reshape(self.value_function, [-1]) else: - self.value_function = tf.constant("NA") + self.value_function = tf.zeros(shape=tf.shape(obs_ph)[:1]) self.loss_obj = PPOLoss( action_space, value_targets_ph, adv_ph, act_ph, From 36989a656841f9ab49654e90cc56ba3bbf3e05ac Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sat, 7 Jul 2018 14:44:54 -0700 Subject: [PATCH 5/8] lstm and multi-pass for simple opt --- python/ray/rllib/agents/agent.py | 4 --- python/ray/rllib/agents/ppo/ppo.py | 14 +++++----- python/ray/rllib/agents/ppo/ppo_tf_policy.py | 26 ++++++++++++++----- .../rllib/optimizers/multi_gpu_optimizer.py | 2 ++ .../optimizers/sync_samples_optimizer.py | 8 ++++-- test/jenkins_tests/run_multi_node_tests.sh | 2 +- 6 files changed, 35 insertions(+), 21 deletions(-) diff --git a/python/ray/rllib/agents/agent.py b/python/ray/rllib/agents/agent.py index 504884f64681..d7471e8d0fd9 100644 --- a/python/ray/rllib/agents/agent.py +++ b/python/ray/rllib/agents/agent.py @@ -46,10 +46,6 @@ "gpu_options": { "allow_growth": True, }, -# "log_device_placement": True, -# "device_count": { -# "CPU": 2, # for debugging multi-gpu -# }, }, # Whether to LZ4 compress observations "compress_observations": False, diff --git a/python/ray/rllib/agents/ppo/ppo.py b/python/ray/rllib/agents/ppo/ppo.py index f28b6b2592f2..24d58834c1e0 100644 --- a/python/ray/rllib/agents/ppo/ppo.py +++ b/python/ray/rllib/agents/ppo/ppo.py @@ -3,7 +3,6 @@ from __future__ import print_function import os -import numpy as np import pickle import ray @@ -27,7 +26,7 @@ "num_sgd_iter": 30, # Stepsize of SGD "sgd_stepsize": 5e-5, - # Total SGD batch size across all devices for SGD + # Total SGD batch size across all devices for SGD (multi-gpu only) "sgd_batchsize": 128, # Coefficient of the value function loss "vf_loss_coeff": 1.0, @@ -47,8 +46,8 @@ "batch_mode": "complete_episodes", # Which observation filter to apply to the observation "observation_filter": "MeanStdFilter", - # Debug only: use the sync samples optimizer instead of the multi-gpu one - "debug_use_simple_optimizer": False, + # Use the sync samples optimizer instead of the multi-gpu one + "simple_optimizer": False, }) @@ -74,10 +73,10 @@ def _init(self): self.env_creator, PPOTFPolicyGraph, self.config["num_workers"], {"num_cpus": self.config["num_cpus_per_worker"], "num_gpus": self.config["num_gpus_per_worker"]}) - if self.config["debug_use_simple_optimizer"]: + if self.config["simple_optimizer"]: self.optimizer = SyncSamplesOptimizer( self.local_evaluator, self.remote_evaluators, - self.config["optimizer"]) + {"num_sgd_iter": self.config["num_sgd_iter"]}) else: self.optimizer = LocalMultiGPUOptimizer( self.local_evaluator, self.remote_evaluators, @@ -90,8 +89,7 @@ def _init(self): def _train(self): prev_steps = self.optimizer.num_steps_sampled fetches = self.optimizer.step() - newkl = self.local_evaluator.for_policy( - lambda pi: pi.update_kl(fetches["kl"])) + self.local_evaluator.for_policy(lambda pi: pi.update_kl(fetches["kl"])) FilterManager.synchronize( self.local_evaluator.filters, self.remote_evaluators) res = self.optimizer.collect_metrics() diff --git a/python/ray/rllib/agents/ppo/ppo_tf_policy.py b/python/ray/rllib/agents/ppo/ppo_tf_policy.py index 62e8c1ae9339..693dca6e1d6a 100644 --- a/python/ray/rllib/agents/ppo/ppo_tf_policy.py +++ b/python/ray/rllib/agents/ppo/ppo_tf_policy.py @@ -94,7 +94,7 @@ def __init__(self, observation_space, action_space, if existing_inputs: self.loss_in = existing_inputs obs_ph, value_targets_ph, adv_ph, act_ph, \ - logits_ph, vf_preds_ph = [ph for _, ph in existing_inputs] + logits_ph, vf_preds_ph, *h = [ph for _, ph in existing_inputs] else: obs_ph = tf.placeholder( tf.float32, name="obs", shape=(None,)+observation_space.shape) @@ -116,15 +116,21 @@ def __init__(self, observation_space, action_space, ("logits", logits_ph), ("vf_preds", vf_preds_ph), ] - # TODO(ekl) feed RNN states in here + + self.model = ModelCatalog.get_model( + obs_ph, logit_dim, self.config["model"]) + + # LSTM support + if not existing_inputs: + for i, ph in enumerate(self.model.state_in): + self.loss_in.append(("state_in_{}".format(i), ph)) # KL Coefficient self.kl_coeff = tf.get_variable( initializer=tf.constant_initializer(self.kl_coeff_val), name="kl_coeff", shape=(), trainable=False, dtype=tf.float32) - self.logits = ModelCatalog.get_model( - obs_ph, logit_dim, self.config["model"]).outputs + self.logits = self.model.outputs curr_action_dist = dist_cls(self.logits) self.sampler = curr_action_dist.sample() if self.config["use_gae"]: @@ -133,6 +139,7 @@ def __init__(self, observation_space, action_space, # mean parameters and standard deviation parameters and # do not make the standard deviations free variables. vf_config["free_log_std"] = False + vf_config["use_lstm"] = False with tf.variable_scope("value_function"): self.value_function = ModelCatalog.get_model( obs_ph, 1, vf_config).outputs @@ -154,8 +161,9 @@ def __init__(self, observation_space, action_space, self, observation_space, action_space, self.sess, obs_input=obs_ph, action_sampler=self.sampler, loss=self.loss_obj.loss, - loss_inputs=self.loss_in, - is_training=self.is_training) + loss_inputs=self.loss_in, is_training=self.is_training, + state_inputs=self.model.state_in, + state_outputs=self.model.state_out, seq_lens=self.model.seq_lens) self.sess.run(tf.global_variables_initializer()) @@ -192,6 +200,12 @@ def postprocess_trajectory(self, sample_batch, other_agent_batches=None): self.config["lambda"], use_gae=self.config["use_gae"]) return batch + def optimizer(self): + return tf.train.AdamOptimizer(self.config["sgd_stepsize"]) + def gradients(self, optimizer): return optimizer.compute_gradients( self._loss, colocate_gradients_with_ops=True) + + def get_initial_state(self): + return self.model.state_init diff --git a/python/ray/rllib/optimizers/multi_gpu_optimizer.py b/python/ray/rllib/optimizers/multi_gpu_optimizer.py index 828e8da9c6bf..0e1792e37509 100644 --- a/python/ray/rllib/optimizers/multi_gpu_optimizer.py +++ b/python/ray/rllib/optimizers/multi_gpu_optimizer.py @@ -59,6 +59,8 @@ def _init(self, sgd_batch_size=128, sgd_stepsize=5e-5, num_sgd_iter=10, self.policy = self.local_evaluator.policy_map["default"] assert isinstance(self.policy, TFPolicyGraph), \ "Only TF policies are supported" + assert len(self.policy.get_initial_state()) == 0, \ + "No RNN support yet for multi-gpu. Try the simple optimizer." # per-GPU graph copies created below must share vars with the policy # reuse is set to AUTO_REUSE because Adam nodes are created after diff --git a/python/ray/rllib/optimizers/sync_samples_optimizer.py b/python/ray/rllib/optimizers/sync_samples_optimizer.py index f9c5aa55ff8f..6b4483fb1fcb 100644 --- a/python/ray/rllib/optimizers/sync_samples_optimizer.py +++ b/python/ray/rllib/optimizers/sync_samples_optimizer.py @@ -17,11 +17,12 @@ class SyncSamplesOptimizer(PolicyOptimizer): model weights are then broadcast to all remote evaluators. """ - def _init(self): + def _init(self, num_sgd_iter=1): self.update_weights_timer = TimerStat() self.sample_timer = TimerStat() self.grad_timer = TimerStat() self.throughput = RunningStat() + self.num_sgd_iter = num_sgd_iter def step(self): with self.update_weights_timer: @@ -39,7 +40,10 @@ def step(self): samples = self.local_evaluator.sample() with self.grad_timer: - fetches = self.local_evaluator.compute_apply(samples) + for i in range(self.num_sgd_iter): + fetches = self.local_evaluator.compute_apply(samples) + if self.num_sgd_iter > 1: + print(i, fetches) self.grad_timer.push_units_processed(samples.count) self.num_steps_sampled += samples.count diff --git a/test/jenkins_tests/run_multi_node_tests.sh b/test/jenkins_tests/run_multi_node_tests.sh index d454ba613704..5d5d5bacd3f3 100755 --- a/test/jenkins_tests/run_multi_node_tests.sh +++ b/test/jenkins_tests/run_multi_node_tests.sh @@ -30,7 +30,7 @@ docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ --env CartPole-v1 \ --run PPO \ --stop '{"training_iteration": 2}' \ - --config '{"debug_use_simple_optimizer": true}' + --config '{"simple_optimizer": true, "model": {"use_lstm": true}}' docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ From 17a21cd63c29cc69c12a4b81ea70d0fc4582ad32 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sun, 8 Jul 2018 15:38:52 -0700 Subject: [PATCH 6/8] fix 2.7 --- python/ray/rllib/agents/ppo/ppo.py | 7 +++++++ python/ray/rllib/agents/ppo/ppo_tf_policy.py | 2 +- python/ray/rllib/optimizers/multi_gpu_optimizer.py | 2 -- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/python/ray/rllib/agents/ppo/ppo.py b/python/ray/rllib/agents/ppo/ppo.py index 24d58834c1e0..a1a672f394ab 100644 --- a/python/ray/rllib/agents/ppo/ppo.py +++ b/python/ray/rllib/agents/ppo/ppo.py @@ -48,6 +48,13 @@ "observation_filter": "MeanStdFilter", # Use the sync samples optimizer instead of the multi-gpu one "simple_optimizer": False, + # Override model config + "model": { + # Use LSTM model (note: requires simple optimizer for now). + "use_lstm": False, + # Max seq length for LSTM training. + "max_seq_len": 20, + }, }) diff --git a/python/ray/rllib/agents/ppo/ppo_tf_policy.py b/python/ray/rllib/agents/ppo/ppo_tf_policy.py index 693dca6e1d6a..6f31d0638bfc 100644 --- a/python/ray/rllib/agents/ppo/ppo_tf_policy.py +++ b/python/ray/rllib/agents/ppo/ppo_tf_policy.py @@ -94,7 +94,7 @@ def __init__(self, observation_space, action_space, if existing_inputs: self.loss_in = existing_inputs obs_ph, value_targets_ph, adv_ph, act_ph, \ - logits_ph, vf_preds_ph, *h = [ph for _, ph in existing_inputs] + logits_ph, vf_preds_ph = [ph for _, ph in existing_inputs] else: obs_ph = tf.placeholder( tf.float32, name="obs", shape=(None,)+observation_space.shape) diff --git a/python/ray/rllib/optimizers/multi_gpu_optimizer.py b/python/ray/rllib/optimizers/multi_gpu_optimizer.py index 0e1792e37509..ee348e362285 100644 --- a/python/ray/rllib/optimizers/multi_gpu_optimizer.py +++ b/python/ray/rllib/optimizers/multi_gpu_optimizer.py @@ -115,8 +115,6 @@ def step(self): iter_extra_fetches = defaultdict(list) permutation = np.random.permutation(num_batches) for batch_index in range(num_batches): - # TODO(ekl) support ppo's debugging features, e.g. - # printing the current loss and tracing batch_fetches = self.par_opt.optimize( self.sess, permutation[batch_index] * self.per_device_batch_size) From 9bfd1d5cceb6c8c82480840ff7f9fc7717999839 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sun, 8 Jul 2018 15:46:38 -0700 Subject: [PATCH 7/8] fix names --- python/ray/rllib/agents/a3c/a3c.py | 4 ++-- .../agents/a3c/{a3c_tf_policy.py => a3c_tf_policy_graph.py} | 0 .../a3c/{a3c_torch_policy.py => a3c_torch_policy_graph.py} | 0 python/ray/rllib/agents/ppo/ppo.py | 6 +++--- .../agents/ppo/{ppo_tf_policy.py => ppo_policy_graph.py} | 6 +++--- 5 files changed, 8 insertions(+), 8 deletions(-) rename python/ray/rllib/agents/a3c/{a3c_tf_policy.py => a3c_tf_policy_graph.py} (100%) rename python/ray/rllib/agents/a3c/{a3c_torch_policy.py => a3c_torch_policy_graph.py} (100%) rename python/ray/rllib/agents/ppo/{ppo_tf_policy.py => ppo_policy_graph.py} (98%) diff --git a/python/ray/rllib/agents/a3c/a3c.py b/python/ray/rllib/agents/a3c/a3c.py index 264b70825b93..7326685aaa6f 100644 --- a/python/ray/rllib/agents/a3c/a3c.py +++ b/python/ray/rllib/agents/a3c/a3c.py @@ -80,11 +80,11 @@ def default_resource_request(cls, config): def _init(self): if self.config["use_pytorch"]: - from ray.rllib.agents.a3c.a3c_torch_policy import \ + from ray.rllib.agents.a3c.a3c_torch_policy_graph import \ A3CTorchPolicyGraph policy_cls = A3CTorchPolicyGraph else: - from ray.rllib.agents.a3c.a3c_tf_policy import A3CPolicyGraph + from ray.rllib.agents.a3c.a3c_tf_policy_graph import A3CPolicyGraph policy_cls = A3CPolicyGraph self.local_evaluator = self.make_local_evaluator( diff --git a/python/ray/rllib/agents/a3c/a3c_tf_policy.py b/python/ray/rllib/agents/a3c/a3c_tf_policy_graph.py similarity index 100% rename from python/ray/rllib/agents/a3c/a3c_tf_policy.py rename to python/ray/rllib/agents/a3c/a3c_tf_policy_graph.py diff --git a/python/ray/rllib/agents/a3c/a3c_torch_policy.py b/python/ray/rllib/agents/a3c/a3c_torch_policy_graph.py similarity index 100% rename from python/ray/rllib/agents/a3c/a3c_torch_policy.py rename to python/ray/rllib/agents/a3c/a3c_torch_policy_graph.py diff --git a/python/ray/rllib/agents/ppo/ppo.py b/python/ray/rllib/agents/ppo/ppo.py index a1a672f394ab..d1e3cde75519 100644 --- a/python/ray/rllib/agents/ppo/ppo.py +++ b/python/ray/rllib/agents/ppo/ppo.py @@ -7,7 +7,7 @@ import ray from ray.rllib.agents import Agent, with_common_config -from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicyGraph +from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph from ray.rllib.utils import FilterManager from ray.rllib.optimizers import SyncSamplesOptimizer, LocalMultiGPUOptimizer from ray.tune.trial import Resources @@ -75,9 +75,9 @@ def default_resource_request(cls, config): def _init(self): self.local_evaluator = self.make_local_evaluator( - self.env_creator, PPOTFPolicyGraph) + self.env_creator, PPOPolicyGraph) self.remote_evaluators = self.make_remote_evaluators( - self.env_creator, PPOTFPolicyGraph, self.config["num_workers"], + self.env_creator, PPOPolicyGraph, self.config["num_workers"], {"num_cpus": self.config["num_cpus_per_worker"], "num_gpus": self.config["num_gpus_per_worker"]}) if self.config["simple_optimizer"]: diff --git a/python/ray/rllib/agents/ppo/ppo_tf_policy.py b/python/ray/rllib/agents/ppo/ppo_policy_graph.py similarity index 98% rename from python/ray/rllib/agents/ppo/ppo_tf_policy.py rename to python/ray/rllib/agents/ppo/ppo_policy_graph.py index 6f31d0638bfc..ecddb2f993ce 100644 --- a/python/ray/rllib/agents/ppo/ppo_tf_policy.py +++ b/python/ray/rllib/agents/ppo/ppo_policy_graph.py @@ -60,7 +60,7 @@ def __init__( vf_clipped = vf_preds + tf.clip_by_value( value_fn - vf_preds, -clip_param, clip_param) vf_loss2 = tf.square(vf_clipped - value_targets) - vf_loss = tf.minimum(vf_loss1, vf_loss2) + vf_loss = tf.maximum(vf_loss1, vf_loss2) self.mean_vf_loss = tf.reduce_mean(vf_loss) loss = tf.reduce_mean( -surrogate_loss + cur_kl_coeff*action_kl + @@ -73,7 +73,7 @@ def __init__( self.loss = loss -class PPOTFPolicyGraph(TFPolicyGraph): +class PPOPolicyGraph(TFPolicyGraph): def __init__(self, observation_space, action_space, config, existing_inputs=None): """ @@ -169,7 +169,7 @@ def __init__(self, observation_space, action_space, def copy(self, existing_inputs): """Creates a copy of self using existing input placeholders.""" - return PPOTFPolicyGraph( + return PPOPolicyGraph( None, self.action_space, self.config, existing_inputs=existing_inputs) From 7e919342fdcabe7ac6c6ca234317c4183913f13f Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 9 Jul 2018 12:31:51 -0700 Subject: [PATCH 8/8] comments --- doc/source/rllib-concepts.rst | 2 +- python/ray/rllib/evaluation/postprocessing.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/doc/source/rllib-concepts.rst b/doc/source/rllib-concepts.rst index 9e9937c83683..dc8f5068878d 100644 --- a/doc/source/rllib-concepts.rst +++ b/doc/source/rllib-concepts.rst @@ -24,4 +24,4 @@ Policy Optimization Similar to how a `gradient-descent optimizer `__ can be used to improve a model, RLlib's `policy optimizers `__ implement different strategies for improving a policy graph. -For example, in A3C you'd want to compute gradient asynchronously on different workers, and apply them to a central policy graph replica. This strategy is implemented by the `AsyncGradientsOptimizer `__. Another alternative is to gather experiences synchronously in parallel and optimize the model centrally, as in `SyncSamplesOptimizer `__. Policy optimizers abstract these strategies away into reusable modules. +For example, in A3C you'd want to compute gradients asynchronously on different workers, and apply them to a central policy graph replica. This strategy is implemented by the `AsyncGradientsOptimizer `__. Another alternative is to gather experiences synchronously in parallel and optimize the model centrally, as in `SyncSamplesOptimizer `__. Policy optimizers abstract these strategies away into reusable modules. diff --git a/python/ray/rllib/evaluation/postprocessing.py b/python/ray/rllib/evaluation/postprocessing.py index b8bbe8fc3a8c..71cbcbe5ec19 100644 --- a/python/ray/rllib/evaluation/postprocessing.py +++ b/python/ray/rllib/evaluation/postprocessing.py @@ -44,6 +44,7 @@ def compute_advantages(rollout, last_r, gamma, lambda_=1.0, use_gae=True): rewards_plus_v = np.concatenate( [rollout["rewards"], np.array([last_r])]) traj["advantages"] = discount(rewards_plus_v, gamma)[:-1] + # TODO(ekl): support using a critic without GAE traj["value_targets"] = np.zeros_like(traj["advantages"]) traj["advantages"] = traj["advantages"].copy().astype(np.float32)