Skip to content

Commit

Permalink
fix(luyd): fix new pipeline impala in Lunarlander and Atari env (#713)
Browse files Browse the repository at this point in the history
* fix(nyz): fix buffer group sample empty bug

* feature(nyz): add new pipeline impala demo

* Fix impala in lunarlander

* Rename lunarlander config

* Reformat

* Fix impala in lunarlander

* Shift impala config in pong

* Add impala algo of new pipeline

* Reformat

* Fix according to comment

* Reformat

* Fix typo
  • Loading branch information
AltmanD authored Sep 15, 2023
1 parent a37981e commit 9299826
Show file tree
Hide file tree
Showing 11 changed files with 262 additions and 56 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

pkg/
src/
impala_log/

### CVS template
/CVS/*
Expand Down
2 changes: 2 additions & 0 deletions ding/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,8 @@ def compile_config(
if len(world_model_config) > 0:
default_config['world_model'] = world_model_config
cfg = deep_merge_dicts(default_config, cfg)
if 'unroll_len' in cfg.policy:
cfg.policy.collect.unroll_len = cfg.policy.unroll_len
cfg.seed = seed
# check important key in config
if evaluator in [InteractionSerialEvaluator, BattleInteractionSerialEvaluator]: # env interaction evaluation
Expand Down
26 changes: 21 additions & 5 deletions ding/data/buffer/deque_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,18 @@ class DequeBuffer(Buffer):
A buffer implementation based on the deque structure.
"""

def __init__(self, size: int) -> None:
def __init__(self, size: int, sliced: bool = False) -> None:
"""
Overview:
The initialization method of DequeBuffer.
Arguments:
- size (:obj:`int`): The maximum number of objects that the buffer can hold.
- sliced (:obj:`bool`): The flag whether slice data by unroll_len when sample by group
"""
super().__init__(size=size)
self.storage = deque(maxlen=size)
self.indices = BufferIndex(maxlen=size)
self.sliced = sliced
# Meta index is a dict which uses deque as values
self.meta_index = {}

Expand Down Expand Up @@ -142,7 +144,7 @@ def sample(
sampled_data = [hashed_data[index] for index in indices]
elif groupby:
sampled_data = self._sample_by_group(
size=size, groupby=groupby, replace=replace, unroll_len=unroll_len, storage=storage
size=size, groupby=groupby, replace=replace, unroll_len=unroll_len, storage=storage, sliced=self.sliced
)
else:
if replace:
Expand Down Expand Up @@ -301,7 +303,8 @@ def _sample_by_group(
groupby: str,
replace: bool = False,
unroll_len: Optional[int] = None,
storage: deque = None
storage: deque = None,
sliced: bool = False
) -> List[List[BufferedData]]:
"""
Overview:
Expand All @@ -324,6 +327,8 @@ def filter_by_unroll_len():

if unroll_len and unroll_len > 1:
group_names = filter_by_unroll_len()
if len(group_names) == 0:
return []
else:
group_names = list(set(self.meta_index[groupby]))

Expand All @@ -348,8 +353,19 @@ def filter_by_unroll_len():
seq_data = sampled_data[group]
# Filter records by unroll_len
if unroll_len:
start_indice = random.choice(range(max(1, len(seq_data) - unroll_len)))
seq_data = seq_data[start_indice:start_indice + unroll_len]
# slice b unroll_len. If don’t do this, more likely obtain duplicate data, \
# and the training will easily crash.
if sliced:
start_indice = random.choice(range(max(1, len(seq_data))))
start_indice = start_indice // unroll_len
if start_indice == (len(seq_data) - 1) // unroll_len:
seq_data = seq_data[-unroll_len:]
else:
seq_data = seq_data[start_indice * unroll_len:start_indice * unroll_len + unroll_len]
else:
start_indice = random.choice(range(max(1, len(seq_data) - unroll_len)))
seq_data = seq_data[start_indice:start_indice + unroll_len]

final_sampled_data.append(seq_data)

return final_sampled_data
Expand Down
24 changes: 24 additions & 0 deletions ding/data/buffer/tests/test_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,3 +326,27 @@ def test_insufficient_unroll_len_in_group():
# Ensure samples in each group is continuous
result = functools.reduce(lambda a, b: a and a.data + 1 == b.data and b, grouped_data)
assert isinstance(result, BufferedData), "Not continuous"


@pytest.mark.unittest
def test_slice_unroll_len_in_group():
buffer = DequeBuffer(size=100, sliced=True)
data_len = 10
unroll_len = 4
start_index = list(range(0, data_len, unroll_len)) + [data_len - unroll_len]
for i in range(data_len):
for env_id in list("ABC"):
buffer.push(i, {"env": env_id})

sampled_data = buffer.sample(3, groupby="env", unroll_len=unroll_len)
assert len(sampled_data) == 3
for grouped_data in sampled_data:
assert len(grouped_data) == 4
# Ensure each group has the same env
env_ids = set(map(lambda sample: sample.meta["env"], grouped_data))
assert len(env_ids) == 1
# Ensure samples in each group is continuous
result = functools.reduce(lambda a, b: a and a.data + 1 == b.data and b, grouped_data)
assert isinstance(result, BufferedData), "Not continuous"
# Ensure data after sliced start from correct index
assert grouped_data[0].data in start_index
47 changes: 47 additions & 0 deletions ding/example/impala.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import gym
from ditk import logging
from ding.model import VAC
from ding.policy import IMPALAPolicy
from ding.envs import SubprocessEnvManagerV2
from ding.data import DequeBuffer
from ding.config import compile_config
from ding.framework import task, ding_init
from ding.framework.context import OnlineRLContext
from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \
CkptSaver, online_logger, termination_checker
from ding.utils import set_pkg_seed
from dizoo.box2d.lunarlander.config.lunarlander_impala_config import main_config, create_config
from dizoo.box2d.lunarlander.envs import LunarLanderEnv


def main():
logging.getLogger().setLevel(logging.INFO)
cfg = compile_config(main_config, create_cfg=create_config, auto=True)
ding_init(cfg)
with task.start(async_mode=False, ctx=OnlineRLContext()):
collector_env = SubprocessEnvManagerV2(
env_fn=[lambda: LunarLanderEnv(cfg.env) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager
)
evaluator_env = SubprocessEnvManagerV2(
env_fn=[lambda: LunarLanderEnv(cfg.env) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager
)
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)

model = VAC(**cfg.policy.model)
buffer_ = DequeBuffer(
size=cfg.policy.other.replay_buffer.replay_buffer_size, sliced=cfg.policy.other.replay_buffer.sliced
)
policy = IMPALAPolicy(cfg.policy, model=model)

task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
task.use(StepCollector(cfg, policy.collect_mode, collector_env, random_collect_size=1024))
task.use(data_pusher(cfg, buffer_, group_by_env=True))
task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
task.use(online_logger(train_show_freq=300))
task.use(CkptSaver(policy, cfg.exp_name, train_freq=10000))
task.use(termination_checker(max_env_step=2e6))
task.run()


if __name__ == "__main__":
main()
44 changes: 18 additions & 26 deletions ding/policy/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
from typing import List, Dict, Any, Tuple

import torch
import treetensor.torch as ttorch

from ding.model import model_wrap
from ding.rl_utils import vtrace_data, vtrace_error_discrete_action, vtrace_error_continuous_action, get_train_sample
from ding.torch_utils import Adam, RMSprop, to_device
from ding.utils import POLICY_REGISTRY
from ding.utils.data import default_collate, default_decollate
from ding.utils.data import default_collate, default_decollate, ttorch_collate
from ding.policy.base_policy import Policy


Expand Down Expand Up @@ -55,7 +56,6 @@ class IMPALAPolicy(Policy):
# (bool) Whether to need policy data in process transition
transition_with_policy_data=True,
learn=dict(

# (int) collect n_sample data, train model update_per_collect times
# here we follow ppo serial pipeline
update_per_collect=4,
Expand Down Expand Up @@ -158,7 +158,13 @@ def _data_preprocess_learn(self, data: List[Dict[str, Any]]):
- done (:obj:`torch.FloatTensor`): :math:`(T, B)`
- weight (:obj:`torch.FloatTensor`): :math:`(T, B)`
"""
data = default_collate(data)
elem = data[0]
if isinstance(elem, dict): # old pipeline
data = default_collate(data)
elif isinstance(elem, list): # new task pipeline
data = default_collate(default_collate(data))
else:
raise TypeError("not support element type ({}) in IMPALA".format(type(elem)))
if self._cuda:
data = to_device(data, self._device)
if self._priority_IS_weight:
Expand All @@ -167,27 +173,11 @@ def _data_preprocess_learn(self, data: List[Dict[str, Any]]):
data['weight'] = data['IS']
else:
data['weight'] = data.get('weight', None)
data['obs_plus_1'] = torch.cat((data['obs'] + data['next_obs'][-1:]), dim=0) # shape (T+1)*B,env_obs_shape
if self._action_space == 'continuous':
data['logit']['mu'] = torch.cat(
data['logit']['mu'], dim=0
).reshape(self._unroll_len, -1, self._action_shape) # shape T,B,env_action_shape
data['logit']['sigma'] = torch.cat(
data['logit']['sigma'], dim=0
).reshape(self._unroll_len, -1, self._action_shape) # shape T,B,env_action_shape
data['action'] = torch.cat(
data['action'], dim=0
).reshape(self._unroll_len, -1, self._action_shape) # shape T,B,env_action_shape
elif self._action_space == 'discrete':
data['logit'] = torch.cat(
data['logit'], dim=0
).reshape(self._unroll_len, -1, self._action_shape) # shape T,B,env_action_shape
data['action'] = torch.cat(data['action'], dim=0).reshape(self._unroll_len, -1) # shape T,B,
data['done'] = torch.cat(data['done'], dim=0).reshape(self._unroll_len, -1).float() # shape T,B,
data['reward'] = torch.cat(data['reward'], dim=0).reshape(self._unroll_len, -1) # shape T,B,
data['weight'] = torch.cat(
data['weight'], dim=0
).reshape(self._unroll_len, -1) if data['weight'] else None # shape T,B
if isinstance(elem, dict): # old pipeline
for k in data:
if isinstance(data[k], list):
data[k] = default_collate(data[k])
data['obs_plus_1'] = torch.cat([data['obs'], data['next_obs'][-1:]], dim=0) # shape (T+1)*B,env_obs_shape
return data

def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
Expand All @@ -213,7 +203,9 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
# IMPALA forward
# ====================
self._learn_model.train()
output = self._learn_model.forward(data['obs_plus_1'], mode='compute_actor_critic')
output = self._learn_model.forward(
data['obs_plus_1'].view((-1, ) + data['obs_plus_1'].shape[2:]), mode='compute_actor_critic'
)
target_logit, behaviour_logit, actions, values, rewards, weights = self._reshape_data(output, data)
# Calculate vtrace error
data = vtrace_data(target_logit, behaviour_logit, actions, values, rewards, weights)
Expand Down Expand Up @@ -276,7 +268,7 @@ def _reshape_data(self, output: Dict[str, Any], data: Dict[str, Any]) -> Tuple[A
actions = data['action'] # shape T,B for discrete # shape T,B,env_action_shape for continuous
values = output['value'].reshape(self._unroll_len + 1, -1) # shape T+1,B,env_action_shape
rewards = data['reward'] # shape T,B
weights_ = 1 - data['done'] # shape T,B
weights_ = 1 - data['done'].float() # shape T,B
weights = torch.ones_like(rewards) # shape T,B
values[1:] = values[1:] * weights_
weights[1:] = weights_[:-1]
Expand Down
24 changes: 13 additions & 11 deletions dizoo/atari/config/serial/pong/pong_impala_config.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from easydict import EasyDict

pong_impala_config = dict(
exp_name='pong_impala_seed0',
exp_name='impala_log/pong_impala_seed0',
env=dict(
collector_env_num=8,
collector_env_num=12,
evaluator_env_num=8,
n_evaluator_episode=8,
stop_value=20,
stop_value=21,
env_id='PongNoFrameskip-v4',
#'ALE/Pong-v5' is available. But special setting is needed after gym make.
frame_stack=4,
Expand All @@ -19,27 +19,29 @@
model=dict(
obs_shape=[4, 84, 84],
action_shape=6,
encoder_hidden_size_list=[128, 128, 512],
critic_head_hidden_size=512,
encoder_hidden_size_list=[64, 128, 256],
critic_head_hidden_size=256,
critic_head_layer_num=2,
actor_head_hidden_size=512,
actor_head_hidden_size=256,
actor_head_layer_num=2,
# impala_cnn_encoder=True,
),
learn=dict(
# (int) collect n_sample data, train model update_per_collect times
# here we follow impala serial pipeline
update_per_collect=10,
update_per_collect=2,
# (int) the number of data for a train iteration
batch_size=128,
# optim_type='rmsprop',
grad_clip_type='clip_norm',
clip_value=0.5,
learning_rate=0.0003,
learning_rate=0.0006,
# (float) loss weight of the value network, the weight of policy network is set to 1
value_weight=0.5,
# (float) loss weight of the entropy regularization, the weight of policy network is set to 1
entropy_weight=0.01,
# (float) discount factor for future reward, defaults int [0, 1]
discount_factor=0.9,
discount_factor=0.99,
# (float) additional discounting parameter
lambda_=0.95,
# (float) clip ratio of importance weights
Expand All @@ -54,8 +56,8 @@
n_sample=16,
collector=dict(collect_print_freq=1000, ),
),
eval=dict(evaluator=dict(eval_freq=5000, )),
other=dict(replay_buffer=dict(replay_buffer_size=10000, ), ),
eval=dict(evaluator=dict(eval_freq=2000, )),
other=dict(replay_buffer=dict(replay_buffer_size=10000, sliced=False), ),
),
)
main_config = EasyDict(pong_impala_config)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from copy import deepcopy
from easydict import EasyDict

spaceinvaders_impala_config = dict(
exp_name='spaceinvaders_impala_seed0',
exp_name='impala_log/spaceinvaders_impala_seed0',
env=dict(
collector_env_num=8,
evaluator_env_num=8,
Expand All @@ -11,7 +10,7 @@
env_id='SpaceInvadersNoFrameskip-v4',
#'ALE/SpaceInvaders-v5' is available. But special setting is needed after gym make.
frame_stack=4,
manager=dict(shared_memory=False, )
# manager=dict(shared_memory=False, )
),
policy=dict(
cuda=True,
Expand All @@ -21,21 +20,21 @@
model=dict(
obs_shape=[4, 84, 84],
action_shape=6,
encoder_hidden_size_list=[128, 128, 256, 512],
critic_head_hidden_size=512,
encoder_hidden_size_list=[128, 128, 256, 256],
critic_head_hidden_size=256,
critic_head_layer_num=3,
actor_head_hidden_size=512,
actor_head_hidden_size=256,
actor_head_layer_num=3,
),
learn=dict(
# (int) collect n_sample data, train model update_per_collect times
# here we follow impala serial pipeline
update_per_collect=3, # update_per_collect show be in [1, 10]
update_per_collect=2, # update_per_collect show be in [1, 10]
# (int) the number of data for a train iteration
batch_size=128,
grad_clip_type='clip_norm',
clip_value=5,
learning_rate=0.0003,
learning_rate=0.0006,
# (float) loss weight of the value network, the weight of policy network is set to 1
value_weight=0.5,
# (float) loss weight of the entropy regularization, the weight of policy network is set to 1
Expand All @@ -56,8 +55,8 @@
n_sample=16,
collector=dict(collect_print_freq=1000, ),
),
eval=dict(evaluator=dict(eval_freq=5000, )),
other=dict(replay_buffer=dict(replay_buffer_size=10000, ), ),
eval=dict(evaluator=dict(eval_freq=500, )),
other=dict(replay_buffer=dict(replay_buffer_size=100000, sliced=True), ),
),
)
spaceinvaders_impala_config = EasyDict(spaceinvaders_impala_config)
Expand Down
Loading

0 comments on commit 9299826

Please sign in to comment.