-
Notifications
You must be signed in to change notification settings - Fork 169
/
trajGRU.py
180 lines (158 loc) · 7.45 KB
/
trajGRU.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import torch
from torch import nn
from nowcasting.config import cfg
from nowcasting.models.model import activation
import torch.nn.functional as F
# input: B, C, H, W
# flow: [B, 2, H, W]
def wrap(input, flow):
B, C, H, W = input.size()
# mesh grid
xx = torch.arange(0, W).view(1, -1).repeat(H, 1).to(cfg.GLOBAL.DEVICE)
yy = torch.arange(0, H).view(-1, 1).repeat(1, W).to(cfg.GLOBAL.DEVICE)
xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1)
yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1)
grid = torch.cat((xx, yy), 1).float()
vgrid = grid + flow
# scale grid to [-1,1]
vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :].clone() / max(W - 1, 1) - 1.0
vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :].clone() / max(H - 1, 1) - 1.0
vgrid = vgrid.permute(0, 2, 3, 1)
output = torch.nn.functional.grid_sample(input, vgrid)
return output
class BaseConvRNN(nn.Module):
def __init__(self, num_filter, b_h_w,
h2h_kernel=(3, 3), h2h_dilate=(1, 1),
i2h_kernel=(3, 3), i2h_stride=(1, 1),
i2h_pad=(1, 1), i2h_dilate=(1, 1),
act_type=torch.tanh,
prefix='BaseConvRNN'):
super(BaseConvRNN, self).__init__()
self._prefix = prefix
self._num_filter = num_filter
self._h2h_kernel = h2h_kernel
assert (self._h2h_kernel[0] % 2 == 1) and (self._h2h_kernel[1] % 2 == 1), \
"Only support odd number, get h2h_kernel= %s" % str(h2h_kernel)
self._h2h_pad = (h2h_dilate[0] * (h2h_kernel[0] - 1) // 2,
h2h_dilate[1] * (h2h_kernel[1] - 1) // 2)
self._h2h_dilate = h2h_dilate
self._i2h_kernel = i2h_kernel
self._i2h_stride = i2h_stride
self._i2h_pad = i2h_pad
self._i2h_dilate = i2h_dilate
self._act_type = act_type
assert len(b_h_w) == 3
i2h_dilate_ksize_h = 1 + (self._i2h_kernel[0] - 1) * self._i2h_dilate[0]
i2h_dilate_ksize_w = 1 + (self._i2h_kernel[1] - 1) * self._i2h_dilate[1]
self._batch_size, self._height, self._width = b_h_w
self._state_height = (self._height + 2 * self._i2h_pad[0] - i2h_dilate_ksize_h)\
// self._i2h_stride[0] + 1
self._state_width = (self._width + 2 * self._i2h_pad[1] - i2h_dilate_ksize_w) \
// self._i2h_stride[1] + 1
self._curr_states = None
self._counter = 0
class TrajGRU(BaseConvRNN):
# b_h_w: input feature map size
def __init__(self, input_channel, num_filter, b_h_w, zoneout=0.0, L=5,
i2h_kernel=(3, 3), i2h_stride=(1, 1), i2h_pad=(1, 1),
h2h_kernel=(5, 5), h2h_dilate=(1, 1),
act_type=cfg.MODEL.RNN_ACT_TYPE):
super(TrajGRU, self).__init__(num_filter=num_filter,
b_h_w=b_h_w,
h2h_kernel=h2h_kernel,
h2h_dilate=h2h_dilate,
i2h_kernel=i2h_kernel,
i2h_pad=i2h_pad,
i2h_stride=i2h_stride,
act_type=act_type,
prefix='TrajGRU')
self._L = L
self._zoneout = zoneout
# 对应 wxz, wxr, wxh
# reset_gate, update_gate, new_mem
self.i2h = nn.Conv2d(in_channels=input_channel,
out_channels=self._num_filter*3,
kernel_size=self._i2h_kernel,
stride=self._i2h_stride,
padding=self._i2h_pad,
dilation=self._i2h_dilate)
# inputs to flow
self.i2f_conv1 = nn.Conv2d(in_channels=input_channel,
out_channels=32,
kernel_size=(5, 5),
stride=1,
padding=(2, 2),
dilation=(1, 1))
# hidden to flow
self.h2f_conv1 = nn.Conv2d(in_channels=self._num_filter,
out_channels=32,
kernel_size=(5, 5),
stride=1,
padding=(2, 2),
dilation=(1, 1))
# generate flow
self.flows_conv = nn.Conv2d(in_channels=32,
out_channels=self._L * 2,
kernel_size=(5, 5),
stride=1,
padding=(2, 2))
# 对应 hh, hz, hr,为 1 * 1 的卷积核
self.ret = nn.Conv2d(in_channels=self._num_filter*self._L,
out_channels=self._num_filter*3,
kernel_size=(1, 1),
stride=1)
# inputs: B*C*H*W
def _flow_generator(self, inputs, states):
if inputs is not None:
i2f_conv1 = self.i2f_conv1(inputs)
else:
i2f_conv1 = None
h2f_conv1 = self.h2f_conv1(states)
f_conv1 = i2f_conv1 + h2f_conv1 if i2f_conv1 is not None else h2f_conv1
f_conv1 = self._act_type(f_conv1)
flows = self.flows_conv(f_conv1)
flows = torch.split(flows, 2, dim=1)
return flows
# inputs 和 states 不同时为空
# inputs: S*B*C*H*W
def forward(self, inputs=None, states=None, seq_len=cfg.HKO.BENCHMARK.IN_LEN):
if states is None:
states = torch.zeros((inputs.size(1), self._num_filter, self._state_height,
self._state_width), dtype=torch.float).to(cfg.GLOBAL.DEVICE)
if inputs is not None:
S, B, C, H, W = inputs.size()
i2h = self.i2h(torch.reshape(inputs, (-1, C, H, W)))
i2h = torch.reshape(i2h, (S, B, i2h.size(1), i2h.size(2), i2h.size(3)))
i2h_slice = torch.split(i2h, self._num_filter, dim=2)
else:
i2h_slice = None
prev_h = states
outputs = []
for i in range(seq_len):
if inputs is not None:
flows = self._flow_generator(inputs[i, ...], prev_h)
else:
flows = self._flow_generator(None, prev_h)
wrapped_data = []
for j in range(len(flows)):
flow = flows[j]
wrapped_data.append(wrap(prev_h, -flow))
wrapped_data = torch.cat(wrapped_data, dim=1)
h2h = self.ret(wrapped_data)
h2h_slice = torch.split(h2h, self._num_filter, dim=1)
if i2h_slice is not None:
reset_gate = torch.sigmoid(i2h_slice[0][i, ...] + h2h_slice[0])
update_gate = torch.sigmoid(i2h_slice[1][i, ...] + h2h_slice[1])
new_mem = self._act_type(i2h_slice[2][i, ...] + reset_gate * h2h_slice[2])
else:
reset_gate = torch.sigmoid(h2h_slice[0])
update_gate = torch.sigmoid(h2h_slice[1])
new_mem = self._act_type(reset_gate * h2h_slice[2])
next_h = update_gate * prev_h + (1 - update_gate) * new_mem
if self._zoneout > 0.0:
mask = F.dropout2d(torch.zeros_like(prev_h), p=self._zoneout)
next_h = torch.where(mask, next_h, prev_h)
outputs.append(next_h)
prev_h = next_h
# return torch.cat(outputs), next_h
return torch.stack(outputs), next_h