forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
16 changed files
with
1,394 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# 外部如何自定义强化学习Controller | ||
|
||
首先导入必要的依赖: | ||
```python | ||
### 引入强化学习Controller基类函数和注册类函数 | ||
from paddleslim.common.RL_controller.utils import RLCONTROLLER | ||
from paddleslim.common.RL_controller import RLBaseController | ||
``` | ||
|
||
通过装饰器的方式把自定义强化学习Controller注册到PaddleSlim,继承基类之后需要重写基类中的`next_tokens`和`update`两个函数。注意:本示例仅说明一些必不可少的步骤,并不能直接运行,完整代码请参考[这里]() | ||
|
||
```python | ||
### 注意: 类名一定要全部大写 | ||
@RLCONTROLLER.register | ||
class LSTM(RLBaseController): | ||
def __init__(self, range_tables, use_gpu=False, **kwargs): | ||
### range_tables 表示tokens的取值范围 | ||
self.range_tables = range_tables | ||
### use_gpu 表示是否使用gpu来训练controller | ||
self.use_gpu = use_gpu | ||
### 定义一些强化学习算法中需要的参数 | ||
... | ||
### 构造相应的program, _build_program这个函数会构造两个program,一个是pred_program,一个是learn_program, 并初始化参数 | ||
self._build_program() | ||
self.place = fluid.CUDAPlace(0) if self.args.use_gpu else fluid.CPUPlace() | ||
self.exe = fluid.Executor(self.place) | ||
self.exe.run(fluid.default_startup_program()) | ||
|
||
### 保存参数到一个字典中,这个字典由server端统一维护更新,因为可能有多个client同时更新一份参数,所以这一步必不可少,由于pred_program和learn_program使用的同一份参数,所以只需要把learn_program中的参数放入字典中即可 | ||
self.param_dicts = {} | ||
self.param_dicts.update(self.learn_program: self.get_params(self.learn_program)) | ||
|
||
def next_tokens(self, states, params_dict): | ||
### 把从server端获取参数字典赋值给当前要用到的program | ||
self.set_params(self.pred_program, params_dict, self.place) | ||
### 根据states构造输入 | ||
self.num_archs = states | ||
feed_dict = self._create_input() | ||
### 获取当前token | ||
actions = self.exe.run(self.pred_program, feed=feed_dict, fetch_list=self.tokens) | ||
... | ||
return actions | ||
|
||
def update(self, rewards, params_dict=None): | ||
### 把从server端获取参数字典赋值给当前要用到的program | ||
self.set_params(self.learn_program, params_dict, self.place) | ||
### 根据`next_tokens`中的states和`update`中的rewards构造输入 | ||
feed_dict = self._create_input(is_test=False, actual_rewards = rewards) | ||
### 计算当前step的loss | ||
loss = self.exe.run(self.learn_program, feed=feed_dict, fetch_list=[self.loss]) | ||
### 获取当前program的参数并返回,client会把本轮的参数传给server端进行参数更新 | ||
params_dict = self.get_params(self.learn_program) | ||
return params_dict | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License" | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import numpy as np | ||
import parl | ||
from parl import layers | ||
from paddle import fluid | ||
from ..utils import RLCONTROLLER, action_mapping | ||
from ...controller import RLBaseController | ||
from .ddpg_model import DefaultDDPGModel as default_ddpg_model | ||
from .noise import AdaptiveNoiseSpec as default_noise | ||
from parl.utils import ReplayMemory | ||
|
||
__all__ = ['DDPG'] | ||
|
||
|
||
class DDPGAgent(parl.Agent): | ||
def __init__(self, algorithm, obs_dim, act_dim): | ||
assert isinstance(obs_dim, int) | ||
assert isinstance(act_dim, int) | ||
self.obs_dim = obs_dim | ||
self.act_dim = act_dim | ||
super(DDPGAgent, self).__init__(algorithm) | ||
|
||
# Attention: In the beginning, sync target model totally. | ||
self.alg.sync_target(decay=0) | ||
|
||
def build_program(self): | ||
self.pred_program = fluid.Program() | ||
self.learn_program = fluid.Program() | ||
|
||
with fluid.program_guard(self.pred_program): | ||
obs = layers.data( | ||
name='obs', shape=[self.obs_dim], dtype='float32') | ||
self.pred_act = self.alg.predict(obs) | ||
|
||
with fluid.program_guard(self.learn_program): | ||
obs = layers.data( | ||
name='obs', shape=[self.obs_dim], dtype='float32') | ||
act = layers.data( | ||
name='act', shape=[self.act_dim], dtype='float32') | ||
reward = layers.data(name='reward', shape=[], dtype='float32') | ||
next_obs = layers.data( | ||
name='next_obs', shape=[self.obs_dim], dtype='float32') | ||
terminal = layers.data(name='terminal', shape=[], dtype='bool') | ||
_, self.critic_cost = self.alg.learn(obs, act, reward, next_obs, | ||
terminal) | ||
|
||
def predict(self, obs): | ||
obs = np.expand_dims(obs, axis=0) | ||
act = self.fluid_executor.run(self.pred_program, | ||
feed={'obs': obs}, | ||
fetch_list=[self.pred_act])[0] | ||
return act | ||
|
||
def learn(self, obs, act, reward, next_obs, terminal): | ||
feed = { | ||
'obs': obs, | ||
'act': act, | ||
'reward': reward, | ||
'next_obs': next_obs, | ||
'terminal': terminal | ||
} | ||
critic_cost = self.fluid_executor.run(self.learn_program, | ||
feed=feed, | ||
fetch_list=[self.critic_cost])[0] | ||
self.alg.sync_target() | ||
return critic_cost | ||
|
||
|
||
@RLCONTROLLER.register | ||
class DDPG(RLBaseController): | ||
def __init__(self, range_tables, use_gpu=False, **kwargs): | ||
self.use_gpu = use_gpu | ||
self.range_tables = range_tables - np.asarray(1) | ||
self.act_dim = len(self.range_tables) | ||
self.obs_dim = kwargs.get('obs_dim') | ||
self.model = kwargs.get( | ||
'model') if 'model' in kwargs else default_ddpg_model | ||
self.actor_lr = kwargs.get( | ||
'actor_lr') if 'actor_lr' in kwargs else 1e-4 | ||
self.critic_lr = kwargs.get( | ||
'critic_lr') if 'critic_lr' in kwargs else 1e-3 | ||
self.gamma = kwargs.get('gamma') if 'gamma' in kwargs else 0.99 | ||
self.tau = kwargs.get('tau') if 'tau' in kwargs else 0.001 | ||
self.memory_size = kwargs.get( | ||
'memory_size') if 'memory_size' in kwargs else 10 | ||
self.reward_scale = kwargs.get( | ||
'reward_scale') if 'reward_scale' in kwargs else 0.1 | ||
self.batch_size = kwargs.get( | ||
'controller_batch_size') if 'controller_batch_size' in kwargs else 1 | ||
self.actions_noise = kwargs.get( | ||
'actions_noise') if 'actions_noise' in kwargs else default_noise | ||
self.action_dist = 0.0 | ||
self.place = fluid.CUDAPlace(0) if self.use_gpu else fluid.CPUPlace() | ||
|
||
model = self.model(self.act_dim) | ||
|
||
if self.actions_noise: | ||
self.actions_noise = self.actions_noise() | ||
|
||
algorithm = parl.algorithms.DDPG( | ||
model, | ||
gamma=self.gamma, | ||
tau=self.tau, | ||
actor_lr=self.actor_lr, | ||
critic_lr=self.critic_lr) | ||
self.agent = DDPGAgent(algorithm, self.obs_dim, self.act_dim) | ||
self.rpm = ReplayMemory(self.memory_size, self.obs_dim, self.act_dim) | ||
|
||
self.pred_program = self.agent.pred_program | ||
self.learn_program = self.agent.learn_program | ||
self.param_dict = self.get_params(self.learn_program) | ||
|
||
def next_tokens(self, obs, params_dict, is_inference=False): | ||
batch_obs = np.expand_dims(obs, axis=0) | ||
self.set_params(self.pred_program, params_dict, self.place) | ||
actions = self.agent.predict(batch_obs.astype('float32')) | ||
### add noise to action | ||
if self.actions_noise and is_inference == False: | ||
actions_noise = np.clip( | ||
np.random.normal( | ||
actions, scale=self.actions_noise.stdev_curr), | ||
-1.0, | ||
1.0) | ||
self.action_dist = np.mean(np.abs(actions_noise - actions)) | ||
else: | ||
actions_noise = actions | ||
actions_noise = action_mapping(actions_noise, self.range_tables) | ||
return actions_noise | ||
|
||
def _update_noise(self, actions_dist): | ||
self.actions_noise.update(actions_dist) | ||
|
||
def update(self, rewards, params_dict, obs, actions, obs_next, terminal): | ||
self.set_params(self.learn_program, params_dict, self.place) | ||
self.rpm.append(obs, actions, self.reward_scale * rewards, obs_next, | ||
terminal) | ||
if self.actions_noise: | ||
self._update_noise(self.action_dist) | ||
if self.rpm.size() > self.memory_size: | ||
obs, actions, rewards, obs_next, terminal = rpm.sample_batch( | ||
self.batch_size) | ||
self.agent.learn(obs, actions, rewards, obs_next, terminal) | ||
params_dict = self.get_params(self.learn_program) | ||
return params_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License" | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from .DDPGController import * |
Oops, something went wrong.