-
-
Notifications
You must be signed in to change notification settings - Fork 95
/
nnp_gru.nim
506 lines (442 loc) · 19.9 KB
/
nnp_gru.nim
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
# Copyright (c) 2018 the Arraymancer contributors
# Distributed under the Apache v2 License (license terms are at http://www.apache.org/licenses/LICENSE-2.0).
# This file may not be copied, modified, or distributed except according to those terms.
import
../tensor,
private/p_activation, ./nnp_linear,
nnp_activation
# For compatibility with CuDNN and allow loading CPU/Cuda weights interchangeably,
# we use the following equations,
#
# - h is hidden state at t-1, h' at t
# - input == x, hidden == h
# - n = h~ (the candidate hidden state)
# - r is the reset gate
# - z is the update gate
# - h', the final output, is a linear interpolation
#
# r = σ(Wr * x + bWr + Ur * h + bUr)
# z = σ(Wz * x + bWz + Uz * h + bUz)
# n = tanh(W * x + bW + r *. (U * h + bU ))
# h' = (1 - z) *. n + z *. h
#
# Those differs from the original paper for n and h'
# - The pointwise multiplication by r is after the matrix multiplication
# - The linear interpolation has the terms switched
# TODO: after the 2 "linear" in forward prop and before the linear
# in backprop, everything is elementwise
# we could use a giant loop-fusion to avoid intermediate tensors
#
# Note that the CPU prefetcher might not work as well, because
# between the use of U3h.unsafe_raw_buf[i] and U3h.unsafe_raw_buf[i+1]
# there will be a lot of intermediate computation.
#
# Also see here for counterarg: https://software.intel.com/en-us/forums/intel-moderncode-for-parallel-architectures/topic/635075
# Intel CPUs prefetcher can maintain 32 streams
proc gru_cell_inference*[T: SomeFloat](
input: Tensor[T],
W3, U3,
bW3, bU3: Tensor[T],
hidden: var Tensor[T]) =
## Input:
## - input tensor of shape [batch_size, features]
## - weight of input W3 [3 * hidden_size, features]
## - weight of hidden U3 [3 * hidden_size, hidden_size]
## - biases of input and hidden state [1, 3 * hidden_size]
##
## Output (in-place):
## - y == h'(t): The next hidden state of the GRU Cell.
## (GRU output and next hidden state are the same)
##
## ⚠️ Input/Output updated in-place:
## - h(t) -> h'(t), the hidden state of shape [batch_size, hidden_size]
## is both an input and output
##
## This is an optimized function when backpropagation is not needed.
let
H = hidden.shape[1]
# Slices
sr = (0 ..< H)|1
sz = (H ..< 2*H)|1
srz = (0 ..< 2*H)|1
s = (2*H ..< 3*H)|1
# Step 1 - U*h and W*x - Resulting shape [batch_size, 3*H]
var W3x, U3h: Tensor[T] # TODO, pass those as parameter to allow buffer reuse
linear(input, W3, bW3, W3x)
linear(hidden, U3, bU3, U3h)
# Step 2 - Computing reset (r) and update (z) gate
var W2ru = W3x[_, srz] # shape [batch_size, 2*H] - we reuse the previous buffer
apply2_inline(W2ru, U3h[_, srz]):
sigmoid(x + y)
# Step 3 - Computing candidate hidden state ñ
var n = W3x[_, s] # shape [batch_size, H] - we reuse the previous buffer
apply3_inline(n, W2ru[_, sr], U3h[_, s]):
tanh(x + y * z)
# Step 4 - Update the hidden state
apply3_inline(hidden, W3x[_, sz], n):
(1 - y) * z + y * x
proc gru_cell_forward*[T: SomeFloat](
input,
W3, U3,
bW3, bU3: Tensor[T],
r, z, n, Uh,
hidden: var Tensor[T]
) =
## Input:
## - input tensor of shape [batch_size, features]
## - hidden state of shape [batch_size, hidden_size]
## - gates weights of input W3 [3 * hidden_size, features]
## - recurrent weights of hidden state U3 [3 * hidden_size, hidden_size]
## - biases of input and hidden state [1, 3 * hidden_size]
##
## Output:
## - r, z, n, Uh: intermediate tensors saved for backpropagation.
## of shape [batch_size, hidden_size]
## - y == h'(t): The next hidden state of the GRU Cell.
## (GRU output and next hidden state are the same)
##
## ⚠️ Input/output updated in place:
## - h(t) -> h'(t), the hidden state of shape [batch_size, hidden_size]
## is both an input and output
let
H = hidden.shape[1]
# Slices
sr = (0 ..< H)|1
sz = (H ..< 2*H)|1
s = (2*H ..< 3*H)|1
# Step 1 - U*h and W*x - Resulting shape [batch_size, 3*H]
var W3x, U3h: Tensor[T] # TODO, pass those as parameter to allow buffer reuse
linear(input, W3, bW3, W3x)
linear(hidden, U3, bU3, U3h)
# # Saving for backprop
apply2_inline(Uh, U3h[_, s]):
y
# Step 2 - Computing reset (r) and update (z) gate
apply3_inline(r, W3x[_, sr], U3h[_, sr]):
sigmoid(y + z)
apply3_inline(z, W3x[_, sz], U3h[_, sz]):
sigmoid(y + z)
# Step 3 - Computing candidate hidden state ñ
# TODO: need apply4 / loopfusion for efficient
# buffer passing in Stacked GRU implementation
n = map3_inline(W3x[_, s], r, U3h[_, s]):
tanh(x + y * z)
# Step 4 - Update the hidden state
apply3_inline(hidden, z, n):
(1 - y) * z + y * x
proc gru_cell_backward*[T: SomeFloat](
dx, dh, dW3, dU3, # input and weights gradients
dbW3, dbU3: var Tensor[T], # bias gradient
dnext: Tensor[T], # gradient flowing back from the next hidden state
x, h, W3, U3: Tensor[T], # input parameters saved from forward
r, z, n, Uh: Tensor[T] # Intermediate tensors saved from forward
) =
## Input:
## - dx, dh, dW3, dU3: respectively gradients of
## - x, input tensor during the forward pass. Shape [batch_size, features]
## - h, hidden state during the forward pass. Shape [batch_size, hidden_size]
## - W3, gate input weights (multiplied by x) during the forward pass. Shape [3 * hidden_size, features]
## - U3, recurrent weights (multiplied by h) during the forward pass. Shape [3 * hidden_size, features]
## - dbW3 and dbU3: gradients of the biases for W3 and U3 weights
## - dnext: gradient flowing back from the next layer
## - x, h, W3, U3: inputs saved from the forward pass
## - r, z, n, Uh: intermediate results saved from the forward pass of shape [batch_size, hidden_size]
# Backprop of step 4 - z part
let dz = (h - n) *. dnext
let dn = (1.0.T -. z) *. dnext
# Backprop of step 3.
let dWx = tanh_backward(dn, n)
let dr = Uh *. dWx
let dUh = r *. dWx
# Backprop of step 2 - update gate z
let dWzx = sigmoid_backward(dz, z)
let dUzh = dWzx
# Backprop of step 2 - reset gate r
let dWrx = sigmoid_backward(dr, r)
let dUrh = dWrx
# Concat
let dW3x = concat(dWrx, dWzx, dWx, axis = 1)
let dU3h = concat(dUrh, dUzh, dUh, axis = 1)
# Backprop of step 1 - TODO this detaches gradients if they are slices
linear_backward(x, W3, dW3x, dx, dW3, dbW3)
linear_backward(h, U3, dU3h, dh, dU3, dbU3)
# Backprop of step 4 - h part
apply3_inline(dh, dnext, z):
x + y * z
proc gru_inference*[T: SomeFloat](
input: Tensor[T],
W3s0, W3sN: Tensor[T],
U3s, bW3s, bU3s: Tensor[T],
output, hidden: var Tensor[T]
) =
## Bidirectional support is not implemented
##
## Inputs:
## - `input`: Input tensor of shape [sequence/timesteps, batch, features]
## - Input weights `W3s` of shapes:
## - W3s0: [3 * hidden_size, features] for the first layer
## - W3sN: [num_stacked_layers - 1, 3 * hidden_size, num_directions * hidden_size] for the following layers
## - A series of hidden state weights `U3s` of shape [num_stacked_layers, 3 * hidden_size, hidden_size]
## - A series of biases for input and hidden state weights of shape [num_stacked_layers, 1, 3 * hidden_size]
##
## Outputs:
## - `output` of shape [sequence/timesteps, batch, num_directions * hidden_size].
## `output` contains the output features `hiddenT` for each T (timesteps)
## - `hidden` of shape [num_stacked_layers * num_directions, batch, hidden_size].
## `hidden` contains the hidden state for timestep T == sequence/timesteps length of `input`
##
## ⚠️ Input/Output updated in-place:
## - h(t) -> h'(t), the hidden state of shape [num_stacked_layers * num_directions, batch, hidden_size]
## is both an input and output
# 0. Retrieve the metadata and validate it
let
seq_len = input.shape[0]
batch_size = input.shape[1]
num_features = input.shape[2]
hidden_size = hidden.shape[2]
num_stacked_layers = 1 + W3sN.shape[0]
num_directions = hidden.shape[0] div num_stacked_layers # Always 1 at the moment
doAssert hidden.shape == [num_stacked_layers * num_directions, batch_size, hidden_size]
doAssert W3s0.shape == [3 * hidden_size, num_features]
if num_stacked_layers > 1:
doAssert W3sN.shape == [num_stacked_layers - 1, 3 * hidden_size, num_directions * hidden_size]
doAssert U3s.shape == [num_stacked_layers, 3 * hidden_size, hidden_size]
doAssert bW3s.shape == [num_stacked_layers, 1, 3 * hidden_size]
doAssert bU3s.shape == bW3s.shape
# Initialize output
output = newTensorUninit[T](seq_len, batch_size, num_directions * hidden_size)
# 2. Subsequent layers
for layer in 0 ..< num_stacked_layers:
let
W3l = if layer == 0: W3s0 else: W3sN[layer - 1, _, _].squeeze(0)
U3l = U3s[layer, _, _].squeeze(0)
bW3l = bW3s[layer, _, _].squeeze(0)
bU3l = bU3s[layer, _, _].squeeze(0)
var hiddenl = hidden[layer * num_directions, _, _].squeeze(0)
for timestep in 0 ..< seq_len:
# TODO: reuse more than the output buffer
let input_ts = block:
if layer == 0:
input[timestep, _, _].squeeze(0)
else:
output[timestep, _, _].squeeze(0)
gru_cell_inference(
input_ts,
W3l, U3l, bW3l, bU3l,
hiddenl
)
output[timestep, _, _] = hiddenl.unsqueeze(0)
proc gru_forward*[T: SomeFloat](
input: Tensor[T],
W3s0, W3sN: Tensor[T],
U3s, bW3s, bU3s: Tensor[T],
rs, zs, ns, Uhs: var Tensor[T],
output, hidden: var Tensor[T],
cached_inputs: var seq[Tensor[T]],
cached_hiddens: var seq[seq[Tensor[T]]]
) =
## ⚠️ API subject to change to match CuDNNs
##
## Bidirectional support is not implemented
##
## Inputs:
## - `input`: Input tensor of shape [sequence/timesteps, batch, features]
## - Input weights `W3s` of shapes:
## - W3s0: [3 * hidden_size, features] for the first layer
## - W3sN: [num_stacked_layers - 1, 3 * hidden_size, num_directions * hidden_size] for the following layers
## - A series of hidden state weights `U3s` of shape [num_stacked_layers, 3 * hidden_size, hidden_size]
## - A series of biases for input and hidden state weights of shape [num_stacked_layers, 1, 3 * hidden_size]
##
## Outputs:
## - rs, zs, ns, Uhs: intermediate tensors saved for backpropagation.
## Shape [num_stacked_layers, timesteps, batch_size, hidden_size]. They must be preallocated (but it can be with unitialized buffers).
## - `output` of shape [sequence/timesteps, batch, num_directions * hidden_size].
## `output` contains the output features `hiddenT` for each T (timesteps)
## - `hidden` of shape [num_stacked_layers * num_directions, batch, hidden_size].
## `hidden` contains the hidden state for timestep T == sequence/timesteps length of `input`
## - `cached_inputs`, a sequence of length num_stacked_layers containing
## - the first layer input of shape [sequence/timesteps, batch, features]
## - the following layer inputs of shape [sequence/timesteps, batch, num_directions * hidden_size]
## - `cached_hiddens`, a sequence of sequences of length [num_stacked_layers, sequence/timesteps]
## - containing all intermediate hidden states for each timesteps for each stacked layers.
## Hidden states are of tensors of shape [3 * hidden_size, hidden_size]
##
## ⚠️ Input/Output updated in-place:
## - h(t) -> h'(t), the hidden state of shape [num_stacked_layers * num_directions, batch, hidden_size]
## is both an input and output
# 0. Retrieve the metadata and validate it
let
seq_len = input.shape[0]
batch_size = input.shape[1]
num_features = input.shape[2]
hidden_size = hidden.shape[2]
num_stacked_layers = cached_inputs.len
num_directions = hidden.shape[0] div num_stacked_layers
doAssert hidden.shape == [num_stacked_layers * num_directions, batch_size, hidden_size]
doAssert W3s0.shape == [3 * hidden_size, num_features]
if num_stacked_layers > 1:
doAssert W3sN.shape == [num_stacked_layers - 1, 3 * hidden_size, num_directions * hidden_size]
doAssert U3s.shape == [num_stacked_layers, 3 * hidden_size, hidden_size]
doAssert bW3s.shape == [num_stacked_layers, 1, 3 * hidden_size]
doAssert bU3s.shape == bW3s.shape
doAssert rs.shape == [num_stacked_layers, seq_len, batch_size, hidden_size]
doAssert zs.shape == [num_stacked_layers, seq_len, batch_size, hidden_size]
doAssert ns.shape == [num_stacked_layers, seq_len, batch_size, hidden_size]
doAssert Uhs.shape == [num_stacked_layers, seq_len, batch_size, hidden_size]
# doAssert cached_inputs.len == num_stacked_layers
doAssert cached_hiddens.len == num_stacked_layers
for x in cached_hiddens:
doAssert x.len == seq_len
let directions = 1 # stub
# Initialize output
output = newTensorUninit[T](seq_len, batch_size, directions * hidden_size)
for layer in 0 ..< num_stacked_layers:
if layer == 0:
cached_inputs[0] = input
else:
cached_inputs[layer] = output.clone()
let
W3l = if layer == 0: W3s0 else: W3sN[layer - 1, _, _].squeeze(0)
U3l = U3s[layer, _, _].squeeze(0)
bW3l = bW3s[layer, _, _].squeeze(0)
bU3l = bU3s[layer, _, _].squeeze(0)
var hiddenl = hidden[layer, _, _].squeeze(0)
for timestep in 0 ..< seq_len:
cached_hiddens[layer][timestep] = hiddenl.clone()
var # Cache for backprop, squeeze the first 2 dim
r_lts = rs[layer, timestep, _, _].squeeze(0).squeeze(0)
z_lts = zs[layer, timestep, _, _].squeeze(0).squeeze(0)
n_lts = ns[layer, timestep, _, _].squeeze(0).squeeze(0)
Uh_lts = Uhs[layer, timestep, _, _].squeeze(0).squeeze(0)
# TODO: gru_cell_forward will detach `nl``
# due to a missing apply4/loop-fusion operation
var n_tmp = n_lts
let input_ts = block:
if layer == 0:
input[timestep, _, _].squeeze(0)
else:
output[timestep, _, _].squeeze(0)
# TODO: reuse buffers
gru_cell_forward(
input_ts,
W3l, U3l, bW3l, bU3l,
r_lts, z_lts, n_tmp, Uh_lts,
hiddenl
)
output[timestep, _, _] = hiddenl.unsqueeze(0)
# TODO: apply/loop-fusion
# copy n_tmpl back to nl
apply2_inline(n_lts, n_tmp):
y
proc gru_backward*[T: SomeFloat](
dInput, dHidden0, # Input and starting hidden state gradient
dW3s0, dW3sN, # Weight tensor
dU3s, dbW3s, dbU3s: var Tensor[T], # Weights & biases gradients
dOutput, dHiddenN: Tensor[T], # Gradient flowing back from the output/next hidden state
cached_inputs: seq[Tensor[T]], # Input params saved from forward
cached_hiddens: seq[seq[Tensor[T]]], # Input params saved from forward
W3s0, W3sN, U3s, # Input params saved from forward
rs, zs, ns, Uhs: Tensor[T] # Intermediate tensors saved from forward
) =
## ⚠️ API subject to change to match CuDNNs
## Outputs:
## - dinput, dHidden0, dW3s, dU3s:
## Gradient tensors, will hold the results corresponding to the respective gradients of:
## - `input`: Input tensor during the forward pass of shape [sequence/timesteps, batch, features]
## - `hidden`: Hidden states during the forward pass of shape [num_stacked_layers * num_directions, batch, hidden_size]
## - Input weights `W3s` of shapes:
## - W3s0: [3 * hidden_size, features] for the first layer
## - W3sN: [num_stacked_layers - 1, 3 * hidden_size, num_directions * hidden_size] for the following layers
## - `U3s`: A series of hidden state weights of shape [num_stacked_layers, 3 * hidden_size, hidden_size]
## - dbW3s and dbU3s: gradients of the biases. Shape [num_stacked_layers, 1, 3 * hidden_size]
##
## Inputs:
## - dOutput: gradient flowing back from the next layer.
## Shape: [sequence/timesteps, batch, num_directions * hidden_size]
## - dHiddenN: gradient flowing back from the last hidden states of each layers
## Shape: [num_stacked_layers * num_directions, batch, hidden_size]
## - cached_inputs, cached_hiddens, W3s, U3s: saved from the forward pass
## - rs, zs, ns, Uhs: intermediate results saved from the forward pass
## Shape [num_stacked_layers, batch_size, hidden_size]
# 0. Retrieve the metadata and validate it
let
seq_len = cached_inputs[0].shape[0]
batch_size = cached_inputs[0].shape[1]
num_features = cached_inputs[0].shape[2]
hidden_size = cached_hiddens[0][0].shape[1]
num_stacked_layers = cached_inputs.len
num_directions = 1 # stub
doAssert W3s0.shape == [3 * hidden_size, num_features]
if num_stacked_layers > 1:
doAssert W3sN.shape == [num_stacked_layers - 1, 3 * hidden_size, num_directions * hidden_size]
doAssert U3s.shape == [num_stacked_layers, 3 * hidden_size, hidden_size]
doAssert rs.shape == [num_stacked_layers, seq_len, batch_size, hidden_size]
doAssert zs.shape == [num_stacked_layers, seq_len, batch_size, hidden_size]
doAssert ns.shape == [num_stacked_layers, seq_len, batch_size, hidden_size]
doAssert Uhs.shape == [num_stacked_layers, seq_len, batch_size, hidden_size]
doAssert dOutput.shape == [seq_len, batch_size, num_directions * hidden_size]
doAssert dHiddenN.shape == [num_stacked_layers * num_directions, batch_size, hidden_size]
# doAssert cached_inputs.len == num_stacked_layers
doAssert cached_hiddens.len == num_stacked_layers
for x in cached_hiddens:
doAssert x.len == seq_len
# 1. Preallocate the results (TODO: separate alloc from compute so that users can pass buffers)
dHidden0 = newTensorUninit[T](num_stacked_layers, batch_size, hidden_size)
dW3s0 = zeros_like(W3s0)
if num_stacked_layers > 1:
dW3sN = zeros_like(W3sN)
dU3s = zeros_like(U3s)
dbW3s = zeros[T]([num_stacked_layers, 1, 3 * hidden_size])
dbU3s = zeros[T]([num_stacked_layers, 1, 3 * hidden_size])
# 2. Proceed from last layer to initial layer
var gFlowBack = dOutput.clone() # gradient flowing back
dInput = newTensorUninit[T](seq_len, batch_size, num_features)
for layer in countdown(num_stacked_layers - 1, 0):
let
W3l = if layer == 0: W3s0 else: W3sN[layer - 1, _, _].squeeze(0)
U3l = U3s[layer, _, _].squeeze(0)
inputl = cached_inputs[layer]
var dht1 = dHiddenN[layer, _, _].squeeze(0).clone() # Start from the gradient of the hidden state
for timestep in countdown(seq_len - 1, 0):
let
input_lts = inputl[timestep, _, _].squeeze(0)
hidden_lts = cached_hiddens[layer][timestep]
r_lts = rs[layer, timestep, _, _].squeeze(0).squeeze(0)
z_lts = zs[layer, timestep, _, _].squeeze(0).squeeze(0)
n_lts = ns[layer, timestep, _, _].squeeze(0).squeeze(0)
Uh_lts = Uhs[layer, timestep, _, _].squeeze(0).squeeze(0)
var gFlowBack_ts = gFlowBack[timestep, _, _].squeeze(0)
# gradients of hidden state and hidden state (t+1)
var dht: Tensor[T]
var dx: Tensor[T]
dht1 += gFlowBack_ts # Add the gradient of the last time step (copy during forward = addition in backward)
# Contribution of weights for this timestep
var dW3s_lts, dU3s_lts, dbW3s_lts, dbU3s_lts: Tensor[T]
gru_cell_backward(
dx, dht, dW3s_lts, dU3s_lts,
dbW3s_lts, dbU3s_lts,
dht1,
input_lts, hidden_lts, W3l, U3l,
r_lts, z_lts, n_lts, Uh_lts
)
# Update gradient flowing back at timestep to pass to next layer
if layer != 0:
gFlowBack_ts.copyFrom dx
else:
dInput[timestep, _, _] = dx.unsqueeze(0)
if timestep != 0:
dht1 = dht
else:
dHidden0[layer, _, _] = dht.unsqueeze(0)
# Accumulate the contribution of weights
if layer == 0:
dW3s0 += dW3s_lts
else:
var tmp = dW3sN[layer - 1, _, _]
tmp += dW3s_lts
var tmp = dU3s[layer, _, _]
tmp += dU3s_lts.unsqueeze(0)
tmp = dbW3s[layer, _, _]
tmp +.= dbW3s_lts.unsqueeze(0)
tmp = dbU3s[layer, _, _]
tmp +.= dbU3s_lts.unsqueeze(0)