diff --git a/examples/Baselines/GridDispatch_competition/README.md b/examples/Baselines/GridDispatch_competition/README.md new file mode 100644 index 000000000..c09abd632 --- /dev/null +++ b/examples/Baselines/GridDispatch_competition/README.md @@ -0,0 +1,7 @@ +## Baselines for grid dispatching competition + +Competition link: [国家电网调控AI创新大赛:电网运行组织智能安排](https://aistudio.baidu.com/aistudio/competition/detail/111) + +We provide a distributed SAC baseline based on PARL with paddlepaddle or torch: +- [paddlepaddle baseline](paddle) +- [torch baseline](torch) diff --git a/examples/Baselines/GridDispatch_competition/paddle/README.md b/examples/Baselines/GridDispatch_competition/paddle/README.md new file mode 100644 index 000000000..c1f37463c --- /dev/null +++ b/examples/Baselines/GridDispatch_competition/paddle/README.md @@ -0,0 +1,61 @@ +## SAC baseline for grid dispatching competition + +In this example, we provide a distributed SAC baseline based on PARL and paddlepaddle for the [grid dispatching competition](https://aistudio.baidu.com/aistudio/competition/detail/111) task. + +### Dependencies +* Linux +* python3.6+ +* paddlepaddle >= 2.1.0 +* parl >= 2.0.0 + +### Computing resource requirements +* 1 GPU + 6 CPUs + +### Training + +1. Download the pretrained model (trained with fixed first 288 timesteps data) in the current directory. (filename: `paddle_pretrain_model`) + +[Baidu Pan](https://pan.baidu.com/s/1R-4EWIgNr2YogbJnMXk4Cg) (password: hwkb) + +2. Copy all files of `gridsim` (the competition package) to the current directory. +```bash +# For example: +cp -r /XXX/gridsim/* . +``` + +2. Update the data path for distributed training (Using an absoluate path). +```bash +export PWD=`pwd` +python yml_creator.py --dataset_path $PWD/data +``` + + +3. Set the environment variable of PARL and gridsim. +```bash +export PARL_BACKEND=paddle +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:./lib64 +``` + +4. Start xparl cluster + +```bash +# You can change following `cpu_num` and `args.actor_num` in the train.py based on the CPU number of your machine. +# Note that you only need to start the cluster once. + +xparl start --port 8010 --cpu_num 6 +``` + +5. start training. + +```bash +python train.py --actor_num 6 +``` + +6. Visualize the training curve and other information. +``` +tensorboard --logdir . +``` + +### Performance +The result after training one hour with 1 GPU and 6 CPUs. +![learning curve](https://raw.githubusercontent.com/benchmarking-rl/PARL-experiments/master/Baselines/GridDispatch_competition/paddle/result.png) diff --git a/examples/Baselines/GridDispatch_competition/paddle/env_wrapper.py b/examples/Baselines/GridDispatch_competition/paddle/env_wrapper.py new file mode 100644 index 000000000..c44d2be0a --- /dev/null +++ b/examples/Baselines/GridDispatch_competition/paddle/env_wrapper.py @@ -0,0 +1,153 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gym +import numpy as np +from parl.utils import logger +from Environment.base_env import Environment +from utilize.settings import settings +from utilize.form_action import * + + +class MaxTimestepWrapper(gym.Wrapper): + def __init__(self, env, max_timestep=288): + logger.info("[env type]:{}".format(type(env))) + self.max_timestep = max_timestep + env.observation_space = None + env.reward_range = None + env.metadata = None + gym.Wrapper.__init__(self, env) + + self.timestep = 0 + + def step(self, action, **kwargs): + self.timestep += 1 + obs, reward, done, info = self.env.step(action, **kwargs) + if self.timestep >= self.max_timestep: + done = True + info["timeout"] = True + else: + info["timeout"] = False + return obs, reward, done, info + + def reset(self, **kwargs): + self.timestep = 0 + return self.env.reset(**kwargs) + + +class ObsTransformerWrapper(gym.Wrapper): + def __init__(self, env): + logger.info("[env type]:{}".format(type(env))) + gym.Wrapper.__init__(self, env) + + def _get_obs(self, obs): + # loads + loads = [] + loads.append(obs.load_p) + loads.append(obs.load_q) + loads.append(obs.load_v) + loads = np.concatenate(loads) + + # prods + prods = [] + prods.append(obs.gen_p) + prods.append(obs.gen_q) + prods.append(obs.gen_v) + prods = np.concatenate(prods) + + # rho + rho = np.array(obs.rho) - 1.0 + + next_load = obs.nextstep_load_p + + # action_space + action_space_low = obs.action_space['adjust_gen_p'].low.tolist() + action_space_high = obs.action_space['adjust_gen_p'].high.tolist() + action_space_low[settings.balanced_id] = 0.0 + action_space_high[settings.balanced_id] = 0.0 + + features = np.concatenate([ + loads, prods, + rho.tolist(), next_load, action_space_low, action_space_high + ]) + + return features + + def step(self, action, **kwargs): + self.raw_obs, reward, done, info = self.env.step(action, **kwargs) + obs = self._get_obs(self.raw_obs) + return obs, reward, done, info + + def reset(self, **kwargs): + self.raw_obs = self.env.reset(**kwargs) + obs = self._get_obs(self.raw_obs) + return obs + + +class RewardShapingWrapper(gym.Wrapper): + def __init__(self, env): + logger.info("[env type]:{}".format(type(env))) + gym.Wrapper.__init__(self, env) + + def step(self, action, **kwargs): + obs, reward, done, info = self.env.step(action, **kwargs) + + shaping_reward = 1.0 + + info["origin_reward"] = reward + + return obs, shaping_reward, done, info + + def reset(self, **kwargs): + return self.env.reset(**kwargs) + + +class ActionWrapper(gym.Wrapper): + def __init__(self, env, raw_env): + logger.info("[env type]:{}".format(type(env))) + gym.Wrapper.__init__(self, env) + self.raw_env = raw_env + self.v_action = np.zeros(self.raw_env.settings.num_gen) + + def step(self, action, **kwargs): + N = len(action) + + gen_p_action_space = self.env.raw_obs.action_space['adjust_gen_p'] + + low_bound = gen_p_action_space.low + high_bound = gen_p_action_space.high + + mapped_action = low_bound + (action - (-1.0)) * ( + (high_bound - low_bound) / 2.0) + mapped_action[self.raw_env.settings.balanced_id] = 0.0 + mapped_action = np.clip(mapped_action, low_bound, high_bound) + + ret_action = form_action(mapped_action, self.v_action) + return self.env.step(ret_action, **kwargs) + + def reset(self, **kwargs): + return self.env.reset(**kwargs) + + +def get_env(): + env = Environment(settings, "EPRIReward") + env.action_space = None + raw_env = env + + env = MaxTimestepWrapper(env) + env = RewardShapingWrapper(env) + env = ObsTransformerWrapper(env) + env = ActionWrapper(env, raw_env) + + return env diff --git a/examples/Baselines/GridDispatch_competition/paddle/grid_agent.py b/examples/Baselines/GridDispatch_competition/paddle/grid_agent.py new file mode 100644 index 000000000..fa4010244 --- /dev/null +++ b/examples/Baselines/GridDispatch_competition/paddle/grid_agent.py @@ -0,0 +1,49 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import parl +import paddle +import numpy as np + + +class GridAgent(parl.Agent): + def __init__(self, algorithm): + super(GridAgent, self).__init__(algorithm) + + self.alg.sync_target(decay=0) + + def predict(self, obs): + obs = paddle.to_tensor(obs.reshape(1, -1), dtype='float32') + action = self.alg.predict(obs) + action_numpy = action.cpu().numpy()[0] + return action_numpy + + def sample(self, obs): + obs = paddle.to_tensor(obs.reshape(1, -1), dtype='float32') + action, _ = self.alg.sample(obs) + action_numpy = action.cpu().numpy()[0] + return action_numpy + + def learn(self, obs, action, reward, next_obs, terminal): + terminal = np.expand_dims(terminal, -1) + reward = np.expand_dims(reward, -1) + + obs = paddle.to_tensor(obs, dtype='float32') + action = paddle.to_tensor(action, dtype='float32') + reward = paddle.to_tensor(reward, dtype='float32') + next_obs = paddle.to_tensor(next_obs, dtype='float32') + terminal = paddle.to_tensor(terminal, dtype='float32') + critic_loss, actor_loss = self.alg.learn(obs, action, reward, next_obs, + terminal) + return critic_loss.cpu().numpy()[0], actor_loss.cpu().numpy()[0] diff --git a/examples/Baselines/GridDispatch_competition/paddle/grid_model.py b/examples/Baselines/GridDispatch_competition/paddle/grid_model.py new file mode 100644 index 000000000..10cc4537d --- /dev/null +++ b/examples/Baselines/GridDispatch_competition/paddle/grid_model.py @@ -0,0 +1,89 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import parl +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +# clamp bounds for Std of action_log +LOG_SIG_MAX = 2.0 +LOG_SIG_MIN = -20.0 + + +class GridModel(parl.Model): + def __init__(self, obs_dim, action_dim): + super(GridModel, self).__init__() + self.actor_model = Actor(obs_dim, action_dim) + self.critic_model = Critic(obs_dim, action_dim) + + def policy(self, obs): + return self.actor_model(obs) + + def value(self, obs, action): + return self.critic_model(obs, action) + + def get_actor_params(self): + return self.actor_model.parameters() + + def get_critic_params(self): + return self.critic_model.parameters() + + +class Actor(parl.Model): + def __init__(self, obs_dim, action_dim): + super(Actor, self).__init__() + + self.l1 = nn.Linear(obs_dim, 512) + self.l2 = nn.Linear(512, 256) + self.mean_linear = nn.Linear(256, action_dim) + self.std_linear = nn.Linear(256, action_dim) + + def forward(self, obs): + x = F.relu(self.l1(obs)) + x = F.relu(self.l2(x)) + + act_mean = self.mean_linear(x) + act_std = self.std_linear(x) + act_log_std = paddle.clip(act_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX) + return act_mean, act_log_std + + +class Critic(parl.Model): + def __init__(self, obs_dim, action_dim): + super(Critic, self).__init__() + + # Q1 network + self.l1 = nn.Linear(obs_dim + action_dim, 512) + self.l2 = nn.Linear(512, 256) + self.l3 = nn.Linear(256, 1) + + # Q2 network + self.l4 = nn.Linear(obs_dim + action_dim, 512) + self.l5 = nn.Linear(512, 256) + self.l6 = nn.Linear(256, 1) + + def forward(self, obs, action): + x = paddle.concat([obs, action], 1) + + # Q1 + q1 = F.relu(self.l1(x)) + q1 = F.relu(self.l2(q1)) + q1 = self.l3(q1) + + # Q2 + q2 = F.relu(self.l4(x)) + q2 = F.relu(self.l5(q2)) + q2 = self.l6(q2) + return q1, q2 diff --git a/examples/Baselines/GridDispatch_competition/paddle/train.py b/examples/Baselines/GridDispatch_competition/paddle/train.py new file mode 100644 index 000000000..763487c4d --- /dev/null +++ b/examples/Baselines/GridDispatch_competition/paddle/train.py @@ -0,0 +1,221 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +os.environ['PARL_BACKEND'] = 'paddle' + +import numpy as np +import argparse +import threading +import time +import parl +from parl.utils import logger, tensorboard, ReplayMemory +from grid_model import GridModel +from grid_agent import GridAgent +from parl.algorithms import SAC +from env_wrapper import get_env + +WARMUP_STEPS = 1e4 +MEMORY_SIZE = int(1e6) +BATCH_SIZE = 256 +GAMMA = 0.99 +TAU = 0.005 +ACTOR_LR = 3e-4 +CRITIC_LR = 3e-4 +OBS_DIM = 819 +ACT_DIM = 54 + + +@parl.remote_class +class Actor(object): + def __init__(self, args): + self.env = get_env() + + obs_dim = OBS_DIM + action_dim = ACT_DIM + self.action_dim = action_dim + + # Initialize model, algorithm, agent, replay_memory + model = GridModel(obs_dim, action_dim) + algorithm = SAC( + model, + gamma=GAMMA, + tau=TAU, + alpha=args.alpha, + actor_lr=ACTOR_LR, + critic_lr=CRITIC_LR) + self.agent = GridAgent(algorithm) + + def sample(self, weights, random_action): + self.agent.set_weights(weights) + + obs = self.env.reset() + + done = False + episode_reward, episode_steps = 0, 0 + sample_data = [] + while not done: + # Select action randomly or according to policy + if random_action: + action = np.random.uniform(-1, 1, size=self.action_dim) + else: + action = self.agent.sample(obs) + + # Perform action + next_obs, reward, done, info = self.env.step(action) + terminal = done and not info['timeout'] + terminal = float(terminal) + + sample_data.append((obs, action, reward, next_obs, terminal)) + + obs = next_obs + episode_reward += info['origin_reward'] + episode_steps += 1 + + return sample_data, episode_steps, episode_reward + + +class Learner(object): + def __init__(self, args): + self.model_lock = threading.Lock() + self.rpm_lock = threading.Lock() + self.log_lock = threading.Lock() + + self.args = args + + obs_dim = OBS_DIM + action_dim = ACT_DIM + + # Initialize model, algorithm, agent, replay_memory + model = GridModel(obs_dim, action_dim) + algorithm = SAC( + model, + gamma=GAMMA, + tau=TAU, + alpha=args.alpha, + actor_lr=ACTOR_LR, + critic_lr=CRITIC_LR) + self.agent = GridAgent(algorithm) + + self.agent.restore("./paddle_pretrain_model") + + self.rpm = ReplayMemory( + max_size=MEMORY_SIZE, obs_dim=obs_dim, act_dim=action_dim) + + self.total_steps = 0 + self.total_MDP_steps = 0 + self.save_cnt = 0 + + parl.connect( + args.xparl_addr, + distributed_files=[ + 'lib64/*', + 'model_jm/*', + 'Agent/*', + 'Environment/*', + 'Observation/*', + 'Reward/*', + 'utilize/*', + ]) + for _ in range(args.actor_num): + th = threading.Thread(target=self.run_sampling) + th.setDaemon(True) + th.start() + + def run_sampling(self): + actor = Actor(self.args) + while True: + start = time.time() + weights = None + with self.model_lock: + weights = self.agent.get_weights() + + random_action = False + if self.rpm.size() < WARMUP_STEPS: + random_action = True + + sample_data, episode_steps, episode_reward = actor.sample( + weights, random_action) + + # Store data in replay memory + with self.rpm_lock: + for data in sample_data: + self.rpm.append(*data) + + sample_time = time.time() - start + start = time.time() + + critic_loss, actor_loss = None, None + # Train agent after collecting sufficient data + if self.rpm.size() >= WARMUP_STEPS: + for _ in range(len(sample_data)): + with self.rpm_lock: + batch_obs, batch_action, batch_reward, batch_next_obs, batch_terminal = self.rpm.sample_batch( + BATCH_SIZE) + with self.model_lock: + critic_loss, actor_loss = self.agent.learn( + batch_obs, batch_action, batch_reward, + batch_next_obs, batch_terminal) + learn_time = time.time() - start + + mean_action = np.mean( + np.array([x[1] for x in sample_data]), axis=0) + + with self.log_lock: + self.total_steps += episode_steps + self.total_MDP_steps += len(sample_data) + tensorboard.add_scalar('train/episode_reward', episode_reward, + self.total_steps) + tensorboard.add_scalar('train/episode_steps', episode_steps, + self.total_steps) + if critic_loss is not None: + tensorboard.add_scalar('train/critic_loss', critic_loss, + self.total_steps) + tensorboard.add_scalar('train/actor_loss', actor_loss, + self.total_steps) + logger.info('Total Steps: {} Reward: {} Steps: {}'.format( + self.total_steps, episode_reward, episode_steps)) + + if self.total_steps // self.args.save_every_steps >= self.save_cnt: + while self.total_steps // self.args.save_every_steps >= self.save_cnt: + self.save_cnt += 1 + with self.model_lock: + self.agent.save( + os.path.join(self.args.save_dir, + "model-{}".format(self.total_steps))) + + +def main(): + learner = Learner(args) + + while True: + time.sleep(1) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--save_every_steps', type=int, default=10000) + parser.add_argument( + "--alpha", + default=0.2, + type=float, + help= + 'Determines the relative importance of entropy term against the reward' + ) + parser.add_argument('--xparl_addr', type=str, default="localhost:8010") + parser.add_argument('--actor_num', type=int, default=1) + parser.add_argument('--save_dir', type=str, default="./saved_models") + args = parser.parse_args() + + main() diff --git a/examples/Baselines/GridDispatch_competition/torch/README.md b/examples/Baselines/GridDispatch_competition/torch/README.md new file mode 100644 index 000000000..626983ae3 --- /dev/null +++ b/examples/Baselines/GridDispatch_competition/torch/README.md @@ -0,0 +1,61 @@ +## SAC baseline for grid dispatching competition + +In this example, we provide a distributed SAC baseline based on PARL and torch for the [grid dispatching competition](https://aistudio.baidu.com/aistudio/competition/detail/111) task. + +### Dependencies +* Linux +* python3.6+ +* torch == 1.6.0 +* parl >= 2.0.0 + +### Computing resource requirements +* 1 GPU + 6 CPUs + +### Training + +1. Download the pretrained model (trained with fixed first 288 timesteps data) in the current directory. (filename: `torch_pretrain_model`) + +[Baidu Pan](https://pan.baidu.com/s/1Pqv9i9byOzqStcHdttOlRA) (password: n9qc) + +2. Copy all files of `gridsim` (the competition package) to the current directory. +```bash +# For example: +cp -r /XXX/gridsim/* . +``` + +2. Update the data path for distributed training (Using an absoluate path). +```bash +export PWD=`pwd` +python yml_creator.py --dataset_path $PWD/data +``` + + +3. Set the environment variable of PARL and gridsim. +```bash +export PARL_BACKEND=torch +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:./lib64 +``` + +4. Start xparl cluster + +```bash +# You can change following `cpu_num` and `args.actor_num` in the train.py based on the CPU number of your machine. +# Note that you only need to start the cluster once. + +xparl start --port 8010 --cpu_num 6 +``` + +5. start training. + +```bash +python train.py --actor_num 6 +``` + +6. Visualize the training curve and other information. +``` +tensorboard --logdir . +``` + +### Performance +The result after training one hour with 1 GPU and 6 CPUs. +![learning curve](https://raw.githubusercontent.com/benchmarking-rl/PARL-experiments/master/Baselines/GridDispatch_competition/torch/result.png) diff --git a/examples/Baselines/GridDispatch_competition/torch/env_wrapper.py b/examples/Baselines/GridDispatch_competition/torch/env_wrapper.py new file mode 100644 index 000000000..c44d2be0a --- /dev/null +++ b/examples/Baselines/GridDispatch_competition/torch/env_wrapper.py @@ -0,0 +1,153 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gym +import numpy as np +from parl.utils import logger +from Environment.base_env import Environment +from utilize.settings import settings +from utilize.form_action import * + + +class MaxTimestepWrapper(gym.Wrapper): + def __init__(self, env, max_timestep=288): + logger.info("[env type]:{}".format(type(env))) + self.max_timestep = max_timestep + env.observation_space = None + env.reward_range = None + env.metadata = None + gym.Wrapper.__init__(self, env) + + self.timestep = 0 + + def step(self, action, **kwargs): + self.timestep += 1 + obs, reward, done, info = self.env.step(action, **kwargs) + if self.timestep >= self.max_timestep: + done = True + info["timeout"] = True + else: + info["timeout"] = False + return obs, reward, done, info + + def reset(self, **kwargs): + self.timestep = 0 + return self.env.reset(**kwargs) + + +class ObsTransformerWrapper(gym.Wrapper): + def __init__(self, env): + logger.info("[env type]:{}".format(type(env))) + gym.Wrapper.__init__(self, env) + + def _get_obs(self, obs): + # loads + loads = [] + loads.append(obs.load_p) + loads.append(obs.load_q) + loads.append(obs.load_v) + loads = np.concatenate(loads) + + # prods + prods = [] + prods.append(obs.gen_p) + prods.append(obs.gen_q) + prods.append(obs.gen_v) + prods = np.concatenate(prods) + + # rho + rho = np.array(obs.rho) - 1.0 + + next_load = obs.nextstep_load_p + + # action_space + action_space_low = obs.action_space['adjust_gen_p'].low.tolist() + action_space_high = obs.action_space['adjust_gen_p'].high.tolist() + action_space_low[settings.balanced_id] = 0.0 + action_space_high[settings.balanced_id] = 0.0 + + features = np.concatenate([ + loads, prods, + rho.tolist(), next_load, action_space_low, action_space_high + ]) + + return features + + def step(self, action, **kwargs): + self.raw_obs, reward, done, info = self.env.step(action, **kwargs) + obs = self._get_obs(self.raw_obs) + return obs, reward, done, info + + def reset(self, **kwargs): + self.raw_obs = self.env.reset(**kwargs) + obs = self._get_obs(self.raw_obs) + return obs + + +class RewardShapingWrapper(gym.Wrapper): + def __init__(self, env): + logger.info("[env type]:{}".format(type(env))) + gym.Wrapper.__init__(self, env) + + def step(self, action, **kwargs): + obs, reward, done, info = self.env.step(action, **kwargs) + + shaping_reward = 1.0 + + info["origin_reward"] = reward + + return obs, shaping_reward, done, info + + def reset(self, **kwargs): + return self.env.reset(**kwargs) + + +class ActionWrapper(gym.Wrapper): + def __init__(self, env, raw_env): + logger.info("[env type]:{}".format(type(env))) + gym.Wrapper.__init__(self, env) + self.raw_env = raw_env + self.v_action = np.zeros(self.raw_env.settings.num_gen) + + def step(self, action, **kwargs): + N = len(action) + + gen_p_action_space = self.env.raw_obs.action_space['adjust_gen_p'] + + low_bound = gen_p_action_space.low + high_bound = gen_p_action_space.high + + mapped_action = low_bound + (action - (-1.0)) * ( + (high_bound - low_bound) / 2.0) + mapped_action[self.raw_env.settings.balanced_id] = 0.0 + mapped_action = np.clip(mapped_action, low_bound, high_bound) + + ret_action = form_action(mapped_action, self.v_action) + return self.env.step(ret_action, **kwargs) + + def reset(self, **kwargs): + return self.env.reset(**kwargs) + + +def get_env(): + env = Environment(settings, "EPRIReward") + env.action_space = None + raw_env = env + + env = MaxTimestepWrapper(env) + env = RewardShapingWrapper(env) + env = ObsTransformerWrapper(env) + env = ActionWrapper(env, raw_env) + + return env diff --git a/examples/Baselines/GridDispatch_competition/torch/grid_agent.py b/examples/Baselines/GridDispatch_competition/torch/grid_agent.py new file mode 100644 index 000000000..3da36f1d0 --- /dev/null +++ b/examples/Baselines/GridDispatch_competition/torch/grid_agent.py @@ -0,0 +1,52 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import parl +import torch +import numpy as np + + +class GridAgent(parl.Agent): + def __init__(self, algorithm): + super(GridAgent, self).__init__(algorithm) + + self.device = torch.device("cuda" if torch.cuda. + is_available() else "cpu") + + self.alg.sync_target(decay=0) + + def predict(self, obs): + obs = torch.FloatTensor(obs.reshape(1, -1)).to(self.device) + action = self.alg.predict(obs) + action_numpy = action.cpu().detach().numpy().flatten() + return action_numpy + + def sample(self, obs): + obs = torch.FloatTensor(obs.reshape(1, -1)).to(self.device) + action, _ = self.alg.sample(obs) + action_numpy = action.cpu().detach().numpy().flatten() + return action_numpy + + def learn(self, obs, action, reward, next_obs, terminal): + terminal = np.expand_dims(terminal, -1) + reward = np.expand_dims(reward, -1) + + obs = torch.FloatTensor(obs).to(self.device) + action = torch.FloatTensor(action).to(self.device) + reward = torch.FloatTensor(reward).to(self.device) + next_obs = torch.FloatTensor(next_obs).to(self.device) + terminal = torch.FloatTensor(terminal).to(self.device) + critic_loss, actor_loss = self.alg.learn(obs, action, reward, next_obs, + terminal) + return critic_loss, actor_loss diff --git a/examples/Baselines/GridDispatch_competition/torch/grid_model.py b/examples/Baselines/GridDispatch_competition/torch/grid_model.py new file mode 100644 index 000000000..144fc8b82 --- /dev/null +++ b/examples/Baselines/GridDispatch_competition/torch/grid_model.py @@ -0,0 +1,89 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import parl +import torch +import torch.nn as nn +import torch.nn.functional as F + +# clamp bounds for Std of action_log +LOG_SIG_MAX = 2.0 +LOG_SIG_MIN = -20.0 + + +class GridModel(parl.Model): + def __init__(self, obs_dim, action_dim): + super(GridModel, self).__init__() + self.actor_model = Actor(obs_dim, action_dim) + self.critic_model = Critic(obs_dim, action_dim) + + def policy(self, obs): + return self.actor_model(obs) + + def value(self, obs, action): + return self.critic_model(obs, action) + + def get_actor_params(self): + return self.actor_model.parameters() + + def get_critic_params(self): + return self.critic_model.parameters() + + +class Actor(parl.Model): + def __init__(self, obs_dim, action_dim): + super(Actor, self).__init__() + + self.l1 = nn.Linear(obs_dim, 512) + self.l2 = nn.Linear(512, 256) + self.mean_linear = nn.Linear(256, action_dim) + self.std_linear = nn.Linear(256, action_dim) + + def forward(self, obs): + x = F.relu(self.l1(obs)) + x = F.relu(self.l2(x)) + + act_mean = self.mean_linear(x) + act_std = self.std_linear(x) + act_log_std = torch.clamp(act_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX) + return act_mean, act_log_std + + +class Critic(parl.Model): + def __init__(self, obs_dim, action_dim): + super(Critic, self).__init__() + + # Q1 network + self.l1 = nn.Linear(obs_dim + action_dim, 512) + self.l2 = nn.Linear(512, 256) + self.l3 = nn.Linear(256, 1) + + # Q2 network + self.l4 = nn.Linear(obs_dim + action_dim, 512) + self.l5 = nn.Linear(512, 256) + self.l6 = nn.Linear(256, 1) + + def forward(self, obs, action): + x = torch.cat([obs, action], 1) + + # Q1 + q1 = F.relu(self.l1(x)) + q1 = F.relu(self.l2(q1)) + q1 = self.l3(q1) + + # Q2 + q2 = F.relu(self.l4(x)) + q2 = F.relu(self.l5(q2)) + q2 = self.l6(q2) + return q1, q2 diff --git a/examples/Baselines/GridDispatch_competition/torch/train.py b/examples/Baselines/GridDispatch_competition/torch/train.py new file mode 100644 index 000000000..6058997d4 --- /dev/null +++ b/examples/Baselines/GridDispatch_competition/torch/train.py @@ -0,0 +1,221 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +os.environ['PARL_BACKEND'] = 'torch' + +import numpy as np +import argparse +import threading +import time +import parl +from parl.utils import logger, tensorboard, ReplayMemory +from grid_model import GridModel +from grid_agent import GridAgent +from parl.algorithms import SAC +from env_wrapper import get_env + +WARMUP_STEPS = 1e4 +MEMORY_SIZE = int(1e6) +BATCH_SIZE = 256 +GAMMA = 0.99 +TAU = 0.005 +ACTOR_LR = 3e-4 +CRITIC_LR = 3e-4 +OBS_DIM = 819 +ACT_DIM = 54 + + +@parl.remote_class +class Actor(object): + def __init__(self, args): + self.env = get_env() + + obs_dim = OBS_DIM + action_dim = ACT_DIM + self.action_dim = action_dim + + # Initialize model, algorithm, agent, replay_memory + model = GridModel(obs_dim, action_dim) + algorithm = SAC( + model, + gamma=GAMMA, + tau=TAU, + alpha=args.alpha, + actor_lr=ACTOR_LR, + critic_lr=CRITIC_LR) + self.agent = GridAgent(algorithm) + + def sample(self, weights, random_action): + self.agent.set_weights(weights) + + obs = self.env.reset() + + done = False + episode_reward, episode_steps = 0, 0 + sample_data = [] + while not done: + # Select action randomly or according to policy + if random_action: + action = np.random.uniform(-1, 1, size=self.action_dim) + else: + action = self.agent.sample(obs) + + # Perform action + next_obs, reward, done, info = self.env.step(action) + terminal = done and not info['timeout'] + terminal = float(terminal) + + sample_data.append((obs, action, reward, next_obs, terminal)) + + obs = next_obs + episode_reward += info['origin_reward'] + episode_steps += 1 + + return sample_data, episode_steps, episode_reward + + +class Learner(object): + def __init__(self, args): + self.model_lock = threading.Lock() + self.rpm_lock = threading.Lock() + self.log_lock = threading.Lock() + + self.args = args + + obs_dim = OBS_DIM + action_dim = ACT_DIM + + # Initialize model, algorithm, agent, replay_memory + model = GridModel(obs_dim, action_dim) + algorithm = SAC( + model, + gamma=GAMMA, + tau=TAU, + alpha=args.alpha, + actor_lr=ACTOR_LR, + critic_lr=CRITIC_LR) + self.agent = GridAgent(algorithm) + + self.agent.restore("./torch_pretrain_model") + + self.rpm = ReplayMemory( + max_size=MEMORY_SIZE, obs_dim=obs_dim, act_dim=action_dim) + + self.total_steps = 0 + self.total_MDP_steps = 0 + self.save_cnt = 0 + + parl.connect( + args.xparl_addr, + distributed_files=[ + 'lib64/*', + 'model_jm/*', + 'Agent/*', + 'Environment/*', + 'Observation/*', + 'Reward/*', + 'utilize/*', + ]) + for _ in range(args.actor_num): + th = threading.Thread(target=self.run_sampling) + th.setDaemon(True) + th.start() + + def run_sampling(self): + actor = Actor(self.args) + while True: + start = time.time() + weights = None + with self.model_lock: + weights = self.agent.get_weights() + + random_action = False + if self.rpm.size() < WARMUP_STEPS: + random_action = True + + sample_data, episode_steps, episode_reward = actor.sample( + weights, random_action) + + # Store data in replay memory + with self.rpm_lock: + for data in sample_data: + self.rpm.append(*data) + + sample_time = time.time() - start + start = time.time() + + critic_loss, actor_loss = None, None + # Train agent after collecting sufficient data + if self.rpm.size() >= WARMUP_STEPS: + for _ in range(len(sample_data)): + with self.rpm_lock: + batch_obs, batch_action, batch_reward, batch_next_obs, batch_terminal = self.rpm.sample_batch( + BATCH_SIZE) + with self.model_lock: + critic_loss, actor_loss = self.agent.learn( + batch_obs, batch_action, batch_reward, + batch_next_obs, batch_terminal) + learn_time = time.time() - start + + mean_action = np.mean( + np.array([x[1] for x in sample_data]), axis=0) + + with self.log_lock: + self.total_steps += episode_steps + self.total_MDP_steps += len(sample_data) + tensorboard.add_scalar('train/episode_reward', episode_reward, + self.total_steps) + tensorboard.add_scalar('train/episode_steps', episode_steps, + self.total_steps) + if critic_loss is not None: + tensorboard.add_scalar('train/critic_loss', critic_loss, + self.total_steps) + tensorboard.add_scalar('train/actor_loss', actor_loss, + self.total_steps) + logger.info('Total Steps: {} Reward: {} Steps: {}'.format( + self.total_steps, episode_reward, episode_steps)) + + if self.total_steps // self.args.save_every_steps >= self.save_cnt: + while self.total_steps // self.args.save_every_steps >= self.save_cnt: + self.save_cnt += 1 + with self.model_lock: + self.agent.save( + os.path.join(self.args.save_dir, + "model-{}".format(self.total_steps))) + + +def main(): + learner = Learner(args) + + while True: + time.sleep(1) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--save_every_steps', type=int, default=10000) + parser.add_argument( + "--alpha", + default=0.2, + type=float, + help= + 'Determines the relative importance of entropy term against the reward' + ) + parser.add_argument('--xparl_addr', type=str, default="localhost:8010") + parser.add_argument('--actor_num', type=int, default=1) + parser.add_argument('--save_dir', type=str, default="./saved_models") + args = parser.parse_args() + + main()