Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(nyz): add PPOF new interface support #567

Merged
merged 10 commits into from
Jan 3, 2023
1 change: 1 addition & 0 deletions ding/bonus/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .ppof import PPOF
75 changes: 75 additions & 0 deletions ding/bonus/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from easydict import EasyDict
import gym
from ding.envs import BaseEnv, DingEnvWrapper
from ding.policy import PPOFPolicy


def get_instance_config(env: str) -> EasyDict:
cfg = PPOFPolicy.default_config()
if env == 'lunarlander_discrete':
cfg.n_sample = 400
elif env == 'lunarlander_continuous':
cfg.action_space = 'continuous'
cfg.n_sample = 400
elif env == 'rocket_landing':
cfg.n_sample = 2048
cfg.adv_norm = False
cfg.model = dict(
encoder_hidden_size_list=[64, 64, 128],
actor_head_hidden_size=128,
critic_head_hidden_size=128,
)
elif env == 'drone_fly':
cfg.action_space = 'continuous'
cfg.adv_norm = False
cfg.epoch_per_collect = 5
cfg.learning_rate = 5e-5
cfg.n_sample = 640
elif env == 'hybrid_moving':
cfg.action_space = 'hybrid'
cfg.n_sample = 3200
cfg.entropy_weight = 0.03
cfg.batch_size = 320
cfg.adv_norm = False
cfg.model = dict(
encoder_hidden_size_list=[256, 128, 64, 64],
sigma_type='fixed',
fixed_sigma_value=0.3,
bound_type='tanh',
)
else:
raise KeyError("not supported env type: {}".format(env))
return cfg


def get_instance_env(env: str) -> BaseEnv:
if env == 'lunarlander_discrete':
return DingEnvWrapper(gym.make('LunarLander-v2'))
elif env == 'lunarlander_continuous':
return DingEnvWrapper(gym.make('LunarLander-v2', continuous=True))
elif env == 'rocket_landing':
from dizoo.rocket.envs import RocketEnv
cfg = EasyDict({
'task': 'landing',
'max_steps': 800,
})
return RocketEnv(cfg)
elif env == 'drone_fly':
from dizoo.gym_pybullet_drones.envs import GymPybulletDronesEnv
cfg = EasyDict({
'env_id': 'flythrugate-aviary-v0',
'action_type': 'VEL',
})
return GymPybulletDronesEnv(cfg)
elif env == 'hybrid_moving':
import gym_hybrid
return DingEnvWrapper(gym.make('Moving-v0'))
else:
raise KeyError("not supported env type: {}".format(env))


def get_hybrid_shape(action_space) -> EasyDict:
return EasyDict({
'action_type_shape': action_space[0].n,
'action_args_shape': action_space[1].shape,
})
230 changes: 230 additions & 0 deletions ding/bonus/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
from typing import Union, Optional
from easydict import EasyDict
import torch
import torch.nn as nn
import treetensor.torch as ttorch
from copy import deepcopy
from ding.utils import SequenceType, squeeze
from ding.model.common import ReparameterizationHead, RegressionHead, MultiHead, \
FCEncoder, ConvEncoder, IMPALAConvEncoder
from ding.torch_utils import MLP, fc_block


class DiscretePolicyHead(nn.Module):

def __init__(
self,
hidden_size: int,
output_size: int,
layer_num: int = 1,
activation: Optional[nn.Module] = nn.ReLU(),
norm_type: Optional[str] = None,
) -> None:
super(DiscretePolicyHead, self).__init__()
self.main = nn.Sequential(
MLP(
hidden_size,
hidden_size,
hidden_size,
layer_num,
layer_fn=nn.Linear,
activation=activation,
norm_type=norm_type
), fc_block(hidden_size, output_size)
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.main(x)


class PPOFModel(nn.Module):
mode = ['compute_actor', 'compute_critic', 'compute_actor_critic']

def __init__(
self,
obs_shape: Union[int, SequenceType],
action_shape: Union[int, SequenceType, EasyDict],
action_space: str = 'discrete',
share_encoder: bool = True,
encoder_hidden_size_list: SequenceType = [128, 128, 64],
actor_head_hidden_size: int = 64,
actor_head_layer_num: int = 1,
critic_head_hidden_size: int = 64,
critic_head_layer_num: int = 1,
activation: Optional[nn.Module] = nn.ReLU(),
norm_type: Optional[str] = None,
sigma_type: Optional[str] = 'independent',
fixed_sigma_value: Optional[int] = 0.3,
bound_type: Optional[str] = None,
encoder: Optional[torch.nn.Module] = None,
) -> None:
super(PPOFModel, self).__init__()
obs_shape = squeeze(obs_shape)
action_shape = squeeze(action_shape)
self.obs_shape, self.action_shape = obs_shape, action_shape
self.share_encoder = share_encoder

# Encoder Type
def new_encoder(outsize):
if isinstance(obs_shape, int) or len(obs_shape) == 1:
return FCEncoder(
obs_shape=obs_shape,
hidden_size_list=encoder_hidden_size_list,
activation=activation,
norm_type=norm_type
)
elif len(obs_shape) == 3:
return ConvEncoder(
obs_shape=obs_shape,
hidden_size_list=encoder_hidden_size_list,
activation=activation,
norm_type=norm_type
)
else:
raise RuntimeError(
"not support obs_shape for pre-defined encoder: {}, please customize your own encoder".
format(obs_shape)
)

if self.share_encoder:
assert actor_head_hidden_size == critic_head_hidden_size, \
"actor and critic network head should have same size."
if encoder:
if isinstance(encoder, torch.nn.Module):
self.encoder = encoder
else:
raise ValueError("illegal encoder instance.")
else:
self.encoder = new_encoder(actor_head_hidden_size)
else:
if encoder:
if isinstance(encoder, torch.nn.Module):
self.actor_encoder = encoder
self.critic_encoder = deepcopy(encoder)
else:
raise ValueError("illegal encoder instance.")
else:
self.actor_encoder = new_encoder(actor_head_hidden_size)
self.critic_encoder = new_encoder(critic_head_hidden_size)

# Head Type
self.critic_head = RegressionHead(
critic_head_hidden_size, 1, critic_head_layer_num, activation=activation, norm_type=norm_type
)
self.action_space = action_space
assert self.action_space in ['discrete', 'continuous', 'hybrid'], self.action_space
if self.action_space == 'continuous':
self.multi_head = False
self.actor_head = ReparameterizationHead(
actor_head_hidden_size,
action_shape,
actor_head_layer_num,
sigma_type=sigma_type,
activation=activation,
norm_type=norm_type,
bound_type=bound_type
)
elif self.action_space == 'discrete':
actor_head_cls = DiscretePolicyHead
multi_head = not isinstance(action_shape, int)
self.multi_head = multi_head
if multi_head:
self.actor_head = MultiHead(
actor_head_cls,
actor_head_hidden_size,
action_shape,
layer_num=actor_head_layer_num,
activation=activation,
norm_type=norm_type
)
else:
self.actor_head = actor_head_cls(
actor_head_hidden_size,
action_shape,
actor_head_layer_num,
activation=activation,
norm_type=norm_type
)
elif self.action_space == 'hybrid': # HPPO
# hybrid action space: action_type(discrete) + action_args(continuous),
# such as {'action_type_shape': torch.LongTensor([0]), 'action_args_shape': torch.FloatTensor([0.1, -0.27])}
action_shape.action_args_shape = squeeze(action_shape.action_args_shape)
action_shape.action_type_shape = squeeze(action_shape.action_type_shape)
actor_action_args = ReparameterizationHead(
actor_head_hidden_size,
action_shape.action_args_shape,
actor_head_layer_num,
sigma_type=sigma_type,
fixed_sigma_value=fixed_sigma_value,
activation=activation,
norm_type=norm_type,
bound_type=bound_type,
)
actor_action_type = DiscretePolicyHead(
actor_head_hidden_size,
action_shape.action_type_shape,
actor_head_layer_num,
activation=activation,
norm_type=norm_type,
)
self.actor_head = nn.ModuleList([actor_action_type, actor_action_args])

# must use list, not nn.ModuleList
if self.share_encoder:
self.actor = [self.encoder, self.actor_head]
self.critic = [self.encoder, self.critic_head]
else:
self.actor = [self.actor_encoder, self.actor_head]
self.critic = [self.critic_encoder, self.critic_head]
# Convenient for calling some apis (e.g. self.critic.parameters()),
# but may cause misunderstanding when `print(self)`
self.actor = nn.ModuleList(self.actor)
self.critic = nn.ModuleList(self.critic)

def forward(self, inputs: ttorch.Tensor, mode: str) -> ttorch.Tensor:
assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
return getattr(self, mode)(inputs)

def compute_actor(self, x: ttorch.Tensor) -> ttorch.Tensor:
if self.share_encoder:
x = self.encoder(x)
else:
x = self.actor_encoder(x)

if self.action_space == 'discrete':
return self.actor_head(x)
elif self.action_space == 'continuous':
x = self.actor_head(x) # mu, sigma
return ttorch.as_tensor(x)
elif self.action_space == 'hybrid':
action_type = self.actor_head[0](x)
action_args = self.actor_head[1](x)
return ttorch.as_tensor({'action_type': action_type, 'action_args': action_args})

def compute_critic(self, x: ttorch.Tensor) -> ttorch.Tensor:
if self.share_encoder:
x = self.encoder(x)
else:
x = self.critic_encoder(x)
x = self.critic_head(x)
return x['pred']

def compute_actor_critic(self, x: ttorch.Tensor) -> ttorch.Tensor:
if self.share_encoder:
actor_embedding = critic_embedding = self.encoder(x)
else:
actor_embedding = self.actor_encoder(x)
critic_embedding = self.critic_encoder(x)

value = self.critic_head(critic_embedding)['pred']

if self.action_space == 'discrete':
logit = self.actor_head(actor_embedding)
return ttorch.as_tensor({'logit': logit, 'value': value})
elif self.action_space == 'continuous':
x = self.actor_head(actor_embedding)
return ttorch.as_tensor({'logit': x, 'value': value})
elif self.action_space == 'hybrid':
action_type = self.actor_head[0](actor_embedding)
action_args = self.actor_head[1](actor_embedding)
return ttorch.as_tensor({'logit': {'action_type': action_type, 'action_args': action_args}, 'value': value})
Loading