forked from BrainJS/brain.js
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlstm.js
132 lines (121 loc) · 3.27 KB
/
lstm.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import Matrix from './matrix';
import RandomMatrix from './matrix/random-matrix';
import RNN from './rnn';
export default class LSTM extends RNN {
getModel(hiddenSize, prevSize) {
return {
// gates parameters
//wix
inputMatrix: new RandomMatrix(hiddenSize, prevSize, 0.08),
//wih
inputHidden: new RandomMatrix(hiddenSize, hiddenSize, 0.08),
//bi
inputBias: new Matrix(hiddenSize, 1),
//wfx
forgetMatrix: new RandomMatrix(hiddenSize, prevSize, 0.08),
//wfh
forgetHidden: new RandomMatrix(hiddenSize, hiddenSize, 0.08),
//bf
forgetBias: new Matrix(hiddenSize, 1),
//wox
outputMatrix: new RandomMatrix(hiddenSize, prevSize, 0.08),
//woh
outputHidden: new RandomMatrix(hiddenSize, hiddenSize, 0.08),
//bo
outputBias: new Matrix(hiddenSize, 1),
// cell write params
//wcx
cellActivationMatrix: new RandomMatrix(hiddenSize, prevSize, 0.08),
//wch
cellActivationHidden: new RandomMatrix(hiddenSize, hiddenSize, 0.08),
//bc
cellActivationBias: new Matrix(hiddenSize, 1)
};
}
/**
*
* @param {Equation} equation
* @param {Matrix} inputMatrix
* @param {Matrix} previousResult
* @param {Object} hiddenLayer
* @returns {Matrix}
*/
getEquation(equation, inputMatrix, previousResult, hiddenLayer) {
let sigmoid = equation.sigmoid.bind(equation);
let add = equation.add.bind(equation);
let multiply = equation.multiply.bind(equation);
let multiplyElement = equation.multiplyElement.bind(equation);
let tanh = equation.tanh.bind(equation);
let inputGate = sigmoid(
add(
add(
multiply(
hiddenLayer.inputMatrix,
inputMatrix
),
multiply(
hiddenLayer.inputHidden,
previousResult
)
),
hiddenLayer.inputBias
)
);
let forgetGate = sigmoid(
add(
add(
multiply(
hiddenLayer.forgetMatrix,
inputMatrix
),
multiply(
hiddenLayer.forgetHidden,
previousResult
)
),
hiddenLayer.forgetBias
)
);
// output gate
let outputGate = sigmoid(
add(
add(
multiply(
hiddenLayer.outputMatrix,
inputMatrix
),
multiply(
hiddenLayer.outputHidden,
previousResult
)
),
hiddenLayer.outputBias
)
);
// write operation on cells
let cellWrite = tanh(
add(
add(
multiply(
hiddenLayer.cellActivationMatrix,
inputMatrix
),
multiply(
hiddenLayer.cellActivationHidden,
previousResult
)
),
hiddenLayer.cellActivationBias
)
);
// compute new cell activation
let retainCell = multiplyElement(forgetGate, previousResult); // what do we keep from cell
let writeCell = multiplyElement(inputGate, cellWrite); // what do we write to cell
let cell = add(retainCell, writeCell); // new cell contents
// compute hidden state as gated, saturated cell activations
return multiplyElement(
outputGate,
tanh(cell)
);
}
}