Skip to content

Improvements thanks to #1 feedback #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 11, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 54 additions & 19 deletions 01_XOR/neuralnetwork.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,21 @@
// Data helper classes

// class ml5.Data

// ml5.Data?
class Data {
// Need to deal with shape
constructor(data) {
this.xs = tf.tensor2d(data.inputs);
this.ys = tf.tensor2d(data.targets);
}
}


// Helper class for a Batch of Data: ml5.Batch
class Batch {
constructor() {
// Need to deal with shape
// this.shape = ??;
this.data = [];
}
Expand All @@ -12,7 +28,7 @@ class Batch {

class NeuralNetwork {

constructor(inputs, hidden, outputs) {
constructor(inputs, hidden, outputs, lr) {
this.model = tf.sequential();
const hiddenLayer = tf.layers.dense({
units: hidden,
Expand All @@ -21,13 +37,14 @@ class NeuralNetwork {
});
const outputLayer = tf.layers.dense({
units: outputs,
inputShape: [hidden],
// inferred
// inputShape: [hidden],
activation: 'sigmoid'
});
this.model.add(hiddenLayer);
this.model.add(outputLayer);

const LEARNING_RATE = 0.5;
const LEARNING_RATE = lr || 0.5;
const optimizer = tf.train.sgd(LEARNING_RATE);

this.model.compile({
Expand All @@ -38,28 +55,46 @@ class NeuralNetwork {
}

predict(inputs) {
if (inputs instanceof Batch) {
return tf.tidy(() => {
const xs = tf.tensor2d(inputs.data);
return this.model.predict(xs).dataSync();
});
return tf.tidy(() => {
let data;
if (inputs instanceof Batch) {
data = inputs.data;
} else {
data = [inputs];
}
const xs = tf.tensor2d(data);
return this.model.predict(xs).dataSync();
});
}

setTrainingData(data) {
if (data instanceof Data) {
this.trainingData = data;
} else {
return tf.tidy(() => {
const xs = tf.tensor2d([inputs]);
return this.model.predict(xs).dataSync();
});
this.trainingData = new Data(data);
}
}

async train(data, epochs, callback) {
const xs = tf.tensor2d(data.inputs);
const ys = tf.tensor2d(data.targets);
async train(callback, epochs, data) {
let xs, ys;
if (data) {
xs = tf.tensor2d(data.inputs);
ys = tf.tensor2d(data.targets);
} else if (this.trainingData) {
xs = this.trainingData.xs;
ys = this.trainingData.ys;
} else {
console.log("I have no data!");
return;
}
await this.model.fit(xs, ys, {
epochs: epochs,
epochs: epochs || 1,
shuffle: true
});
xs.dispose();
ys.dispose();
if (data) {
xs.dispose();
ys.dispose();
}
callback();
}
}
}
7 changes: 4 additions & 3 deletions 01_XOR/sketch.js
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,21 @@ let counter = 0;
let training = true;

function train() {
nn.train(data, 10, finished);
nn.train(finished);
}

function finished() {
counter++;
statusP.html('training pass: ' + counter + '<br>framerate: ' + floor(frameRate()));
setTimeout(train, 10);
train();
}

let statusP;

function setup() {
createCanvas(400, 400);
nn = new NeuralNetwork(2, 2, 1);
nn.setTrainingData(data);
train();
statusP = createP('0');
}
Expand Down Expand Up @@ -66,4 +67,4 @@ function draw() {
}
}
// }
}
}