-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathmodel.py
129 lines (114 loc) · 5.19 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import torch
from torch import nn
from torch.nn import functional as F
from torch.distributions import Normal
class Encoder(nn.Module):
"""
Encoder to embed image observation (3, 64, 64) to vector (1024,)
"""
def __init__(self):
super(Encoder, self).__init__()
self.cv1 = nn.Conv2d(3, 32, kernel_size=4, stride=2)
self.cv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
self.cv3 = nn.Conv2d(64, 128, kernel_size=4, stride=2)
self.cv4 = nn.Conv2d(128, 256, kernel_size=4, stride=2)
def forward(self, obs):
hidden = F.relu(self.cv1(obs))
hidden = F.relu(self.cv2(hidden))
hidden = F.relu(self.cv3(hidden))
embedded_obs = F.relu(self.cv4(hidden)).reshape(hidden.size(0), -1)
return embedded_obs
class RecurrentStateSpaceModel(nn.Module):
"""
This class includes multiple components
Deterministic state model: h_t+1 = f(h_t, s_t, a_t)
Stochastic state model (prior): p(s_t+1 | h_t+1)
State posterior: q(s_t | h_t, o_t)
NOTE: actually, this class takes embedded observation by Encoder class
min_stddev is added to stddev same as original implementation
Activation function for this class is F.relu same as original implementation
"""
def __init__(self, state_dim, action_dim, rnn_hidden_dim,
hidden_dim=200, min_stddev=0.1, act=F.relu):
super(RecurrentStateSpaceModel, self).__init__()
self.state_dim = state_dim
self.action_dim = action_dim
self.rnn_hidden_dim = rnn_hidden_dim
self.fc_state_action = nn.Linear(state_dim + action_dim, hidden_dim)
self.fc_rnn_hidden = nn.Linear(rnn_hidden_dim, hidden_dim)
self.fc_state_mean_prior = nn.Linear(hidden_dim, state_dim)
self.fc_state_stddev_prior = nn.Linear(hidden_dim, state_dim)
self.fc_rnn_hidden_embedded_obs = nn.Linear(rnn_hidden_dim + 1024, hidden_dim)
self.fc_state_mean_posterior = nn.Linear(hidden_dim, state_dim)
self.fc_state_stddev_posterior = nn.Linear(hidden_dim, state_dim)
self.rnn = nn.GRUCell(hidden_dim, rnn_hidden_dim)
self._min_stddev = min_stddev
self.act = act
def forward(self, state, action, rnn_hidden, embedded_next_obs):
"""
h_t+1 = f(h_t, s_t, a_t)
Return prior p(s_t+1 | h_t+1) and posterior p(s_t+1 | h_t+1, o_t+1)
for model training
"""
next_state_prior, rnn_hidden = self.prior(state, action, rnn_hidden)
next_state_posterior = self.posterior(rnn_hidden, embedded_next_obs)
return next_state_prior, next_state_posterior, rnn_hidden
def prior(self, state, action, rnn_hidden):
"""
h_t+1 = f(h_t, s_t, a_t)
Compute prior p(s_t+1 | h_t+1)
"""
hidden = self.act(self.fc_state_action(torch.cat([state, action], dim=1)))
rnn_hidden = self.rnn(hidden, rnn_hidden)
hidden = self.act(self.fc_rnn_hidden(rnn_hidden))
mean = self.fc_state_mean_prior(hidden)
stddev = F.softplus(self.fc_state_stddev_prior(hidden)) + self._min_stddev
return Normal(mean, stddev), rnn_hidden
def posterior(self, rnn_hidden, embedded_obs):
"""
Compute posterior q(s_t | h_t, o_t)
"""
hidden = self.act(self.fc_rnn_hidden_embedded_obs(
torch.cat([rnn_hidden, embedded_obs], dim=1)))
mean = self.fc_state_mean_posterior(hidden)
stddev = F.softplus(self.fc_state_stddev_posterior(hidden)) + self._min_stddev
return Normal(mean, stddev)
class ObservationModel(nn.Module):
"""
p(o_t | s_t, h_t)
Observation model to reconstruct image observation (3, 64, 64)
from state and rnn hidden state
"""
def __init__(self, state_dim, rnn_hidden_dim):
super(ObservationModel, self).__init__()
self.fc = nn.Linear(state_dim + rnn_hidden_dim, 1024)
self.dc1 = nn.ConvTranspose2d(1024, 128, kernel_size=5, stride=2)
self.dc2 = nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2)
self.dc3 = nn.ConvTranspose2d(64, 32, kernel_size=6, stride=2)
self.dc4 = nn.ConvTranspose2d(32, 3, kernel_size=6, stride=2)
def forward(self, state, rnn_hidden):
hidden = self.fc(torch.cat([state, rnn_hidden], dim=1))
hidden = hidden.view(hidden.size(0), 1024, 1, 1)
hidden = F.relu(self.dc1(hidden))
hidden = F.relu(self.dc2(hidden))
hidden = F.relu(self.dc3(hidden))
obs = self.dc4(hidden)
return obs
class RewardModel(nn.Module):
"""
p(r_t | s_t, h_t)
Reward model to predict reward from state and rnn hidden state
"""
def __init__(self, state_dim, rnn_hidden_dim, hidden_dim=300, act=F.relu):
super(RewardModel, self).__init__()
self.fc1 = nn.Linear(state_dim + rnn_hidden_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, hidden_dim)
self.fc4 = nn.Linear(hidden_dim, 1)
self.act = act
def forward(self, state, rnn_hidden):
hidden = self.act(self.fc1(torch.cat([state, rnn_hidden], dim=1)))
hidden = self.act(self.fc2(hidden))
hidden = self.act(self.fc3(hidden))
reward = self.fc4(hidden)
return reward