Skip to content
This repository has been archived by the owner on Dec 24, 2024. It is now read-only.

[WIP] Support gym.spaces.Dict #219

Merged
merged 18 commits into from
Apr 2, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions machina/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from machina.envs.ac_in_ob_env import AcInObEnv
from machina.envs.rew_in_ob_env import RewInObEnv
from machina.envs.skill_env import SkillEnv
from machina.envs.env_utils import flatten_to_dict
18 changes: 18 additions & 0 deletions machina/envs/env_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import numpy as np
from collections import OrderedDict


def flatten_to_dict(flatten_obs, dict_space, dict_keys=None):
if dict_keys is None:
dict_keys = dict_space.spaces.keys()
obs_dict = OrderedDict()
begin_index = 0
end_index = 0
for key in dict_keys:
origin_shape = dict_space.spaces[key].shape
end_index += np.prod(origin_shape)
dim = len(flatten_obs.shape)
obs_dict[key] = flatten_obs[..., begin_index:end_index].reshape(
flatten_obs.shape[:-1] + origin_shape)
begin_index = end_index
return obs_dict
107 changes: 107 additions & 0 deletions tests/env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""
This is a Dict space version of Pendulum.
"""
import gym
from gym import spaces
from gym.utils import seeding
import numpy as np
from os import path


class PendulumDictEnv(gym.Env):
metadata = {
'render.modes': ['human', 'rgb_array'],
'video.frames_per_second': 30
}

def __init__(self):
self.max_speed = 8
self.max_torque = 2.
self.dt = .05
self.viewer = None

high_ang = np.array([1., 1.])
high_ang_vel = np.array([self.max_speed])
self.action_space = spaces.Box(
low=-self.max_torque, high=self.max_torque, shape=(1,), dtype=np.float32)
self.observation_space = spaces.Dict(
{
'angle': spaces.Box(low=-high_ang, high=high_ang, dtype=np.float32),
'angular_velocity': spaces.Box(low=-high_ang_vel, high=high_ang_vel, dtype=np.float32)
}
)

self.seed()

def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]

def step(self, u):
th, thdot = self.state # th := theta

g = 10.
m = 1.
l = 1.
dt = self.dt

u = np.clip(u, -self.max_torque, self.max_torque)[0]
self.last_u = u # for rendering
costs = angle_normalize(th)**2 + .1*thdot**2 + .001*(u**2)

newthdot = thdot + \
(-3*g/(2*l) * np.sin(th + np.pi) + 3./(m*l**2)*u) * dt
newth = th + newthdot*dt
newthdot = np.clip(newthdot, -self.max_speed,
self.max_speed) # pylint: disable=E1111

self.state = np.array([newth, newthdot])
return self._get_obs(), -costs, False, {}

def reset(self):
high = np.array([np.pi, 1])
self.state = self.np_random.uniform(low=-high, high=high)
self.last_u = None
return self._get_obs()

def _get_obs(self):
theta, thetadot = self.state
return {
'angle': np.array([np.cos(theta), np.sin(theta)]),
'angular_velocity': np.array([thetadot])
}

def render(self, mode='human'):

if self.viewer is None:
from gym.envs.classic_control import rendering
self.viewer = rendering.Viewer(500, 500)
self.viewer.set_bounds(-2.2, 2.2, -2.2, 2.2)
rod = rendering.make_capsule(1, .2)
rod.set_color(.8, .3, .3)
self.pole_transform = rendering.Transform()
rod.add_attr(self.pole_transform)
self.viewer.add_geom(rod)
axle = rendering.make_circle(.05)
axle.set_color(0, 0, 0)
self.viewer.add_geom(axle)
fname = path.join(path.dirname(__file__), "assets/clockwise.png")
self.img = rendering.Image(fname, 1., 1.)
self.imgtrans = rendering.Transform()
self.img.add_attr(self.imgtrans)

self.viewer.add_onetime(self.img)
self.pole_transform.set_rotation(self.state[0] + np.pi/2)
if self.last_u:
self.imgtrans.scale = (-self.last_u/2, np.abs(self.last_u)/2)

return self.viewer.render(return_rgb_array=mode == 'rgb_array')

def close(self):
if self.viewer:
self.viewer.close()
self.viewer = None


def angle_normalize(x):
return (((x+np.pi) % (2*np.pi)) - np.pi)
128 changes: 128 additions & 0 deletions tests/simple_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.nn.init import kaiming_uniform_, uniform_
import torch.nn.functional as F
import gym
from machina.envs import flatten_to_dict


def mini_weight_init(m):
Expand Down Expand Up @@ -325,3 +326,130 @@ def forward(self, ob):
feat = self.discrim_f(ob)
h = torch.relu(self.fc1(feat))
return self.output_layer(h)


class PolDictNet(nn.Module):
def __init__(self, ob_space, ac_space, h1=200, h2=100, deterministic=False):
super(PolDictNet, self).__init__()

self.ob_space = ob_space
self.deterministic = deterministic

if isinstance(ac_space, gym.spaces.Box):
self.discrete = False
else:
self.discrete = True
if isinstance(ac_space, gym.spaces.MultiDiscrete):
self.multi = True
else:
self.multi = False

self.fc1 = nn.Linear(ob_space.spaces['angle'].shape[0], h1)
self.fc2 = nn.Linear(
h1 + self.ob_space.spaces['angular_velocity'].shape[0], h2)
self.fc1.apply(weight_init)
self.fc2.apply(weight_init)

if not self.discrete:
self.mean_layer = nn.Linear(h2, ac_space.shape[0])
if not self.deterministic:
self.log_std_param = nn.Parameter(
torch.randn(ac_space.shape[0])*1e-10 - 1)
self.mean_layer.apply(mini_weight_init)
else:
if self.multi:
self.output_layers = nn.ModuleList(
[nn.Linear(h2, vec) for vec in ac_space.nvec])
list(map(lambda x: x.apply(mini_weight_init), self.output_layers))
else:
self.output_layer = nn.Linear(h2, ac_space.n)
self.output_layer.apply(mini_weight_init)

def forward(self, ob):
dict_ob = flatten_to_dict(ob, self.ob_space)
h = F.relu(self.fc1(dict_ob['angle']))
h = F.relu(
self.fc2(torch.cat([h, dict_ob['angular_velocity']], dim=1)))
if not self.discrete:
mean = torch.tanh(self.mean_layer(h))
if not self.deterministic:
log_std = self.log_std_param.expand_as(mean)
return mean, log_std
else:
return mean
else:
if self.multi:
return torch.cat([torch.softmax(ol(h), dim=-1).unsqueeze(-2) for ol in self.output_layers], dim=-2)
else:
return torch.softmax(self.output_layer(h), dim=-1)


class PolNetDictLSTM(nn.Module):
def __init__(self, ob_space, ac_space, h_size=1024, cell_size=512):
super(PolNetDictLSTM, self).__init__()
self.h_size = h_size
self.cell_size = cell_size
self.rnn = True
self.ob_space = ob_space

if isinstance(ac_space, gym.spaces.Box):
self.discrete = False
else:
self.discrete = True
if isinstance(ac_space, gym.spaces.MultiDiscrete):
self.multi = True
else:
self.multi = False

self.input_layer = nn.Linear(
ob_space.spaces['angle'].shape[0], self.h_size)
self.cell = nn.LSTMCell(
self.h_size + ob_space.spaces['angular_velocity'].shape[0], hidden_size=self.cell_size)
if not self.discrete:
self.mean_layer = nn.Linear(self.cell_size, ac_space.shape[0])
self.log_std_param = nn.Parameter(
torch.randn(ac_space.shape[0])*1e-10 - 1)

self.mean_layer.apply(mini_weight_init)
else:
if self.multi:
self.output_layers = nn.ModuleList(
[nn.Linear(self.cell_size, vec) for vec in ac_space.nvec])
list(map(lambda x: x.apply(mini_weight_init), self.output_layers))
else:
self.output_layer = nn.Linear(self.cell_size, ac_space.n)
self.output_layer.apply(mini_weight_init)

def init_hs(self, batch_size=1):
new_hs = (next(self.parameters()).new(batch_size, self.cell_size).zero_(), next(
self.parameters()).new(batch_size, self.cell_size).zero_())
return new_hs

def forward(self, xs, hs, h_masks):
print(xs.shape)
time_seq, batch_size, *_ = xs.shape

hs = (hs[0].reshape(batch_size, self.cell_size),
hs[1].reshape(batch_size, self.cell_size))

dict_xs = flatten_to_dict(xs, self.ob_space)
xs = torch.relu(self.input_layer(dict_xs['angle']))
ang_vels = dict_xs['angular_velocity']

hiddens = []
for x, ang_vel, mask in zip(xs, ang_vels, h_masks):
hs = (hs[0] * (1 - mask), hs[1] * (1 - mask))
hs = self.cell(
torch.cat([x, ang_vel], dim=1), hs)
hiddens.append(hs[0])
hiddens = torch.cat([h.unsqueeze(0) for h in hiddens], dim=0)

if not self.discrete:
means = torch.tanh(self.mean_layer(hiddens))
log_std = self.log_std_param.expand_as(means)
return means, log_std, hs
else:
if self.multi:
return torch.cat([torch.softmax(ol(hiddens), dim=-1).unsqueeze(-2) for ol in self.output_layers], dim=-2), hs
else:
return torch.softmax(self.output_layer(hiddens), dim=-1), hs
Loading