diff --git a/lib/nn.js b/lib/nn.js index c6e3ad2..5cce746 100644 --- a/lib/nn.js +++ b/lib/nn.js @@ -1,13 +1,21 @@ // Other techniques for learning -function sigmoid(x) { - return 1 / (1 + Math.exp(-x)); +class ActivationFunction{ + constructor(func, dfunc){ + this.func = func; + this.dfunc = dfunc; + } } -function dsigmoid(y) { - // return sigmoid(x) * (1 - sigmoid(x)); - return y * (1 - y); -} +let sigmoid = new ActivationFunction( + x => 1 / (1 + Math.exp(-x)), + y => y * (1- y) +); + +let tanh = new ActivationFunction( + x => Math.tanh(x), + y => 1-(y*y) +); class NeuralNetwork { @@ -26,10 +34,9 @@ class NeuralNetwork { this.bias_h.randomize(); this.bias_o.randomize(); this.setLearningRate(); - + this.setActivationFunction(); - this.setDActivationFunction(); - + } predict(input_array) { @@ -39,41 +46,37 @@ class NeuralNetwork { let hidden = Matrix.multiply(this.weights_ih, inputs); hidden.add(this.bias_h); // activation function! - hidden.map(this.activation_function); + hidden.map(this.activation_function.func); // Generating the output's output! let output = Matrix.multiply(this.weights_ho, hidden); output.add(this.bias_o); - output.map(this.activation_function); + output.map(this.activation_function.func); // Sending back to the caller! return output.toArray(); } - + setLearningRate(learning_rate = 0.1) { - this.learning_rate = learning_rate; + this.learning_rate = learning_rate; } - - setActivationFunction(Fun = sigmoid) { - this.activation_function = Fun; - } - - setDActivationFunction(dFun = dsigmoid) { - this.d_activation_function = dFun; + + setActivationFunction(func = sigmoid) { + this.activation_function = func; } - train(input_array, target_array) { + train(input_array, target_array) { // Generating the Hidden Outputs let inputs = Matrix.fromArray(input_array); let hidden = Matrix.multiply(this.weights_ih, inputs); hidden.add(this.bias_h); // activation function! - hidden.map(this.activation_function); + hidden.map(this.activation_function.func); // Generating the output's output! let outputs = Matrix.multiply(this.weights_ho, hidden); outputs.add(this.bias_o); - outputs.map(this.activation_function); + outputs.map(this.activation_function.func); // Convert array to matrix object let targets = Matrix.fromArray(target_array); @@ -84,7 +87,7 @@ class NeuralNetwork { // let gradient = outputs * (1 - outputs); // Calculate gradient - let gradients = Matrix.map(outputs, this.d_activation_function); + let gradients = Matrix.map(outputs, this.activation_function.dfunc); gradients.multiply(output_errors); gradients.multiply(this.learning_rate); @@ -103,7 +106,7 @@ class NeuralNetwork { let hidden_errors = Matrix.multiply(who_t, output_errors); // Calculate hidden gradient - let hidden_gradient = Matrix.map(hidden, this.d_activation_function); + let hidden_gradient = Matrix.map(hidden, this.activation_function.dfunc); hidden_gradient.multiply(hidden_errors); hidden_gradient.multiply(this.learning_rate);