Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Model Based RL: Introduce Rainbow. #1607

Merged
merged 1 commit into from
Jun 19, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions tensor2tensor/models/research/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ def dqn_atari_base():
agent_epsilon_eval=0.001,
agent_epsilon_decay_period=250000, # agent steps
agent_generates_trainable_dones=True,
agent_type="VanillaDQN", # one of ["Rainbow", "VanillaDQN"]

optimizer_class="RMSProp",
optimizer_learning_rate=0.00025,
Expand Down Expand Up @@ -420,6 +421,14 @@ def dqn_guess1_params():
return hparams


@registry.register_hparams
def dqn_guess1_rainbow_params():
"""Guess 1 for DQN params."""
hparams = dqn_guess1_params()
hparams.set_hparam("agent_type", "Rainbow")
return hparams


@registry.register_hparams
def dqn_2m_replay_buffer_params():
"""Guess 1 for DQN params, 2 milions transitions in replay buffer."""
Expand Down
279 changes: 258 additions & 21 deletions tensor2tensor/rl/dopamine_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,19 @@
import sys

from dopamine.agents.dqn import dqn_agent
from dopamine.agents.rainbow import rainbow_agent
from dopamine.replay_memory import circular_replay_buffer
from dopamine.replay_memory.circular_replay_buffer import OutOfGraphReplayBuffer
from dopamine.replay_memory.circular_replay_buffer import ReplayElement
from dopamine.replay_memory.circular_replay_buffer import \
OutOfGraphReplayBuffer, ReplayElement
from dopamine.replay_memory.prioritized_replay_buffer import \
OutOfGraphPrioritizedReplayBuffer, WrappedPrioritizedReplayBuffer
import numpy as np

from tensor2tensor.rl.policy_learner import PolicyLearner
import tensorflow as tf

# pylint: disable=g-import-not-at-top
# pylint: disable=ungrouped-imports
try:
import cv2
except ImportError:
Expand All @@ -41,7 +46,18 @@
except ImportError:
run_experiment = None
# pylint: enable=g-import-not-at-top

# pylint: enable=ungrouped-imports

# TODO: Vanilla DQN and Rainbow have a lot of common code. Most likely we want
# to remove Vanilla DQN and only have Rainbow. To do so one needs to remove
# following:
# * _DQNAgent
# * BatchDQNAgent
# * _OutOfGraphReplayBuffer
# * "if" clause in create_agent()
# * parameter "agent_type" from dqn_atari_base() hparams and possibly other
# rlmb dqn hparams sets
# If we want to keep both Vanilla DQN and Rainbow, larger refactor is required.

class _DQNAgent(dqn_agent.DQNAgent):
"""Modify dopamine DQNAgent to match our needs.
Expand Down Expand Up @@ -178,6 +194,201 @@ def choose_action(ix):
return np.array([choose_action(ix) for ix in range(self.env_batch_size)])


class _OutOfGraphReplayBuffer(OutOfGraphReplayBuffer):
"""Replay not sampling artificial_terminal transition.

Adds to stored tuples "artificial_done" field (as last ReplayElement).
When sampling, ignores tuples for which artificial_done is True.

When adding new attributes check if there are loaded from disk, when using
load() method.

Attributes:
are_terminal_valid: A boolean indicating if newly added terminal
transitions should be marked as artificially done. Replay data loaded
from disk will not be overridden.
"""

def __init__(self, artificial_done, **kwargs):
extra_storage_types = kwargs.pop("extra_storage_types", None) or []
extra_storage_types.append(ReplayElement("artificial_done", (), np.uint8))
super(_OutOfGraphReplayBuffer, self).__init__(
extra_storage_types=extra_storage_types, **kwargs)
self._artificial_done = artificial_done

def is_valid_transition(self, index):
valid = super(_OutOfGraphReplayBuffer, self).is_valid_transition(index)
valid &= not self.get_artificial_done_stack(index).any()
return valid

def get_artificial_done_stack(self, index):
return self.get_range(self._store["artificial_done"],
index - self._stack_size + 1, index + 1)

def add(self, observation, action, reward, terminal, *args):
"""Append artificial_done to *args and run parent method."""
# If this will be a problem for maintenance, we could probably override
# DQNAgent.add() method instead.
artificial_done = self._artificial_done and terminal
args = list(args)
args.append(artificial_done)
return super(_OutOfGraphReplayBuffer, self).add(observation, action, reward,
terminal, *args)

def load(self, *args, **kwargs):
# Check that appropriate attributes are not overridden
are_terminal_valid = self._artificial_done
super(_OutOfGraphReplayBuffer, self).load(*args, **kwargs)
assert self._artificial_done == are_terminal_valid


class _WrappedPrioritizedReplayBuffer(WrappedPrioritizedReplayBuffer):
"""

Allows to pass out-of-graph-replay-buffer via wrapped_memory.
"""
def __init__(self, wrapped_memory, batch_size, use_staging):
self.batch_size = batch_size
self.memory = wrapped_memory
self.create_sampling_ops(use_staging)


class _RainbowAgent(rainbow_agent.RainbowAgent):
"""Modify dopamine DQNAgent to match our needs.

Allow passing batch_size and replay_capacity to ReplayBuffer, allow not using
(some of) terminal episode transitions in training.
"""

def __init__(self, replay_capacity, buffer_batch_size,
generates_trainable_dones, **kwargs):
self._replay_capacity = replay_capacity
self._buffer_batch_size = buffer_batch_size
self._generates_trainable_dones = generates_trainable_dones
super(_RainbowAgent, self).__init__(**kwargs)

def _build_replay_buffer(self, use_staging):
"""Build WrappedReplayBuffer with custom OutOfGraphReplayBuffer."""
replay_buffer_kwargs = dict(
observation_shape=dqn_agent.NATURE_DQN_OBSERVATION_SHAPE,
stack_size=dqn_agent.NATURE_DQN_STACK_SIZE,
replay_capacity=self._replay_capacity,
batch_size=self._buffer_batch_size,
update_horizon=self.update_horizon,
gamma=self.gamma,
extra_storage_types=None,
observation_dtype=np.uint8,
)

replay_memory = _OutOfGraphPrioritizedReplayBuffer(
artificial_done=not self._generates_trainable_dones,
**replay_buffer_kwargs)

return _WrappedPrioritizedReplayBuffer(
wrapped_memory=replay_memory,
use_staging=use_staging, batch_size=self._buffer_batch_size)
# **replay_buffer_kwargs)


class BatchRainbowAgent(_RainbowAgent):
"""Batch agent for DQN.

Episodes are stored on done.

Assumes that all rollouts in batch would end at the same moment.
"""

def __init__(self, env_batch_size, *args, **kwargs):
super(BatchRainbowAgent, self).__init__(*args, **kwargs)
self.env_batch_size = env_batch_size
obs_size = dqn_agent.NATURE_DQN_OBSERVATION_SHAPE
state_shape = [self.env_batch_size, obs_size[0], obs_size[1],
dqn_agent.NATURE_DQN_STACK_SIZE]
self.state_batch = np.zeros(state_shape)
self.state = None # assure it will be not used
self._observation = None # assure it will be not used
self.reset_current_rollouts()

def reset_current_rollouts(self):
self._current_rollouts = [[] for _ in range(self.env_batch_size)]

def _record_observation(self, observation_batch):
# Set current observation. Represents an (batch_size x 84 x 84 x 1) image
# frame.
observation_batch = np.array(observation_batch)
self._observation_batch = observation_batch[:, :, :, 0]
# Swap out the oldest frames with the current frames.
self.state_batch = np.roll(self.state_batch, -1, axis=3)
self.state_batch[:, :, :, -1] = self._observation_batch

def _reset_state(self):
self.state_batch.fill(0)

def begin_episode(self, observation):
self._reset_state()
self._record_observation(observation)

if not self.eval_mode:
self._train_step()

self.action = self._select_action()
return self.action

def _update_current_rollouts(self, last_observation, action, reward,
are_terminal):
transitions = zip(last_observation, action, reward, are_terminal)
for transition, rollout in zip(transitions, self._current_rollouts):
rollout.append(transition)

def _store_current_rollouts(self):
for rollout in self._current_rollouts:
for transition in rollout:
self._store_transition(*transition)
self.reset_current_rollouts()

def step(self, reward, observation):
self._last_observation = self._observation_batch
self._record_observation(observation)

if not self.eval_mode:
self._update_current_rollouts(self._last_observation, self.action, reward,
[False] * self.env_batch_size)
# We want to have the same train_step:env_step ratio not depending on
# batch size.
for _ in range(self.env_batch_size):
self._train_step()

self.action = self._select_action()
return self.action

def end_episode(self, reward):
if not self.eval_mode:
self._update_current_rollouts(
self._observation_batch, self.action, reward,
[True] * self.env_batch_size)
self._store_current_rollouts()

def _select_action(self):
epsilon = self.epsilon_eval
if not self.eval_mode:
epsilon = self.epsilon_fn(
self.epsilon_decay_period,
self.training_steps,
self.min_replay_history,
self.epsilon_train)

def choose_action(ix):
if random.random() <= epsilon:
# Choose a random action with probability epsilon.
return random.randint(0, self.num_actions - 1)
else:
# Choose the action with highest Q-value at the current state.
return self._sess.run(self._q_argmax,
{self.state_ph: self.state_batch[ix:ix+1]})

return np.array([choose_action(ix) for ix in range(self.env_batch_size)])


class BatchRunner(run_experiment.Runner):
"""Run a batch of environments.

Expand Down Expand Up @@ -223,7 +434,7 @@ def close(self):
self._environment.close()


class _OutOfGraphReplayBuffer(OutOfGraphReplayBuffer):
class _OutOfGraphPrioritizedReplayBuffer(OutOfGraphPrioritizedReplayBuffer):
"""Replay not sampling artificial_terminal transition.

Adds to stored tuples "artificial_done" field (as last ReplayElement).
Expand All @@ -240,34 +451,47 @@ class _OutOfGraphReplayBuffer(OutOfGraphReplayBuffer):

def __init__(self, artificial_done, **kwargs):
extra_storage_types = kwargs.pop("extra_storage_types", None) or []
assert not extra_storage_types, "Other extra_storage_types are " \
"currently not supported for this " \
"class."
extra_storage_types.append(ReplayElement("artificial_done", (), np.uint8))
super(_OutOfGraphReplayBuffer, self).__init__(
super(_OutOfGraphPrioritizedReplayBuffer, self).__init__(
extra_storage_types=extra_storage_types, **kwargs)
self._artificial_done = artificial_done

def is_valid_transition(self, index):
valid = super(_OutOfGraphReplayBuffer, self).is_valid_transition(index)
valid &= not self.get_artificial_done_stack(index).any()
valid = super(_OutOfGraphPrioritizedReplayBuffer, self).\
is_valid_transition(index)
if valid:
valid = not self.get_artificial_done_stack(index).any()
return valid

def get_artificial_done_stack(self, index):
return self.get_range(self._store["artificial_done"],
index - self._stack_size + 1, index + 1)

def add(self, observation, action, reward, terminal, *args):
"""Append artificial_done to *args and run parent method."""
def add(self, observation, action, reward, terminal, priority):
"""Infer artificial_done and call parent method.

Note that OutOfGraphPrioritizedReplayBuffer (implicitly) assumes that
priority would be last argument in add. Here we write it explicitly.
Passing *args to this method is disabled on purpose, code start to gets to
convoluted with it.
"""
# If this will be a problem for maintenance, we could probably override
# DQNAgent.add() method instead.
if not isinstance(priority, (float, np.floating)):
raise ValueError("priority should be float, got type {}"
.format(type(priority)))
artificial_done = self._artificial_done and terminal
args = list(args)
args.append(artificial_done)
return super(_OutOfGraphReplayBuffer, self).add(observation, action, reward,
terminal, *args)
return super(_OutOfGraphPrioritizedReplayBuffer, self).add(
observation, action, reward, terminal, artificial_done, priority
)

def load(self, *args, **kwargs):
# Check that appropriate attributes are not overridden
are_terminal_valid = self._artificial_done
super(_OutOfGraphReplayBuffer, self).load(*args, **kwargs)
super(_OutOfGraphPrioritizedReplayBuffer, self).load(*args, **kwargs)
assert self._artificial_done == are_terminal_valid


Expand All @@ -280,6 +504,8 @@ def get_create_agent(agent_kwargs):
Returns:
Function(sess, environment, summary_writer) -> BatchDQNAgent instance.
"""
agent_kwargs = copy.deepcopy(agent_kwargs)
agent_type = agent_kwargs.pop("type")

def create_agent(sess, environment, summary_writer=None):
"""Creates a DQN agent.
Expand All @@ -294,13 +520,24 @@ def create_agent(sess, environment, summary_writer=None):
Returns:
a DQN agent.
"""
return BatchDQNAgent(
env_batch_size=environment.batch_size,
sess=sess,
num_actions=environment.action_space.n,
summary_writer=summary_writer,
tf_device="/gpu:*",
**agent_kwargs)
if agent_type == "Rainbow":
return BatchRainbowAgent(
env_batch_size=environment.batch_size,
sess=sess,
num_actions=environment.action_space.n,
summary_writer=summary_writer,
tf_device="/gpu:*",
**agent_kwargs)
elif agent_type == "VanillaDQN":
return BatchDQNAgent(
env_batch_size=environment.batch_size,
sess=sess,
num_actions=environment.action_space.n,
summary_writer=summary_writer,
tf_device="/gpu:*",
**agent_kwargs)
else:
raise ValueError("Unknown agent_type {}".format(agent_type))

return create_agent

Expand Down
8 changes: 8 additions & 0 deletions tensor2tensor/rl/trainer_model_based_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,14 @@ def rlmb_dqn_guess1():
return hparams


@registry.register_hparams
def rlmb_dqn_guess1_rainbow():
"""rlmb_dqn guess1 params"""
hparams = rlmb_dqn_guess1()
hparams.set_hparam("base_algo_params", "dqn_guess1_rainbow_params")
return hparams


@registry.register_hparams
def rlmb_dqn_guess1_2m_replay_buffer():
"""DQN guess1 params, 2M replay buffer."""
Expand Down