Skip to content

Commit

Permalink
Merge pull request #2 from thu-ml/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
ChenDRAG authored Jul 17, 2020
2 parents db0e2e5 + f8ad6df commit 979dc2d
Show file tree
Hide file tree
Showing 47 changed files with 1,295 additions and 993 deletions.
1 change: 1 addition & 0 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ on: [push, pull_request]
jobs:
build:
runs-on: ubuntu-latest
if: "!contains(github.event.head_commit.message, 'ci skip')"
strategy:
matrix:
python-version: [3.6, 3.7, 3.8]
Expand Down
18 changes: 2 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -206,26 +206,12 @@ test_envs = ts.env.VectorEnv([lambda: gym.make(task) for _ in range(test_num)])
Define the network:

```python
class Net(nn.Module):
def __init__(self, state_shape, action_shape):
super().__init__()
self.model = nn.Sequential(*[
nn.Linear(np.prod(state_shape), 128), nn.ReLU(inplace=True),
nn.Linear(128, 128), nn.ReLU(inplace=True),
nn.Linear(128, 128), nn.ReLU(inplace=True),
nn.Linear(128, np.prod(action_shape))
])
def forward(self, s, state=None, info={}):
if not isinstance(s, torch.Tensor):
s = torch.tensor(s, dtype=torch.float)
batch = s.shape[0]
logits = self.model(s.view(batch, -1))
return logits, state
from tianshou.utils.net.common import Net

env = gym.make(task)
state_shape = env.observation_space.shape or env.observation_space.n
action_shape = env.action_space.shape or env.action_space.n
net = Net(state_shape, action_shape)
net = Net(layer_num=2, state_shape=state_shape, action_shape=action_shape)
optim = torch.optim.Adam(net.parameters(), lr=lr)
```

Expand Down
15 changes: 15 additions & 0 deletions docs/api/tianshou.utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,18 @@ tianshou.utils
:members:
:undoc-members:
:show-inheritance:

.. automodule:: tianshou.utils.net.common
:members:
:undoc-members:
:show-inheritance:

.. automodule:: tianshou.utils.net.discrete
:members:
:undoc-members:
:show-inheritance:

.. automodule:: tianshou.utils.net.continuous
:members:
:undoc-members:
:show-inheritance:
2 changes: 1 addition & 1 deletion docs/tutorials/dqn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ Tianshou supports any user-defined PyTorch networks and optimizers but with the
net = Net(state_shape, action_shape)
optim = torch.optim.Adam(net.parameters(), lr=1e-3)

The rules of self-defined networks are:
You can also have a try with those pre-defined networks in :mod:`~tianshou.utils.net.common`, :mod:`~tianshou.utils.net.discrete`, and :mod:`~tianshou.utils.net.continuous`. The rules of self-defined networks are:

1. Input: observation ``obs`` (may be a ``numpy.ndarray``, ``torch.Tensor``, dict, or self-defined class), hidden state ``state`` (for RNN usage), and other information ``info`` provided by the environment.
2. Output: some ``logits``, the next hidden state ``state``, and intermediate result during the policy forwarding procedure ``policy``. The logits could be a tuple instead of a ``torch.Tensor``. It depends on how the policy process the network output. For example, in PPO :cite:`PPO`, the return of the network might be ``(mu, sigma), state`` for Gaussian policy. The ``policy`` can be a Batch of torch.Tensor or other things, which will be stored in the replay buffer, and can be accessed in the policy update process (e.g. in ``policy.learn()``, the ``batch.policy`` is what you need).
Expand Down
17 changes: 8 additions & 9 deletions examples/ant_v2_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from tianshou.data import Collector, ReplayBuffer
from tianshou.env import VectorEnv, SubprocVectorEnv
from tianshou.exploration import GaussianNoise

from continuous_net import Actor, Critic
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import Actor, Critic


def get_args():
Expand Down Expand Up @@ -57,14 +57,13 @@ def test_ddpg(args=get_args()):
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
actor = Actor(
args.layer_num, args.state_shape, args.action_shape,
args.max_action, args.device
).to(args.device)
net = Net(args.layer_num, args.state_shape, device=args.device)
actor = Actor(net, args.action_shape, args.max_action,
args.device).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
critic = Critic(
args.layer_num, args.state_shape, args.action_shape, args.device
).to(args.device)
net = Net(args.layer_num, args.state_shape,
args.action_shape, concat=True, device=args.device)
critic = Critic(net, args.device).to(args.device)
critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
policy = DDPGPolicy(
actor, actor_optim, critic, critic_optim,
Expand Down
17 changes: 8 additions & 9 deletions examples/ant_v2_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.env import VectorEnv, SubprocVectorEnv

from continuous_net import ActorProb, Critic
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import ActorProb, Critic


def get_args():
Expand Down Expand Up @@ -58,18 +58,17 @@ def test_sac(args=get_args()):
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.layer_num, args.state_shape, device=args.device)
actor = ActorProb(
args.layer_num, args.state_shape, args.action_shape,
net, args.action_shape,
args.max_action, args.device, unbounded=True
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
critic1 = Critic(
args.layer_num, args.state_shape, args.action_shape, args.device
).to(args.device)
net = Net(args.layer_num, args.state_shape,
args.action_shape, concat=True, device=args.device)
critic1 = Critic(net, args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
critic2 = Critic(
args.layer_num, args.state_shape, args.action_shape, args.device
).to(args.device)
critic2 = Critic(net, args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = SACPolicy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
Expand Down
17 changes: 8 additions & 9 deletions examples/ant_v2_td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from tianshou.data import Collector, ReplayBuffer
from tianshou.env import VectorEnv, SubprocVectorEnv
from tianshou.exploration import GaussianNoise

from continuous_net import Actor, Critic
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import Actor, Critic


def get_args():
Expand Down Expand Up @@ -60,18 +60,17 @@ def test_td3(args=get_args()):
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.layer_num, args.state_shape, device=args.device)
actor = Actor(
args.layer_num, args.state_shape, args.action_shape,
net, args.action_shape,
args.max_action, args.device
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
critic1 = Critic(
args.layer_num, args.state_shape, args.action_shape, args.device
).to(args.device)
net = Net(args.layer_num, args.state_shape,
args.action_shape, concat=True, device=args.device)
critic1 = Critic(net, args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
critic2 = Critic(
args.layer_num, args.state_shape, args.action_shape, args.device
).to(args.device)
critic2 = Critic(net, args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = TD3Policy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
Expand Down
59 changes: 38 additions & 21 deletions tianshou/env/atari.py → examples/atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
import gym
import numpy as np
from gym.spaces.box import Box
from tianshou.data import Batch

SIZE = 84
FRAME = 4


def create_atari_environment(name=None, sticky_actions=True,
Expand All @@ -14,6 +18,27 @@ def create_atari_environment(name=None, sticky_actions=True,
return env


def preprocess_fn(obs=None, act=None, rew=None, done=None,
obs_next=None, info=None, policy=None):
if obs_next is not None:
obs_next = np.reshape(obs_next, (-1, *obs_next.shape[2:]))
obs_next = np.moveaxis(obs_next, 0, -1)
obs_next = cv2.resize(obs_next, (SIZE, SIZE))
obs_next = np.asanyarray(obs_next, dtype=np.uint8)
obs_next = np.reshape(obs_next, (-1, FRAME, SIZE, SIZE))
obs_next = np.moveaxis(obs_next, 1, -1)
elif obs is not None:
obs = np.reshape(obs, (-1, *obs.shape[2:]))
obs = np.moveaxis(obs, 0, -1)
obs = cv2.resize(obs, (SIZE, SIZE))
obs = np.asanyarray(obs, dtype=np.uint8)
obs = np.reshape(obs, (-1, FRAME, SIZE, SIZE))
obs = np.moveaxis(obs, 1, -1)

return Batch(obs=obs, act=act, rew=rew, done=done,
obs_next=obs_next, info=info)


class preprocessing(object):
def __init__(self, env, frame_skip=4, terminal_on_life_loss=False,
size=84, max_episode_steps=2000):
Expand All @@ -35,7 +60,8 @@ def __init__(self, env, frame_skip=4, terminal_on_life_loss=False,

@property
def observation_space(self):
return Box(low=0, high=255, shape=(self.size, self.size, 4),
return Box(low=0, high=255,
shape=(self.size, self.size, self.frame_skip),
dtype=np.uint8)

def action_space(self):
Expand All @@ -57,8 +83,8 @@ def reset(self):
self._grayscale_obs(self.screen_buffer[0])
self.screen_buffer[1].fill(0)

return np.stack([
self._pool_and_resize() for _ in range(self.frame_skip)], axis=-1)
return np.array([self._pool_and_resize()
for _ in range(self.frame_skip)])

def render(self, mode='human'):
return self.env.render(mode)
Expand All @@ -85,19 +111,15 @@ def step(self, action):
self._grayscale_obs(self.screen_buffer[t_])

observation.append(self._pool_and_resize())
while len(observation) > 0 and len(observation) < self.frame_skip:
if len(observation) == 0:
observation = [self._pool_and_resize()
for _ in range(self.frame_skip)]
while len(observation) > 0 and \
len(observation) < self.frame_skip:
observation.append(observation[-1])
if len(observation) > 0:
observation = np.stack(observation, axis=-1)
else:
observation = np.stack([
self._pool_and_resize() for _ in range(self.frame_skip)],
axis=-1)
if self.count >= self.max_episode_steps:
terminal = True
else:
terminal = False
return observation, total_reward, (terminal or is_terminal), info
terminal = self.count >= self.max_episode_steps
return np.array(observation), total_reward, \
(terminal or is_terminal), info

def _grayscale_obs(self, output):
self.env.ale.getScreenGrayscale(output)
Expand All @@ -108,9 +130,4 @@ def _pool_and_resize(self):
np.maximum(self.screen_buffer[0], self.screen_buffer[1],
out=self.screen_buffer[0])

transformed_image = cv2.resize(self.screen_buffer[0],
(self.size, self.size),
interpolation=cv2.INTER_AREA)
int_image = np.asarray(transformed_image, dtype=np.uint8)
# return np.expand_dims(int_image, axis=2)
return int_image
return self.screen_buffer[0]
81 changes: 0 additions & 81 deletions examples/continuous_net.py

This file was deleted.

17 changes: 8 additions & 9 deletions examples/halfcheetahBullet_v0_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
import pybullet_envs
except ImportError:
pass

from continuous_net import ActorProb, Critic
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import ActorProb, Critic


def get_args():
Expand Down Expand Up @@ -66,18 +66,17 @@ def test_sac(args=get_args()):
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.layer_num, args.state_shape, device=args.device)
actor = ActorProb(
args.layer_num, args.state_shape, args.action_shape,
net, args.action_shape,
args.max_action, args.device, unbounded=True
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
critic1 = Critic(
args.layer_num, args.state_shape, args.action_shape, args.device
).to(args.device)
net = Net(args.layer_num, args.state_shape,
args.action_shape, concat=True, device=args.device)
critic1 = Critic(net, args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
critic2 = Critic(
args.layer_num, args.state_shape, args.action_shape, args.device
).to(args.device)
critic2 = Critic(net, args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = SACPolicy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
Expand Down
Loading

0 comments on commit 979dc2d

Please sign in to comment.