diff --git a/deploy.sh b/deploy.sh
index b96f8f19e..5ffb8766f 100755
--- a/deploy.sh
+++ b/deploy.sh
@@ -1,5 +1,5 @@
#!/usr/bin/env bash
-# Copyright 2018 Google Inc. All Rights Reserved.
+# Copyright 2018 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/iris/build-resources.sh b/iris/build-resources.sh
new file mode 100755
index 000000000..708383b23
--- /dev/null
+++ b/iris/build-resources.sh
@@ -0,0 +1,61 @@
+#!/usr/bin/env bash
+
+# Copyright 2018 Google LLC. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+# Builds resources for the Iris demo.
+# Note this is not necessary to run the demo, because we already provide hosted
+# pre-built resources.
+# Usage example: do this from the 'iris' directory:
+# ./build-resources.sh
+
+set -e
+
+DEMO_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+
+TRAIN_EPOCHS=100
+while true; do
+ if [[ "$1" == "--epochs" ]]; then
+ TRAIN_EPOCHS=$2
+ shift 2
+ elif [[ -z "$1" ]]; then
+ break
+ else
+ echo "ERROR: Unrecognized argument: $1"
+ exit 1
+ fi
+done
+
+RESOURCES_ROOT="${DEMO_DIR}/dist/resources"
+rm -rf "${RESOURCES_ROOT}"
+mkdir -p "${RESOURCES_ROOT}"
+
+# Run Python script to generate the pretrained model and weights files.
+# Make sure you install the tensorflowjs pip package first.
+
+python "${DEMO_DIR}/python/iris.py" \
+ --epochs "${TRAIN_EPOCHS}" \
+ --artifacts_dir "${RESOURCES_ROOT}"
+
+cd ${DEMO_DIR}
+yarn
+yarn build
+
+echo
+echo "-----------------------------------------------------------"
+echo "Resources written to ${RESOURCES_ROOT}."
+echo "You can now run the demo with 'yarn watch'."
+echo "-----------------------------------------------------------"
+echo
diff --git a/iris/index.html b/iris/index.html
index 47ab5a3af..338b9f745 100644
--- a/iris/index.html
+++ b/iris/index.html
@@ -117,7 +117,8 @@
TensorFlow.js Layers: Iris Demo
-
+
+
diff --git a/iris/index.js b/iris/index.js
index 9ded1c69e..ba9750562 100644
--- a/iris/index.js
+++ b/iris/index.js
@@ -17,28 +17,12 @@
import * as tf from '@tensorflow/tfjs';
-import {getIrisData, IRIS_CLASSES, IRIS_NUM_CLASSES} from './data';
-import {clearEvaluateTable, getManualInputData, loadTrainParametersFromUI, plotAccuracies, plotLosses, renderEvaluateTable, renderLogitsForManualInput, setManualInputWinnerMessage, status, wireUpEvaluateTableCallbacks} from './ui';
+import * as data from './data';
+import * as loader from './loader';
+import * as ui from './ui';
let model;
-/**
- * Load pretrained model stored at a remote URL.
- *
- * @return An instance of `tf.Model` with model topology and weights loaded.
- */
-async function loadHostedPretrainedModel() {
- const HOSTED_MODEL_JSON_URL =
- 'https://storage.googleapis.com/tfjs-models/tfjs/iris_v1/model.json';
- status('Loading pretrained model from ' + HOSTED_MODEL_JSON_URL);
- try {
- model = await tf.loadModel(HOSTED_MODEL_JSON_URL);
- status('Done loading pretrained model.');
- } catch (err) {
- status('Loading pretrained model failed.');
- }
-}
-
/**
* Train a `tf.Model` to recognize Iris flower type.
*
@@ -53,9 +37,9 @@ async function loadHostedPretrainedModel() {
* @returns The trained `tf.Model` instance.
*/
async function trainModel(xTrain, yTrain, xTest, yTest) {
- status('Training model... Please wait.');
+ ui.status('Training model... Please wait.');
- const params = loadTrainParametersFromUI();
+ const params = ui.loadTrainParametersFromUI();
// Define the topology of the model: two dense layers.
const model = tf.sequential();
@@ -79,8 +63,8 @@ async function trainModel(xTrain, yTrain, xTest, yTest) {
callbacks: {
onEpochEnd: async (epoch, logs) => {
// Plot the loss and accuracy values at the end of every training epoch.
- plotLosses(lossValues, epoch, logs.loss, logs.val_loss);
- plotAccuracies(accuracyValues, epoch, logs.acc, logs.val_acc);
+ ui.plotLosses(lossValues, epoch, logs.loss, logs.val_loss);
+ ui.plotAccuracies(accuracyValues, epoch, logs.acc, logs.val_acc);
// Await web page DOM to refresh for the most recently plotted values.
await tf.nextFrame();
@@ -88,7 +72,7 @@ async function trainModel(xTrain, yTrain, xTest, yTest) {
}
});
- status('Model training complete.');
+ ui.status('Model training complete.');
return model;
}
@@ -99,7 +83,7 @@ async function trainModel(xTrain, yTrain, xTest, yTest) {
*/
async function predictOnManualInput(model) {
if (model == null) {
- setManualInputWinnerMessage('ERROR: Please load or train model first.');
+ ui.setManualInputWinnerMessage('ERROR: Please load or train model first.');
return;
}
@@ -107,7 +91,7 @@ async function predictOnManualInput(model) {
// `predict` call is released at the end.
tf.tidy(() => {
// Prepare input data as a 2D `tf.Tensor`.
- const inputData = getManualInputData();
+ const inputData = ui.getManualInputData();
const input = tf.tensor2d([inputData], [1, 4]);
// Call `model.predict` to get the prediction output as probabilities for
@@ -115,9 +99,9 @@ async function predictOnManualInput(model) {
const predictOut = model.predict(input);
const logits = Array.from(predictOut.dataSync());
- const winner = IRIS_CLASSES[predictOut.argMax(-1).dataSync()[0]];
- setManualInputWinnerMessage(winner);
- renderLogitsForManualInput(logits);
+ const winner = data.IRIS_CLASSES[predictOut.argMax(-1).dataSync()[0]];
+ ui.setManualInputWinnerMessage(winner);
+ ui.renderLogitsForManualInput(logits);
});
}
@@ -130,24 +114,29 @@ async function predictOnManualInput(model) {
* [numTestExamples, 3].
*/
async function evaluateModelOnTestData(model, xTest, yTest) {
- clearEvaluateTable();
+ ui.clearEvaluateTable();
tf.tidy(() => {
const xData = xTest.dataSync();
const yTrue = yTest.argMax(-1).dataSync();
const predictOut = model.predict(xTest);
const yPred = predictOut.argMax(-1);
- renderEvaluateTable(xData, yTrue, yPred.dataSync(), predictOut.dataSync());
+ ui.renderEvaluateTable(
+ xData, yTrue, yPred.dataSync(), predictOut.dataSync());
});
predictOnManualInput(model);
}
+const LOCAL_MODEL_JSON_URL = 'http://localhost:1235/resources/model.json';
+const HOSTED_MODEL_JSON_URL =
+ 'https://storage.googleapis.com/tfjs-models/tfjs/iris_v1/model.json';
+
/**
* The main function of the Iris demo.
*/
async function iris() {
- const [xTrain, yTrain, xTest, yTest] = getIrisData(0.15);
+ const [xTrain, yTrain, xTest, yTest] = data.getIrisData(0.15);
document.getElementById('train-from-scratch')
.addEventListener('click', async () => {
@@ -155,14 +144,32 @@ async function iris() {
evaluateModelOnTestData(model, xTest, yTest);
});
- document.getElementById('load-pretrained')
- .addEventListener('click', async () => {
- clearEvaluateTable();
- await loadHostedPretrainedModel();
- predictOnManualInput(model);
- });
+ if (await loader.urlExists(HOSTED_MODEL_JSON_URL)) {
+ ui.status('Model available: ' + HOSTED_MODEL_JSON_URL);
+ const button = document.getElementById('load-pretrained-remote');
+ button.addEventListener('click', async () => {
+ ui.clearEvaluateTable();
+ model = await loader.loadHostedPretrainedModel(HOSTED_MODEL_JSON_URL);
+ predictOnManualInput(model);
+ });
+ // button.style.visibility = 'visible';
+ button.style.display = 'inline-block';
+ }
+
+ if (await loader.urlExists(LOCAL_MODEL_JSON_URL)) {
+ ui.status('Model available: ' + LOCAL_MODEL_JSON_URL);
+ const button = document.getElementById('load-pretrained-local');
+ button.addEventListener('click', async () => {
+ ui.clearEvaluateTable();
+ model = await loader.loadHostedPretrainedModel(LOCAL_MODEL_JSON_URL);
+ predictOnManualInput(model);
+ });
+ // button.style.visibility = 'visible';
+ button.style.display = 'inline-block';
+ }
- wireUpEvaluateTableCallbacks(() => predictOnManualInput(model));
+ ui.status('Standing by.');
+ ui.wireUpEvaluateTableCallbacks(() => predictOnManualInput(model));
}
iris();
diff --git a/iris/loader.js b/iris/loader.js
new file mode 100644
index 000000000..51296e4e9
--- /dev/null
+++ b/iris/loader.js
@@ -0,0 +1,53 @@
+/**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+
+import * as tf from '@tensorflow/tfjs';
+import * as ui from './ui';
+
+/**
+ * Test whether a given URL is retrievable.
+ */
+export async function urlExists(url) {
+ ui.status('Testing url ' + url);
+ try {
+ const response = await fetch(url, {method: 'HEAD'});
+ return response.ok;
+ } catch (err) {
+ return false;
+ }
+}
+
+/**
+ * Load pretrained model stored at a remote URL.
+ *
+ * @return An instance of `tf.Model` with model topology and weights loaded.
+ */
+export async function loadHostedPretrainedModel(url) {
+ ui.status('Loading pretrained model from ' + url);
+ try {
+ const model = await tf.loadModel(url);
+ ui.status('Done loading pretrained model.');
+ // We can't load a model twice due to
+ // https://github.com/tensorflow/tfjs/issues/34
+ // Therefore we remove the load buttons to avoid user confusion.
+ ui.disableLoadModelButtons();
+ return model;
+ } catch (err) {
+ console.error(err);
+ ui.status('Loading pretrained model failed.');
+ }
+}
diff --git a/iris/package.json b/iris/package.json
index a0d3c4a43..855c3384f 100644
--- a/iris/package.json
+++ b/iris/package.json
@@ -13,15 +13,16 @@
"vega-embed": "^3.0.0"
},
"scripts": {
- "watch": "NODE_ENV=development parcel --no-hmr --open index.html ",
- "build": "NODE_ENV=production parcel build index.html --no-minify --public-url ./"
+ "watch": "./serve.sh",
+ "build": "NODE_ENV=production parcel build index.html --no-minify --public-url /"
},
"devDependencies": {
"babel-plugin-transform-runtime": "~6.23.0",
"babel-polyfill": "~6.26.0",
"babel-preset-env": "~1.6.1",
"clang-format": "~1.2.2",
- "parcel-bundler": "~1.6.2"
+ "http-server": "~0.10.0",
+ "parcel-bundler": "~1.7.0"
},
"babel": {
"presets": [
diff --git a/iris/python/__init__.py b/iris/python/__init__.py
new file mode 100644
index 000000000..636b70f0d
--- /dev/null
+++ b/iris/python/__init__.py
@@ -0,0 +1,18 @@
+# Copyright 2018 Google LLC. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
diff --git a/iris/python/iris.py b/iris/python/iris.py
new file mode 100644
index 000000000..a8f882349
--- /dev/null
+++ b/iris/python/iris.py
@@ -0,0 +1,102 @@
+# Copyright 2018 Google LLC. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+"""Train a simple model for the Iris dataset; Export the model and weights."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+
+import keras
+import numpy as np
+import tensorflowjs as tfjs
+
+import iris_data
+
+
+def train(epochs,
+ artifacts_dir,
+ sequential=False):
+ """Train a Keras model for Iris data classification and save result as JSON.
+
+ Args:
+ epochs: Number of epochs to traing the Keras model for.
+ artifacts_dir: Directory to save the model artifacts (model topology JSON,
+ weights and weight manifest) in.
+ sequential: Whether to use a Keras Sequential model, instead of the default
+ functional model.
+
+ Returns:
+ Final classification accuracy on the training set.
+ """
+ data_x, data_y = iris_data.load()
+
+ if sequential:
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(
+ 10, input_shape=[data_x.shape[1]], use_bias=True, activation='sigmoid',
+ name='Dense1'))
+ model.add(keras.layers.Dense(
+ 3, use_bias=True, activation='softmax', name='Dense2'))
+ else:
+ iris_x = keras.layers.Input((4,))
+ dense1 = keras.layers.Dense(
+ 10, use_bias=True, name='Dense1', activation='sigmoid')(iris_x)
+ dense2 = keras.layers.Dense(
+ 3, use_bias=True, name='Dense2', activation='softmax')(dense1)
+ # pylint:disable=redefined-variable-type
+ model = keras.models.Model(inputs=[iris_x], outputs=[dense2])
+ # pylint:enable=redefined-variable-type
+ model.compile(loss='categorical_crossentropy', optimizer='adam')
+
+ model.fit(data_x, data_y, batch_size=8, epochs=epochs)
+
+ # Run prediction on the training set.
+ pred_ys = np.argmax(model.predict(data_x), axis=1)
+ true_ys = np.argmax(data_y, axis=1)
+ final_train_accuracy = np.mean((pred_ys == true_ys).astype(np.float32))
+ print('Accuracy on the training set: %g' % final_train_accuracy)
+
+ tfjs.converters.save_keras_model(model, artifacts_dir)
+
+ return final_train_accuracy
+
+
+def main():
+ train(FLAGS.epochs, FLAGS.artifacts_dir, sequential=FLAGS.sequential)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser('Iris model training and serialization')
+ parser.add_argument(
+ '--sequential',
+ action='store_true',
+ help='Use a Keras Sequential model, instead of the default functional '
+ 'model.')
+ parser.add_argument(
+ '--epochs',
+ type=int,
+ default=100,
+ help='Number of epochs to train the Keras model for.')
+ parser.add_argument(
+ '--artifacts_dir',
+ type=str,
+ default='/tmp/iris.keras',
+ help='Local path for saving the TensorFlow.js artifacts.')
+
+ FLAGS, _ = parser.parse_known_args()
+ main()
diff --git a/iris/python/iris_data.py b/iris/python/iris_data.py
new file mode 100644
index 000000000..c72e39f71
--- /dev/null
+++ b/iris/python/iris_data.py
@@ -0,0 +1,217 @@
+# Copyright 2018 Google LLC. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+"""Iris dataset (see https://en.wikipedia.org/wiki/Iris_flower_data_set)."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+IRIS_CLASSES = ('setosa', 'versicolor', 'virginica')
+IRIS_DATA = [
+ '5.1,3.5,1.4,0.2,Iris-setosa',
+ '4.9,3.0,1.4,0.2,Iris-setosa',
+ '4.7,3.2,1.3,0.2,Iris-setosa',
+ '4.6,3.1,1.5,0.2,Iris-setosa',
+ '5.0,3.6,1.4,0.2,Iris-setosa',
+ '5.4,3.9,1.7,0.4,Iris-setosa',
+ '4.6,3.4,1.4,0.3,Iris-setosa',
+ '5.0,3.4,1.5,0.2,Iris-setosa',
+ '4.4,2.9,1.4,0.2,Iris-setosa',
+ '4.9,3.1,1.5,0.1,Iris-setosa',
+ '5.4,3.7,1.5,0.2,Iris-setosa',
+ '4.8,3.4,1.6,0.2,Iris-setosa',
+ '4.8,3.0,1.4,0.1,Iris-setosa',
+ '4.3,3.0,1.1,0.1,Iris-setosa',
+ '5.8,4.0,1.2,0.2,Iris-setosa',
+ '5.7,4.4,1.5,0.4,Iris-setosa',
+ '5.4,3.9,1.3,0.4,Iris-setosa',
+ '5.1,3.5,1.4,0.3,Iris-setosa',
+ '5.7,3.8,1.7,0.3,Iris-setosa',
+ '5.1,3.8,1.5,0.3,Iris-setosa',
+ '5.4,3.4,1.7,0.2,Iris-setosa',
+ '5.1,3.7,1.5,0.4,Iris-setosa',
+ '4.6,3.6,1.0,0.2,Iris-setosa',
+ '5.1,3.3,1.7,0.5,Iris-setosa',
+ '4.8,3.4,1.9,0.2,Iris-setosa',
+ '5.0,3.0,1.6,0.2,Iris-setosa',
+ '5.0,3.4,1.6,0.4,Iris-setosa',
+ '5.2,3.5,1.5,0.2,Iris-setosa',
+ '5.2,3.4,1.4,0.2,Iris-setosa',
+ '4.7,3.2,1.6,0.2,Iris-setosa',
+ '4.8,3.1,1.6,0.2,Iris-setosa',
+ '5.4,3.4,1.5,0.4,Iris-setosa',
+ '5.2,4.1,1.5,0.1,Iris-setosa',
+ '5.5,4.2,1.4,0.2,Iris-setosa',
+ '4.9,3.1,1.5,0.1,Iris-setosa',
+ '5.0,3.2,1.2,0.2,Iris-setosa',
+ '5.5,3.5,1.3,0.2,Iris-setosa',
+ '4.9,3.1,1.5,0.1,Iris-setosa',
+ '4.4,3.0,1.3,0.2,Iris-setosa',
+ '5.1,3.4,1.5,0.2,Iris-setosa',
+ '5.0,3.5,1.3,0.3,Iris-setosa',
+ '4.5,2.3,1.3,0.3,Iris-setosa',
+ '4.4,3.2,1.3,0.2,Iris-setosa',
+ '5.0,3.5,1.6,0.6,Iris-setosa',
+ '5.1,3.8,1.9,0.4,Iris-setosa',
+ '4.8,3.0,1.4,0.3,Iris-setosa',
+ '5.1,3.8,1.6,0.2,Iris-setosa',
+ '4.6,3.2,1.4,0.2,Iris-setosa',
+ '5.3,3.7,1.5,0.2,Iris-setosa',
+ '5.0,3.3,1.4,0.2,Iris-setosa',
+ '7.0,3.2,4.7,1.4,Iris-versicolor',
+ '6.4,3.2,4.5,1.5,Iris-versicolor',
+ '6.9,3.1,4.9,1.5,Iris-versicolor',
+ '5.5,2.3,4.0,1.3,Iris-versicolor',
+ '6.5,2.8,4.6,1.5,Iris-versicolor',
+ '5.7,2.8,4.5,1.3,Iris-versicolor',
+ '6.3,3.3,4.7,1.6,Iris-versicolor',
+ '4.9,2.4,3.3,1.0,Iris-versicolor',
+ '6.6,2.9,4.6,1.3,Iris-versicolor',
+ '5.2,2.7,3.9,1.4,Iris-versicolor',
+ '5.0,2.0,3.5,1.0,Iris-versicolor',
+ '5.9,3.0,4.2,1.5,Iris-versicolor',
+ '6.0,2.2,4.0,1.0,Iris-versicolor',
+ '6.1,2.9,4.7,1.4,Iris-versicolor',
+ '5.6,2.9,3.6,1.3,Iris-versicolor',
+ '6.7,3.1,4.4,1.4,Iris-versicolor',
+ '5.6,3.0,4.5,1.5,Iris-versicolor',
+ '5.8,2.7,4.1,1.0,Iris-versicolor',
+ '6.2,2.2,4.5,1.5,Iris-versicolor',
+ '5.6,2.5,3.9,1.1,Iris-versicolor',
+ '5.9,3.2,4.8,1.8,Iris-versicolor',
+ '6.1,2.8,4.0,1.3,Iris-versicolor',
+ '6.3,2.5,4.9,1.5,Iris-versicolor',
+ '6.1,2.8,4.7,1.2,Iris-versicolor',
+ '6.4,2.9,4.3,1.3,Iris-versicolor',
+ '6.6,3.0,4.4,1.4,Iris-versicolor',
+ '6.8,2.8,4.8,1.4,Iris-versicolor',
+ '6.7,3.0,5.0,1.7,Iris-versicolor',
+ '6.0,2.9,4.5,1.5,Iris-versicolor',
+ '5.7,2.6,3.5,1.0,Iris-versicolor',
+ '5.5,2.4,3.8,1.1,Iris-versicolor',
+ '5.5,2.4,3.7,1.0,Iris-versicolor',
+ '5.8,2.7,3.9,1.2,Iris-versicolor',
+ '6.0,2.7,5.1,1.6,Iris-versicolor',
+ '5.4,3.0,4.5,1.5,Iris-versicolor',
+ '6.0,3.4,4.5,1.6,Iris-versicolor',
+ '6.7,3.1,4.7,1.5,Iris-versicolor',
+ '6.3,2.3,4.4,1.3,Iris-versicolor',
+ '5.6,3.0,4.1,1.3,Iris-versicolor',
+ '5.5,2.5,4.0,1.3,Iris-versicolor',
+ '5.5,2.6,4.4,1.2,Iris-versicolor',
+ '6.1,3.0,4.6,1.4,Iris-versicolor',
+ '5.8,2.6,4.0,1.2,Iris-versicolor',
+ '5.0,2.3,3.3,1.0,Iris-versicolor',
+ '5.6,2.7,4.2,1.3,Iris-versicolor',
+ '5.7,3.0,4.2,1.2,Iris-versicolor',
+ '5.7,2.9,4.2,1.3,Iris-versicolor',
+ '6.2,2.9,4.3,1.3,Iris-versicolor',
+ '5.1,2.5,3.0,1.1,Iris-versicolor',
+ '5.7,2.8,4.1,1.3,Iris-versicolor',
+ '6.3,3.3,6.0,2.5,Iris-virginica',
+ '5.8,2.7,5.1,1.9,Iris-virginica',
+ '7.1,3.0,5.9,2.1,Iris-virginica',
+ '6.3,2.9,5.6,1.8,Iris-virginica',
+ '6.5,3.0,5.8,2.2,Iris-virginica',
+ '7.6,3.0,6.6,2.1,Iris-virginica',
+ '4.9,2.5,4.5,1.7,Iris-virginica',
+ '7.3,2.9,6.3,1.8,Iris-virginica',
+ '6.7,2.5,5.8,1.8,Iris-virginica',
+ '7.2,3.6,6.1,2.5,Iris-virginica',
+ '6.5,3.2,5.1,2.0,Iris-virginica',
+ '6.4,2.7,5.3,1.9,Iris-virginica',
+ '6.8,3.0,5.5,2.1,Iris-virginica',
+ '5.7,2.5,5.0,2.0,Iris-virginica',
+ '5.8,2.8,5.1,2.4,Iris-virginica',
+ '6.4,3.2,5.3,2.3,Iris-virginica',
+ '6.5,3.0,5.5,1.8,Iris-virginica',
+ '7.7,3.8,6.7,2.2,Iris-virginica',
+ '7.7,2.6,6.9,2.3,Iris-virginica',
+ '6.0,2.2,5.0,1.5,Iris-virginica',
+ '6.9,3.2,5.7,2.3,Iris-virginica',
+ '5.6,2.8,4.9,2.0,Iris-virginica',
+ '7.7,2.8,6.7,2.0,Iris-virginica',
+ '6.3,2.7,4.9,1.8,Iris-virginica',
+ '6.7,3.3,5.7,2.1,Iris-virginica',
+ '7.2,3.2,6.0,1.8,Iris-virginica',
+ '6.2,2.8,4.8,1.8,Iris-virginica',
+ '6.1,3.0,4.9,1.8,Iris-virginica',
+ '6.4,2.8,5.6,2.1,Iris-virginica',
+ '7.2,3.0,5.8,1.6,Iris-virginica',
+ '7.4,2.8,6.1,1.9,Iris-virginica',
+ '7.9,3.8,6.4,2.0,Iris-virginica',
+ '6.4,2.8,5.6,2.2,Iris-virginica',
+ '6.3,2.8,5.1,1.5,Iris-virginica',
+ '6.1,2.6,5.6,1.4,Iris-virginica',
+ '7.7,3.0,6.1,2.3,Iris-virginica',
+ '6.3,3.4,5.6,2.4,Iris-virginica',
+ '6.4,3.1,5.5,1.8,Iris-virginica',
+ '6.0,3.0,4.8,1.8,Iris-virginica',
+ '6.9,3.1,5.4,2.1,Iris-virginica',
+ '6.7,3.1,5.6,2.4,Iris-virginica',
+ '6.9,3.1,5.1,2.3,Iris-virginica',
+ '5.8,2.7,5.1,1.9,Iris-virginica',
+ '6.8,3.2,5.9,2.3,Iris-virginica',
+ '6.7,3.3,5.7,2.5,Iris-virginica',
+ '6.7,3.0,5.2,2.3,Iris-virginica',
+ '6.3,2.5,5.0,1.9,Iris-virginica',
+ '6.5,3.0,5.2,2.0,Iris-virginica',
+ '6.2,3.4,5.4,2.3,Iris-virginica',
+ '5.9,3.0,5.1,1.8,Iris-virginica']
+
+
+def load():
+ """Load Iris data.
+
+ Returns:
+ Iris data as a numpy ndarray of size [n, 4] and dtype `float32`, n being the
+ number of available samples.
+ Iris classification target as a numpy ndarray of [n, 3] and dtype `float32`.
+ The order of the data is randomly shuffled.
+ """
+ iris_x = []
+ iris_y = []
+ for line in IRIS_DATA:
+ items = line.split(',')
+ xs = [float(x) for x in items[:4]]
+ iris_x.append(xs)
+ assert items[-1].startswith('Iris-')
+ iris_y.append(IRIS_CLASSES.index(items[-1].replace('Iris-', '')))
+
+ # Randomly shuffle the data.
+ iris_xy = list(zip(iris_x, iris_y))
+ np.random.shuffle(iris_xy)
+ iris_x, iris_y = zip(*iris_xy)
+ return (np.array(iris_x, dtype=np.float32),
+ _to_one_hot(iris_y, len(IRIS_CLASSES)))
+
+
+def _to_one_hot(indices, num_classes):
+ """Convert indices to one-hot encoding.
+
+ Args:
+ indices: A list of `int` indices with length `n`, each eleemnt of which is
+ assumed to be a zero-based class index >= 0 and < `num_classes`.
+ num_classes: Total number of possible classes as an `int`.
+
+ Returns:
+ A numpy ndarray of shape [n, num_classes] and dtype `float32`.
+ """
+ one_hot = np.zeros([len(indices), num_classes], dtype=np.float32)
+ one_hot[np.arange(len(indices)), indices] = 1
+ return one_hot
diff --git a/iris/python/iris_data_test.py b/iris/python/iris_data_test.py
new file mode 100644
index 000000000..613c14332
--- /dev/null
+++ b/iris/python/iris_data_test.py
@@ -0,0 +1,42 @@
+# Copyright 2018 Google LLC. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+"""Test for the Iris dataset module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import unittest
+
+import numpy as np
+
+import iris_data
+
+class IrisDataTest(unittest.TestCase):
+
+ def testLoadData(self):
+ iris_x, iris_y = iris_data.load()
+ self.assertEqual(2, len(iris_x.shape))
+ self.assertGreater(iris_x.shape[0], 0)
+ self.assertEqual(4, iris_x.shape[1])
+ self.assertEqual(iris_x.shape[0], iris_y.shape[0])
+ self.assertEqual(3, iris_y.shape[1])
+ self.assertTrue(
+ np.allclose(np.ones([iris_y.shape[0], 1]), np.sum(iris_y, axis=1)))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/iris/python/iris_test.py b/iris/python/iris_test.py
new file mode 100644
index 000000000..3fed96d87
--- /dev/null
+++ b/iris/python/iris_test.py
@@ -0,0 +1,58 @@
+# Copyright 2018 Google LLC. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+"""Test for the Iris model."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+import os
+import shutil
+import tempfile
+import unittest
+
+import iris
+
+
+class IrisTest(unittest.TestCase):
+
+ def setUp(self):
+ self._tmp_dir = tempfile.mkdtemp()
+ super(IrisTest, self).setUp()
+
+ def tearDown(self):
+ if os.path.isdir(self._tmp_dir):
+ shutil.rmtree(self._tmp_dir)
+ super(IrisTest, self).tearDown()
+
+ def testTrainAndSaveNonSequential(self):
+ final_train_accuracy = iris.train(100, self._tmp_dir)
+ self.assertGreater(final_train_accuracy, 0.9)
+
+ # Check that the model json file is created.
+ json.load(open(os.path.join(self._tmp_dir, 'model.json'), 'rt'))
+
+ def testTrainAndSaveSequential(self):
+ final_train_accuracy = iris.train(100, self._tmp_dir, sequential=True)
+ self.assertGreater(final_train_accuracy, 0.9)
+
+ # Check that the model json file is created.
+ json.load(open(os.path.join(self._tmp_dir, 'model.json'), 'rt'))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/iris/serve.sh b/iris/serve.sh
new file mode 100755
index 000000000..e5ede63ed
--- /dev/null
+++ b/iris/serve.sh
@@ -0,0 +1,42 @@
+#!/usr/bin/env bash
+
+# Copyright 2018 Google LLC. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+# This script starts two HTTP servers on different ports:
+# * Port 1234 (using parcel) serves HTML and JavaScript.
+# * Port 1235 (using http-server) serves pretrained model resources.
+#
+# The reason for this arrangement is that Parcel currently has a limitation that
+# prevents it from serving the pretrained models; see
+# https://github.com/parcel-bundler/parcel/issues/1098. Once that issue is
+# resolved, a single Parcel server will be sufficient.
+
+NODE_ENV=development
+RESOURCE_PORT=1235
+
+# Ensure that http-server is available
+yarn
+
+echo Starting the pretrained model server...
+node_modules/http-server/bin/http-server dist --cors -p "${RESOURCE_PORT}" > /dev/null & HTTP_SERVER_PID=$!
+
+echo Starting the example html/js server...
+# This uses port 1234 by default.
+node_modules/parcel-bundler/bin/cli.js serve -d dist --open --no-hmr --public-url / index.html
+
+# When the Parcel server exits, kill the http-server too.
+kill $HTTP_SERVER_PID
+
diff --git a/iris/ui.js b/iris/ui.js
index ed080104e..6c62521e6 100644
--- a/iris/ui.js
+++ b/iris/ui.js
@@ -227,5 +227,11 @@ export function loadTrainParametersFromUI() {
}
export function status(statusText) {
+ console.log(statusText);
document.getElementById('demo-status').textContent = statusText;
}
+
+export function disableLoadModelButtons() {
+ document.getElementById('load-pretrained-remote').style.display = 'none';
+ document.getElementById('load-pretrained-local').style.display = 'none';
+}
diff --git a/iris/yarn.lock b/iris/yarn.lock
index d5fdf748a..024624301 100644
--- a/iris/yarn.lock
+++ b/iris/yarn.lock
@@ -705,6 +705,10 @@ binary-extensions@^1.0.0:
version "1.11.0"
resolved "https://registry.yarnpkg.com/binary-extensions/-/binary-extensions-1.11.0.tgz#46aa1751fb6a2f93ee5e689bb1087d4b14c6c205"
+bindings@~1.2.1:
+ version "1.2.1"
+ resolved "https://registry.yarnpkg.com/bindings/-/bindings-1.2.1.tgz#14ad6113812d2d37d72e67b4cacb4bb726505f11"
+
block-stream@*:
version "0.0.9"
resolved "https://registry.yarnpkg.com/block-stream/-/block-stream-0.0.9.tgz#13ebfe778a03205cfe03751481ebb4b3300c126a"
@@ -766,12 +770,6 @@ brorand@^1.0.1:
version "1.1.0"
resolved "https://registry.yarnpkg.com/brorand/-/brorand-1.1.0.tgz#12c25efe40a45e3c323eb8675a0a0ce57b22371f"
-browser-resolve@^1.11.2:
- version "1.11.2"
- resolved "https://registry.yarnpkg.com/browser-resolve/-/browser-resolve-1.11.2.tgz#8ff09b0a2c421718a1051c260b32e48f442938ce"
- dependencies:
- resolve "1.1.7"
-
browserify-aes@^1.0.0, browserify-aes@^1.0.4:
version "1.1.1"
resolved "https://registry.yarnpkg.com/browserify-aes/-/browserify-aes-1.1.1.tgz#38b7ab55edb806ff2dcda1a7f1620773a477c49f"
@@ -907,7 +905,7 @@ chalk@^1.1.3:
strip-ansi "^3.0.0"
supports-color "^2.0.0"
-chalk@^2.1.0, chalk@^2.3.1:
+chalk@^2.1.0, chalk@^2.3.2:
version "2.3.2"
resolved "https://registry.yarnpkg.com/chalk/-/chalk-2.3.2.tgz#250dc96b07491bfd601e648d66ddf5f60c7a5c65"
dependencies:
@@ -1038,6 +1036,10 @@ colormin@^1.0.5:
css-color-names "0.0.4"
has "^1.0.1"
+colors@1.0.3:
+ version "1.0.3"
+ resolved "https://registry.yarnpkg.com/colors/-/colors-1.0.3.tgz#0433f44d809680fdeb60ed260f1b0c262e82a40b"
+
colors@~1.1.2:
version "1.1.2"
resolved "https://registry.yarnpkg.com/colors/-/colors-1.1.2.tgz#168a4701756b6a7f51a12ce0c97bfa28c084ed63"
@@ -1117,6 +1119,10 @@ core-util-is@1.0.2, core-util-is@~1.0.0:
version "1.0.2"
resolved "https://registry.yarnpkg.com/core-util-is/-/core-util-is-1.0.2.tgz#b5fd54220aa2bc5ab57aab7140c940754503c1a7"
+corser@~2.0.0:
+ version "2.0.1"
+ resolved "https://registry.yarnpkg.com/corser/-/corser-2.0.1.tgz#8eda252ecaab5840dcd975ceb90d9370c819ff87"
+
create-ecdh@^4.0.0:
version "4.0.0"
resolved "https://registry.yarnpkg.com/create-ecdh/-/create-ecdh-4.0.0.tgz#888c723596cdf7612f6498233eebd7a35301737d"
@@ -1405,6 +1411,13 @@ date-now@^0.1.4:
version "0.1.4"
resolved "https://registry.yarnpkg.com/date-now/-/date-now-0.1.4.tgz#eaf439fd4d4848ad74e5cc7dbef200672b9e345b"
+deasync@^0.1.12:
+ version "0.1.12"
+ resolved "https://registry.yarnpkg.com/deasync/-/deasync-0.1.12.tgz#0159492a4133ab301d6c778cf01e74e63b10e549"
+ dependencies:
+ bindings "~1.2.1"
+ nan "^2.0.7"
+
debug@2.6.9, debug@^2.2.0, debug@^2.3.3, debug@^2.6.8:
version "2.6.9"
resolved "https://registry.yarnpkg.com/debug/-/debug-2.6.9.tgz#5d128515df134ff327e90a4c93f4e077a536341f"
@@ -1557,6 +1570,15 @@ ecc-jsbn@~0.1.1:
dependencies:
jsbn "~0.1.0"
+ecstatic@^2.0.0:
+ version "2.2.1"
+ resolved "https://registry.yarnpkg.com/ecstatic/-/ecstatic-2.2.1.tgz#b5087fad439dd9dd49d31e18131454817fe87769"
+ dependencies:
+ he "^1.1.1"
+ mime "^1.2.11"
+ minimist "^1.1.0"
+ url-join "^2.0.2"
+
editorconfig@^0.13.2:
version "0.13.3"
resolved "https://registry.yarnpkg.com/editorconfig/-/editorconfig-0.13.3.tgz#e5219e587951d60958fd94ea9a9a008cdeff1b34"
@@ -1662,6 +1684,10 @@ etag@~1.8.1:
version "1.8.1"
resolved "https://registry.yarnpkg.com/etag/-/etag-1.8.1.tgz#41ae2eeb65efa62268aebfea83ac7d79299b0887"
+eventemitter3@1.x.x:
+ version "1.2.0"
+ resolved "https://registry.yarnpkg.com/eventemitter3/-/eventemitter3-1.2.0.tgz#1c86991d816ad1e504750e73874224ecf3bec508"
+
events@^1.0.0:
version "1.1.1"
resolved "https://registry.yarnpkg.com/events/-/events-1.1.1.tgz#9ebdb7635ad099c70dcc4c2a1f5004288e8bd924"
@@ -1992,6 +2018,10 @@ hawk@3.1.3, hawk@~3.1.3:
hoek "2.x.x"
sntp "1.x.x"
+he@^1.1.1:
+ version "1.1.1"
+ resolved "https://registry.yarnpkg.com/he/-/he-1.1.1.tgz#93410fd21b009735151f8868c2f271f3427e23fd"
+
hmac-drbg@^1.0.0:
version "1.0.1"
resolved "https://registry.yarnpkg.com/hmac-drbg/-/hmac-drbg-1.0.1.tgz#d2745701025a6c775a6c545793ed502fc0c649a1"
@@ -2015,7 +2045,7 @@ html-comment-regex@^1.1.0:
version "1.1.1"
resolved "https://registry.yarnpkg.com/html-comment-regex/-/html-comment-regex-1.1.1.tgz#668b93776eaae55ebde8f3ad464b307a4963625e"
-htmlnano@^0.1.6:
+htmlnano@^0.1.7:
version "0.1.7"
resolved "https://registry.yarnpkg.com/htmlnano/-/htmlnano-0.1.7.tgz#1751937a05f122a3248dba1c63edf01f6d36cb84"
dependencies:
@@ -2046,6 +2076,26 @@ http-errors@~1.6.2:
setprototypeof "1.0.3"
statuses ">= 1.3.1 < 2"
+http-proxy@^1.8.1:
+ version "1.16.2"
+ resolved "https://registry.yarnpkg.com/http-proxy/-/http-proxy-1.16.2.tgz#06dff292952bf64dbe8471fa9df73066d4f37742"
+ dependencies:
+ eventemitter3 "1.x.x"
+ requires-port "1.x.x"
+
+http-server@~0.10.0:
+ version "0.10.0"
+ resolved "https://registry.yarnpkg.com/http-server/-/http-server-0.10.0.tgz#b2a446b16a9db87ed3c622ba9beb1b085b1234a7"
+ dependencies:
+ colors "1.0.3"
+ corser "~2.0.0"
+ ecstatic "^2.0.0"
+ http-proxy "^1.8.1"
+ opener "~1.4.0"
+ optimist "0.6.x"
+ portfinder "^1.0.13"
+ union "~0.4.3"
+
http-signature@~1.1.0:
version "1.1.1"
resolved "https://registry.yarnpkg.com/http-signature/-/http-signature-1.1.1.tgz#df72e267066cd0ac67fb76adf8e134a8fbcf91bf"
@@ -2537,6 +2587,10 @@ mime@1.4.1:
version "1.4.1"
resolved "https://registry.yarnpkg.com/mime/-/mime-1.4.1.tgz#121f9ebc49e3766f311a76e1fa1c8003c4b03aa6"
+mime@^1.2.11:
+ version "1.6.0"
+ resolved "https://registry.yarnpkg.com/mime/-/mime-1.6.0.tgz#32cd9e5c64553bd58d19a568af452acff04981b1"
+
mimic-fn@^1.0.0:
version "1.2.0"
resolved "https://registry.yarnpkg.com/mimic-fn/-/mimic-fn-1.2.0.tgz#820c86a39334640e99516928bd03fca88057d022"
@@ -2559,10 +2613,14 @@ minimist@0.0.8:
version "0.0.8"
resolved "https://registry.yarnpkg.com/minimist/-/minimist-0.0.8.tgz#857fcabfc3397d2625b8228262e86aa7a011b05d"
-minimist@^1.1.3, minimist@^1.2.0:
+minimist@^1.1.0, minimist@^1.1.3, minimist@^1.2.0:
version "1.2.0"
resolved "https://registry.yarnpkg.com/minimist/-/minimist-1.2.0.tgz#a35008b20f41383eec1fb914f4cd5df79a264284"
+minimist@~0.0.1:
+ version "0.0.10"
+ resolved "https://registry.yarnpkg.com/minimist/-/minimist-0.0.10.tgz#de3f98543dbf96082be48ad1a0c7cda836301dcf"
+
mixin-deep@^1.2.0:
version "1.3.1"
resolved "https://registry.yarnpkg.com/mixin-deep/-/mixin-deep-1.3.1.tgz#a49e7268dce1a0d9698e45326c5626df3543d0fe"
@@ -2570,7 +2628,7 @@ mixin-deep@^1.2.0:
for-in "^1.0.2"
is-extendable "^1.0.1"
-"mkdirp@>=0.5 0", mkdirp@^0.5.1, mkdirp@~0.5.0, mkdirp@~0.5.1:
+mkdirp@0.5.x, "mkdirp@>=0.5 0", mkdirp@^0.5.1, mkdirp@~0.5.0, mkdirp@~0.5.1:
version "0.5.1"
resolved "https://registry.yarnpkg.com/mkdirp/-/mkdirp-0.5.1.tgz#30057438eac6cf7f8c4767f38648d6697d75c903"
dependencies:
@@ -2580,7 +2638,7 @@ ms@2.0.0:
version "2.0.0"
resolved "https://registry.yarnpkg.com/ms/-/ms-2.0.0.tgz#5608aeadfc00be6c2901df5f9861788de0d597c8"
-nan@^2.3.0:
+nan@^2.0.7, nan@^2.3.0:
version "2.10.0"
resolved "https://registry.yarnpkg.com/nan/-/nan-2.10.0.tgz#96d0cd610ebd58d4b4de9cc0c6828cda99c7548f"
@@ -2778,12 +2836,23 @@ once@^1.3.0, once@^1.3.3:
dependencies:
wrappy "1"
+opener@~1.4.0:
+ version "1.4.3"
+ resolved "https://registry.yarnpkg.com/opener/-/opener-1.4.3.tgz#5c6da2c5d7e5831e8ffa3964950f8d6674ac90b8"
+
opn@^5.1.0:
version "5.3.0"
resolved "https://registry.yarnpkg.com/opn/-/opn-5.3.0.tgz#64871565c863875f052cfdf53d3e3cb5adb53b1c"
dependencies:
is-wsl "^1.1.0"
+optimist@0.6.x:
+ version "0.6.1"
+ resolved "https://registry.yarnpkg.com/optimist/-/optimist-0.6.1.tgz#da3ea74686fa21a19a111c326e90eb15a0196686"
+ dependencies:
+ minimist "~0.0.1"
+ wordwrap "~0.0.2"
+
optionator@^0.8.1:
version "0.8.2"
resolved "https://registry.yarnpkg.com/optionator/-/optionator-0.8.2.tgz#364c5e409d3f4d6301d6c0b4c05bba50180aeb64"
@@ -2850,9 +2919,9 @@ pako@~1.0.5:
version "1.0.6"
resolved "https://registry.yarnpkg.com/pako/-/pako-1.0.6.tgz#0101211baa70c4bca4a0f63f2206e97b7dfaf258"
-parcel-bundler@~1.6.2:
- version "1.6.2"
- resolved "https://registry.yarnpkg.com/parcel-bundler/-/parcel-bundler-1.6.2.tgz#e415c4993b6f4c48cd410427594ce80bbb7d90f1"
+parcel-bundler@~1.7.0:
+ version "1.7.0"
+ resolved "https://registry.yarnpkg.com/parcel-bundler/-/parcel-bundler-1.7.0.tgz#8c2512615fd602d2f39bd97bfd128f8fe524b321"
dependencies:
babel-code-frame "^6.26.0"
babel-core "^6.25.0"
@@ -2865,7 +2934,6 @@ parcel-bundler@~1.6.2:
babel-types "^6.26.0"
babylon "^6.17.4"
babylon-walk "^1.0.2"
- browser-resolve "^1.11.2"
browserslist "^2.11.2"
chalk "^2.1.0"
chokidar "^2.0.1"
@@ -2873,12 +2941,13 @@ parcel-bundler@~1.6.2:
commander "^2.11.0"
cross-spawn "^6.0.4"
cssnano "^3.10.0"
+ deasync "^0.1.12"
dotenv "^5.0.0"
filesize "^3.6.0"
get-port "^3.2.0"
glob "^7.1.2"
grapheme-breaker "^0.3.2"
- htmlnano "^0.1.6"
+ htmlnano "^0.1.7"
is-url "^1.2.2"
js-yaml "^3.10.0"
json5 "^0.5.1"
@@ -2888,13 +2957,12 @@ parcel-bundler@~1.6.2:
node-libs-browser "^2.0.0"
opn "^5.1.0"
physical-cpu-count "^2.0.0"
- postcss "^6.0.10"
+ postcss "^6.0.19"
postcss-value-parser "^3.3.0"
posthtml "^0.11.2"
posthtml-parser "^0.4.0"
posthtml-render "^1.1.0"
resolve "^1.4.0"
- sanitize-filename "^1.6.1"
semver "^5.4.1"
serialize-to-js "^1.1.1"
serve-static "^1.12.4"
@@ -2967,6 +3035,14 @@ physical-cpu-count@^2.0.0:
version "2.0.0"
resolved "https://registry.yarnpkg.com/physical-cpu-count/-/physical-cpu-count-2.0.0.tgz#18de2f97e4bf7a9551ad7511942b5496f7aba660"
+portfinder@^1.0.13:
+ version "1.0.13"
+ resolved "https://registry.yarnpkg.com/portfinder/-/portfinder-1.0.13.tgz#bb32ecd87c27104ae6ee44b5a3ccbf0ebb1aede9"
+ dependencies:
+ async "^1.5.2"
+ debug "^2.2.0"
+ mkdirp "0.5.x"
+
posix-character-classes@^0.1.0:
version "0.1.1"
resolved "https://registry.yarnpkg.com/posix-character-classes/-/posix-character-classes-0.1.1.tgz#01eac0fe3b5af71a2a6c02feabb8c1fef7e00eab"
@@ -3182,13 +3258,13 @@ postcss@^5.0.10, postcss@^5.0.11, postcss@^5.0.12, postcss@^5.0.13, postcss@^5.0
source-map "^0.5.6"
supports-color "^3.2.3"
-postcss@^6.0.10:
- version "6.0.19"
- resolved "https://registry.yarnpkg.com/postcss/-/postcss-6.0.19.tgz#76a78386f670b9d9494a655bf23ac012effd1555"
+postcss@^6.0.19:
+ version "6.0.21"
+ resolved "https://registry.yarnpkg.com/postcss/-/postcss-6.0.21.tgz#8265662694eddf9e9a5960db6da33c39e4cd069d"
dependencies:
- chalk "^2.3.1"
+ chalk "^2.3.2"
source-map "^0.6.1"
- supports-color "^5.2.0"
+ supports-color "^5.3.0"
posthtml-parser@^0.3.3:
version "0.3.3"
@@ -3275,6 +3351,10 @@ q@^1.1.2:
version "1.5.1"
resolved "https://registry.yarnpkg.com/q/-/q-1.5.1.tgz#7e32f75b41381291d04611f1bf14109ac00651d7"
+qs@~2.3.3:
+ version "2.3.3"
+ resolved "https://registry.yarnpkg.com/qs/-/qs-2.3.3.tgz#e9e85adbe75da0bbe4c8e0476a086290f863b404"
+
qs@~6.4.0:
version "6.4.0"
resolved "https://registry.yarnpkg.com/qs/-/qs-6.4.0.tgz#13e26d28ad6b0ffaa91312cd3bf708ed351e7233"
@@ -3461,14 +3541,14 @@ require-main-filename@^1.0.1:
version "1.0.1"
resolved "https://registry.yarnpkg.com/require-main-filename/-/require-main-filename-1.0.1.tgz#97f717b69d48784f5f526a6c5aa8ffdda055a4d1"
+requires-port@1.x.x:
+ version "1.0.0"
+ resolved "https://registry.yarnpkg.com/requires-port/-/requires-port-1.0.0.tgz#925d2601d39ac485e091cf0da5c6e694dc3dcaff"
+
resolve-url@^0.2.1:
version "0.2.1"
resolved "https://registry.yarnpkg.com/resolve-url/-/resolve-url-0.2.1.tgz#2c637fe77c893afd2a663fe21aa9080068e2052a"
-resolve@1.1.7:
- version "1.1.7"
- resolved "https://registry.yarnpkg.com/resolve/-/resolve-1.1.7.tgz#203114d82ad2c5ed9e8e0411b3932875e889e97b"
-
resolve@^1.1.5, resolve@^1.1.6, resolve@^1.4.0:
version "1.5.0"
resolved "https://registry.yarnpkg.com/resolve/-/resolve-1.5.0.tgz#1f09acce796c9a762579f31b2c1cc4c3cddf9f36"
@@ -3512,12 +3592,6 @@ safer-eval@^1.2.3:
dependencies:
clones "^1.1.0"
-sanitize-filename@^1.6.1:
- version "1.6.1"
- resolved "https://registry.yarnpkg.com/sanitize-filename/-/sanitize-filename-1.6.1.tgz#612da1c96473fa02dccda92dcd5b4ab164a6772a"
- dependencies:
- truncate-utf8-bytes "^1.0.0"
-
sax@~1.2.1, sax@~1.2.4:
version "1.2.4"
resolved "https://registry.yarnpkg.com/sax/-/sax-1.2.4.tgz#2816234e2378bddc4e5354fab5caa895df7100d9"
@@ -3844,7 +3918,7 @@ supports-color@^3.2.3:
dependencies:
has-flag "^1.0.0"
-supports-color@^5.2.0, supports-color@^5.3.0:
+supports-color@^5.3.0:
version "5.3.0"
resolved "https://registry.yarnpkg.com/supports-color/-/supports-color-5.3.0.tgz#5b24ac15db80fa927cf5227a4a33fd3c4c7676c0"
dependencies:
@@ -3973,12 +4047,6 @@ trim-right@^1.0.1:
version "1.0.1"
resolved "https://registry.yarnpkg.com/trim-right/-/trim-right-1.0.1.tgz#cb2e1203067e0c8de1f614094b9fe45704ea6003"
-truncate-utf8-bytes@^1.0.0:
- version "1.0.2"
- resolved "https://registry.yarnpkg.com/truncate-utf8-bytes/-/truncate-utf8-bytes-1.0.2.tgz#405923909592d56f78a5818434b0b78489ca5f2b"
- dependencies:
- utf8-byte-length "^1.0.1"
-
tslib@^1.9.0:
version "1.9.0"
resolved "https://registry.yarnpkg.com/tslib/-/tslib-1.9.0.tgz#e37a86fda8cbbaf23a057f473c9f4dc64e5fc2e8"
@@ -4045,6 +4113,12 @@ union-value@^1.0.0:
is-extendable "^0.1.1"
set-value "^0.4.3"
+union@~0.4.3:
+ version "0.4.6"
+ resolved "https://registry.yarnpkg.com/union/-/union-0.4.6.tgz#198fbdaeba254e788b0efcb630bc11f24a2959e0"
+ dependencies:
+ qs "~2.3.3"
+
uniq@^1.0.1:
version "1.0.1"
resolved "https://registry.yarnpkg.com/uniq/-/uniq-1.0.1.tgz#b31c5ae8254844a3a8281541ce2b04b865a734ff"
@@ -4078,6 +4152,10 @@ urix@^0.1.0:
version "0.1.0"
resolved "https://registry.yarnpkg.com/urix/-/urix-0.1.0.tgz#da937f7a62e21fec1fd18d49b35c2935067a6c72"
+url-join@^2.0.2:
+ version "2.0.5"
+ resolved "https://registry.yarnpkg.com/url-join/-/url-join-2.0.5.tgz#5af22f18c052a000a48d7b82c5e9c2e2feeda728"
+
url@^0.11.0:
version "0.11.0"
resolved "https://registry.yarnpkg.com/url/-/url-0.11.0.tgz#3838e97cfc60521eb73c525a8e55bfdd9e2e28f1"
@@ -4091,10 +4169,6 @@ use@^3.1.0:
dependencies:
kind-of "^6.0.2"
-utf8-byte-length@^1.0.1:
- version "1.0.4"
- resolved "https://registry.yarnpkg.com/utf8-byte-length/-/utf8-byte-length-1.0.4.tgz#f45f150c4c66eee968186505ab93fcbb8ad6bf61"
-
util-deprecate@~1.0.1:
version "1.0.2"
resolved "https://registry.yarnpkg.com/util-deprecate/-/util-deprecate-1.0.2.tgz#450d4dc9fa70de732762fbd2d4a28981419a0ccf"
@@ -4406,6 +4480,10 @@ wide-align@^1.1.0:
dependencies:
string-width "^1.0.2"
+wordwrap@~0.0.2:
+ version "0.0.3"
+ resolved "https://registry.yarnpkg.com/wordwrap/-/wordwrap-0.0.3.tgz#a3d5da6cd5c0bc0008d37234bbaf1bed63059107"
+
wordwrap@~1.0.0:
version "1.0.0"
resolved "https://registry.yarnpkg.com/wordwrap/-/wordwrap-1.0.0.tgz#27584810891456a4171c8d0226441ade90cbcaeb"
diff --git a/mnist-transfer-cnn/build-resources.sh b/mnist-transfer-cnn/build-resources.sh
new file mode 100755
index 000000000..b7edad544
--- /dev/null
+++ b/mnist-transfer-cnn/build-resources.sh
@@ -0,0 +1,63 @@
+#!/usr/bin/env bash
+
+# Copyright 2018 Google LLC. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+# Builds resources for the MNIST Transfer CNN demo.
+# Note this is not necessary to run the demo, because we already provide hosted
+# pre-built resources.
+# Usage example: do this from the 'mnist-transfer-cnn' directory:
+# ./build-resources.sh
+
+set -e
+
+DEMO_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+
+TRAIN_EPOCHS=5
+while true; do
+ if [[ "$1" == "--epochs" ]]; then
+ TRAIN_EPOCHS=$2
+ shift 2
+ elif [[ -z "$1" ]]; then
+ break
+ else
+ echo "ERROR: Unrecognized argument: $1"
+ exit 1
+ fi
+done
+
+RESOURCES_ROOT="${DEMO_DIR}/dist/resources"
+rm -rf "${RESOURCES_ROOT}"
+mkdir -p "${RESOURCES_ROOT}"
+
+# Run Python script to generate the pretrained model and weights files.
+# Make sure you have installed the tensorfowjs pip package first.
+
+python "${DEMO_DIR}/python/mnist_transfer_cnn.py" \
+ --epochs "${TRAIN_EPOCHS}" \
+ --artifacts_dir "${RESOURCES_ROOT}" \
+ --gte5_data_path_prefix "${RESOURCES_ROOT}/gte5" \
+ --gte5_cutoff 1024
+
+cd ${DEMO_DIR}
+yarn
+yarn build
+
+echo
+echo "-----------------------------------------------------------"
+echo "Resources written to ${RESOURCES_ROOT}."
+echo "You can now run the demo with 'yarn watch'."
+echo "-----------------------------------------------------------"
+echo
diff --git a/mnist-transfer-cnn/index.html b/mnist-transfer-cnn/index.html
index 1833ac96d..821ed5267 100644
--- a/mnist-transfer-cnn/index.html
+++ b/mnist-transfer-cnn/index.html
@@ -70,6 +70,11 @@
TensorFlow.js Layers: MNIST CNN Transfer Learning Demo
+
+
+
+
+
diff --git a/mnist-transfer-cnn/index.js b/mnist-transfer-cnn/index.js
index d144acbdc..4816cc026 100644
--- a/mnist-transfer-cnn/index.js
+++ b/mnist-transfer-cnn/index.js
@@ -20,19 +20,28 @@ import * as loader from './loader';
import * as ui from './ui';
import * as util from './util';
-const HOSTED_MODEL_JSON_URL =
- 'https://storage.googleapis.com/tfjs-models/tfjs/mnist_transfer_cnn_v1/model.json';
-const HOSTED_TRAIN_DATA_JSON_URL =
- 'https://storage.googleapis.com/tfjs-models/tfjs/mnist_transfer_cnn_v1/gte5.train.json';
-const HOSTED_TEST_DATA_JSON_URL =
- 'https://storage.googleapis.com/tfjs-models/tfjs/mnist_transfer_cnn_v1/gte5.test.json';
+const HOSTED_URLS = {
+ model:
+ 'https://storage.googleapis.com/tfjs-models/tfjs/mnist_transfer_cnn_v1/model.json',
+ train:
+ 'https://storage.googleapis.com/tfjs-models/tfjs/mnist_transfer_cnn_v1/gte5.train.json',
+ test:
+ 'https://storage.googleapis.com/tfjs-models/tfjs/mnist_transfer_cnn_v1/gte5.test.json'
+};
+
+const LOCAL_URLS = {
+ model: 'http://localhost:1235/resources/model.json',
+ train: 'http://localhost:1235/resources/gte5.train.json',
+ test: 'http://localhost:1235/resources/gte5.test.json'
+};
class MnistTransferCNNPredictor {
/**
* Initializes the MNIST Transfer CNN demo.
*/
- async init() {
- this.model = await loader.loadHostedPretrainedModel(HOSTED_MODEL_JSON_URL);
+ async init(urls) {
+ this.urls = urls;
+ this.model = await loader.loadHostedPretrainedModel(urls.model);
this.imageSize = this.model.layers[0].batchInputShape[1];
this.numClasses = 5;
@@ -44,10 +53,10 @@ class MnistTransferCNNPredictor {
async loadRetrainData() {
console.log('Loading data for transfer learning...');
ui.status('Loading data for transfer learning...');
- this.gte5TrainData = await loader.loadHostedData(
- HOSTED_TRAIN_DATA_JSON_URL, this.numClasses);
+ this.gte5TrainData =
+ await loader.loadHostedData(this.urls.train, this.numClasses);
this.gte5TestData =
- await loader.loadHostedData(HOSTED_TEST_DATA_JSON_URL, this.numClasses);
+ await loader.loadHostedData(this.urls.test, this.numClasses);
ui.status('Done loading data for transfer learning.');
}
@@ -135,10 +144,31 @@ class MnistTransferCNNPredictor {
* and retrain functions with the UI.
*/
async function setupMnistTransferCNN() {
- const predictor = await new MnistTransferCNNPredictor().init();
- ui.prepUI(
- x => predictor.predict(x), x => predictor.retrainModel(),
- predictor.testExamples, predictor.imageSize);
+ if (await loader.urlExists(HOSTED_URLS.model)) {
+ ui.status('Model available: ' + HOSTED_URLS.model);
+ const button = document.getElementById('load-pretrained-remote');
+ button.addEventListener('click', async () => {
+ const predictor = await new MnistTransferCNNPredictor().init(HOSTED_URLS);
+ ui.prepUI(
+ x => predictor.predict(x), x => predictor.retrainModel(),
+ predictor.testExamples, predictor.imageSize);
+ });
+ button.style.display = 'inline-block';
+ }
+
+ if (await loader.urlExists(LOCAL_URLS.model)) {
+ ui.status('Model available: ' + LOCAL_URLS.model);
+ const button = document.getElementById('load-pretrained-local');
+ button.addEventListener('click', async () => {
+ const predictor = await new MnistTransferCNNPredictor().init(LOCAL_URLS);
+ ui.prepUI(
+ x => predictor.predict(x), x => predictor.retrainModel(),
+ predictor.testExamples, predictor.imageSize);
+ });
+ button.style.display = 'inline-block';
+ }
+
+ ui.status('Standing by.');
}
setupMnistTransferCNN();
diff --git a/mnist-transfer-cnn/loader.js b/mnist-transfer-cnn/loader.js
index 92a7389b3..600b90b58 100644
--- a/mnist-transfer-cnn/loader.js
+++ b/mnist-transfer-cnn/loader.js
@@ -19,6 +19,19 @@ import * as tf from '@tensorflow/tfjs';
import * as ui from './ui';
import * as util from './util';
+/**
+ * Test whether a given URL is retrievable.
+ */
+export async function urlExists(url) {
+ ui.status('Testing url ' + url);
+ try {
+ const response = await fetch(url, {method: 'HEAD'});
+ return response.ok;
+ } catch (err) {
+ return false;
+ }
+}
+
/**
* Load pretrained model stored at a remote URL.
*
@@ -29,9 +42,13 @@ export async function loadHostedPretrainedModel(url) {
try {
const model = await tf.loadModel(url);
ui.status('Done loading pretrained model.');
+ // We can't load a model twice due to
+ // https://github.com/tensorflow/tfjs/issues/34
+ // Therefore we remove the load buttons to avoid user confusion.
+ ui.disableLoadModelButtons();
return model;
} catch (err) {
- console.log(err);
+ console.error(err);
ui.status('Loading pretrained model failed.');
}
}
@@ -51,7 +68,7 @@ export async function loadHostedData(url, numClasses) {
ui.status('Done loading data.');
return result;
} catch (err) {
- console.log(err);
+ console.error(err);
ui.status('Loading data failed.');
}
}
diff --git a/mnist-transfer-cnn/package.json b/mnist-transfer-cnn/package.json
index f05c6e629..6b22e2338 100644
--- a/mnist-transfer-cnn/package.json
+++ b/mnist-transfer-cnn/package.json
@@ -13,7 +13,7 @@
"vega-embed": "~3.0.0"
},
"scripts": {
- "watch": "NODE_ENV=development parcel --no-hmr --open index.html ",
+ "watch": "./serve.sh",
"build": "NODE_ENV=production parcel build index.html --no-minify --public-url ./"
},
"devDependencies": {
@@ -21,7 +21,8 @@
"babel-polyfill": "~6.26.0",
"babel-preset-env": "~1.6.1",
"clang-format": "~1.2.2",
- "parcel-bundler": "~1.6.2"
+ "http-server": "~0.10.0",
+ "parcel-bundler": "~1.7.0"
},
"babel": {
"presets": [
diff --git a/mnist-transfer-cnn/python/__init__.py b/mnist-transfer-cnn/python/__init__.py
new file mode 100644
index 000000000..636b70f0d
--- /dev/null
+++ b/mnist-transfer-cnn/python/__init__.py
@@ -0,0 +1,18 @@
+# Copyright 2018 Google LLC. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
diff --git a/mnist-transfer-cnn/python/mnist_transfer_cnn.py b/mnist-transfer-cnn/python/mnist_transfer_cnn.py
new file mode 100644
index 000000000..b9856b661
--- /dev/null
+++ b/mnist-transfer-cnn/python/mnist_transfer_cnn.py
@@ -0,0 +1,288 @@
+# Copyright 2018 Google LLC. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Train a simple CNN model for transfer learning in browser with TF.js Layers.
+
+The model architecture in this file is based on the Keras stock example at:
+ https://github.com/keras-team/keras/blob/master/examples/mnist_transfer_cnn.py.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import datetime
+import json
+
+import keras
+from keras import backend as K
+import tensorflowjs as tfjs
+
+# Input image dimensions.
+IMG_ROWS, IMG_COLS = 28, 28
+NUM_CLASSES = 5
+
+INPUT_SHAPE = ((1, IMG_ROWS, IMG_COLS)
+ if K.image_data_format() == 'channels_first'
+ else (IMG_ROWS, IMG_COLS, 1))
+
+
+def load_mnist_data(gte5_cutoff):
+ (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
+
+ # create two datasets one with digits below 5 and one with 5 and above
+ x_train_lt5 = x_train[y_train < 5]
+ y_train_lt5 = y_train[y_train < 5]
+ x_test_lt5 = x_test[y_test < 5]
+ y_test_lt5 = y_test[y_test < 5]
+
+ x_train_gte5 = x_train[y_train >= 5]
+ y_train_gte5 = y_train[y_train >= 5] - 5
+ x_test_gte5 = x_test[y_test >= 5]
+ y_test_gte5 = y_test[y_test >= 5] - 5
+
+ if gte5_cutoff > 0:
+ x_train_gte5 = x_train_gte5[:gte5_cutoff, ...]
+ y_train_gte5 = y_train_gte5[:gte5_cutoff, ...]
+ x_test_gte5 = x_test_gte5[:gte5_cutoff, ...]
+ y_test_gte5 = y_test_gte5[:gte5_cutoff, ...]
+
+ return (x_train_lt5, y_train_lt5, x_test_lt5, y_test_lt5,
+ x_train_gte5, y_train_gte5, x_test_gte5, y_test_gte5)
+
+
+def train_model(model,
+ optimizer,
+ train,
+ test,
+ num_classes,
+ batch_size=128,
+ epochs=5):
+ """Train or re-train a model using data.
+
+ Args:
+ model: `keras.Model` instance to (re-)train.
+ optimizer: Name of the optimizer to use.
+ train: Training data, of shape (NUM_EXAMPLES, IMG_ROWS, IMG_COLS).
+ test: Test data, of shape (NUM_EXAMPLES, IMG_ROWS, IMG_COLS).
+ num_classes: Number of classes.
+ batch_size: Batch size.
+ epochs: How many epochs to train.
+ """
+ x_train = train[0].reshape((train[0].shape[0],) + INPUT_SHAPE)
+ x_test = test[0].reshape((test[0].shape[0],) + INPUT_SHAPE)
+ x_train = x_train.astype('float32')
+ x_test = x_test.astype('float32')
+ x_train /= 255
+ x_test /= 255
+
+ # convert class vectors to binary class matrices
+ y_train = keras.utils.to_categorical(train[1], num_classes)
+ y_test = keras.utils.to_categorical(test[1], num_classes)
+
+ model.compile(loss='categorical_crossentropy',
+ optimizer=optimizer,
+ metrics=['accuracy'])
+
+ t = datetime.datetime.now()
+ model.fit(x_train, y_train,
+ batch_size=batch_size,
+ epochs=epochs,
+ verbose=1,
+ validation_data=(x_test, y_test))
+ print('Training time: %s' % (datetime.datetime.now() - t))
+ score = model.evaluate(x_test, y_test, verbose=0)
+ print('Test score:', score[0])
+ print('Test accuracy:', score[1])
+
+
+def write_mnist_examples_to_json_file(x, y, js_path):
+ """Write a batch of MNIST examples to a JavaScript (.js) file.
+
+ Args:
+ x: A numpy array representing the image data, with shape
+ (NUM_EXAMPLES, IMG_ROWS, IMG_COLS).
+ y: A numpy array representing the image labels (as integer indices),
+ with shape (NUM_EXAMPLES,).
+ js_path: Path to the JavaScript file to write to.
+ """
+ data = []
+ num_examples = x.shape[0]
+ for i in range(num_examples):
+ data.append({'x': x[i, ...].tolist(), 'y': int(y[i])})
+ with open(js_path, 'wt') as f:
+ f.write(json.dumps(data))
+
+
+def write_gte5_data(x_train_gte5,
+ y_train_gte5,
+ x_test_gte5,
+ y_test_gte5,
+ gte5_data_path_prefix):
+ """Write the transfer-learning data to .js files.
+
+ Args:
+ x_train_lt5: x (image) data for training: digits >= 5.
+ y_train_lt5: y (label) data for training: digits >= 5.
+ x_test_lt5: x (image) data for test (validation): digits >= 5.
+ y_test_lt5: y (label) data for test (validation): digits >= 5.
+ gte5_data_path_prefix: Path prefix for writing the files. For example,
+ if the value is '/tmp/foo', the train and test files will be written
+ at '/tmp/foo.train.js' and '/tmp/foo.test.js', respectively.
+ """
+ gte5_train_path = gte5_data_path_prefix + '.train.json'
+ write_mnist_examples_to_json_file(
+ x_train_gte5, y_train_gte5, gte5_train_path)
+ print('Wrote gte5 training data to: %s' % gte5_train_path)
+ gte5_test_path = gte5_data_path_prefix + '.test.json'
+ write_mnist_examples_to_json_file(
+ x_test_gte5, y_test_gte5, gte5_test_path)
+ print('Wrote gte5 test data to: %s' % gte5_test_path)
+
+
+def train_and_save_model(filters,
+ kernel_size,
+ pool_size,
+ batch_size,
+ epochs,
+ x_train_lt5,
+ y_train_lt5,
+ x_test_lt5,
+ y_test_lt5,
+ artifacts_dir,
+ optimizer='adam'):
+ """Train and save MNIST CNN model.
+
+ Args:
+ filters: number of filters for convolution layers.
+ kernel_size: kernel size for convolution layers.
+ pool_size: pooling kernel size for pooling layers.
+ batch_size: batch size.
+ epochs: number of epochs to train for.
+ x_train_lt5: x (image) data for training: digits < 5.
+ y_train_lt5: y (label) data for training: digits < 5.
+ x_test_lt5: x (image) data for test (validation): digits < 5.
+ y_test_lt5: y (label) data for test (validation): digits < 5.
+ artifacts_dir: Directory to save the model artifacts (model topology JSON,
+ weights and weight manifest) in.
+ optimizer: The name of the optimizer to use, as a string.
+ """
+
+ feature_layers = [
+ keras.layers.Conv2D(filters, kernel_size,
+ padding='valid',
+ input_shape=INPUT_SHAPE),
+ keras.layers.Activation('relu'),
+ keras.layers.Conv2D(filters, kernel_size),
+ keras.layers.Activation('relu'),
+ keras.layers.MaxPooling2D(pool_size=pool_size),
+ keras.layers.Dropout(0.25),
+ keras.layers.Flatten(),
+ ]
+ classification_layers = [
+ keras.layers.Dense(128),
+ keras.layers.Activation('relu'),
+ keras.layers.Dropout(0.5),
+ keras.layers.Dense(NUM_CLASSES),
+ keras.layers.Activation('softmax')
+ ]
+ model = keras.models.Sequential(feature_layers + classification_layers)
+
+ train_model(model,
+ optimizer,
+ (x_train_lt5, y_train_lt5),
+ (x_test_lt5, y_test_lt5),
+ NUM_CLASSES, batch_size=batch_size, epochs=epochs)
+ tfjs.converters.save_keras_model(model, artifacts_dir)
+
+
+def main():
+ (x_train_lt5, y_train_lt5, x_test_lt5, y_test_lt5,
+ x_train_gte5, y_train_gte5, x_test_gte5, y_test_gte5) = load_mnist_data(
+ FLAGS.gte5_cutoff)
+
+ write_gte5_data(x_train_gte5, y_train_gte5, x_test_gte5, y_test_gte5,
+ FLAGS.gte5_data_path_prefix)
+
+ train_and_save_model(FLAGS.filters,
+ FLAGS.kernel_size,
+ FLAGS.pool_size,
+ FLAGS.batch_size,
+ FLAGS.epochs,
+ x_train_lt5,
+ y_train_lt5,
+ x_test_lt5,
+ y_test_lt5,
+ FLAGS.artifacts_dir,
+ optimizer=FLAGS.optimizer)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser('MNIST model training and serialization')
+ parser.add_argument(
+ '--epochs',
+ type=int,
+ default=5,
+ help='Number of epochs to train the Keras model for.')
+ parser.add_argument(
+ '--batch_size',
+ type=int,
+ default=128,
+ help='Batch size for training.')
+ parser.add_argument(
+ '--filters',
+ type=int,
+ default=32,
+ help='Number of convolutional filters to use.')
+ parser.add_argument(
+ '--pool_size',
+ type=int,
+ default=2,
+ help='Size of pooling area for max pooling.')
+ parser.add_argument(
+ '--kernel_size',
+ type=int,
+ default=3,
+ help='Convolutional kernel size.')
+ parser.add_argument(
+ '--artifacts_dir',
+ type=str,
+ default='/tmp/mnist.keras',
+ help='Local path for saving the TensorFlow.js artifacts.')
+ parser.add_argument(
+ '--optimizer',
+ type=str,
+ default='adam',
+ help='Name of the optimizer to use for training.')
+ parser.add_argument(
+ '--gte5_cutoff',
+ type=int,
+ default=1024,
+ help='If value is > 0, will cause only the first this many examples '
+ 'in the gte5 (label >= 5) transfer learning subset to be written to '
+ 'JavaScript (.js) files.')
+ parser.add_argument(
+ '--gte5_data_path_prefix',
+ type=str,
+ default='/tmp/mnist_transfer_cnn.gte5',
+ help='Prefix for the label >= 5 data for transfer learning.'
+ 'For example, if the prefix is /tmp/foo.gte5, the train and test '
+ 'data will be written to /tmp/foo.gte5.train.js and '
+ '/tmp/foo.gte5.test.js, respectively.')
+ # TODO(cais, soergel): Eventually we want to use the dataset API and not write
+ # writing the data to file.
+
+ FLAGS, _ = parser.parse_known_args()
+ main()
diff --git a/mnist-transfer-cnn/python/mnist_transfer_cnn_test.py b/mnist-transfer-cnn/python/mnist_transfer_cnn_test.py
new file mode 100644
index 000000000..fb6aaacdc
--- /dev/null
+++ b/mnist-transfer-cnn/python/mnist_transfer_cnn_test.py
@@ -0,0 +1,78 @@
+# Copyright 2018 Google LLC. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Test for the MNIST transfer learning CNN model."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+import os
+import shutil
+import tempfile
+import unittest
+
+import numpy as np
+
+from . import mnist_transfer_cnn
+
+class MnistTest(unittest.TestCase):
+
+ def setUp(self):
+ self._tmp_dir = tempfile.mkdtemp()
+ super(MnistTest, self).setUp()
+
+ def tearDown(self):
+ if os.path.isdir(self._tmp_dir):
+ shutil.rmtree(self._tmp_dir)
+ super(MnistTest, self).tearDown()
+
+ def _getFakeMnistData(self):
+ """Generate some fake MNIST data for testing."""
+ x_train = np.random.rand(4, 28, 28)
+ y_train = np.random.rand(4,)
+ x_test = np.random.rand(4, 28, 28)
+ y_test = np.random.rand(4,)
+ return x_train, y_train, x_test, y_test
+
+ def testWriteGte5DataToJsFiles(self):
+ x_train, y_train, x_test, y_test = self._getFakeMnistData()
+ js_file_prefix = os.path.join(self._tmp_dir, 'gte5')
+ mnist_transfer_cnn.write_gte5_data(x_train, y_train, x_test, y_test,
+ js_file_prefix)
+
+ for js_path_suffix in ('.train.json', '.test.json'):
+ with open(js_file_prefix + js_path_suffix, 'rt') as f:
+ json_string = f.read()
+ data = json.loads(json_string)
+ if 'train' in js_path_suffix:
+ self.assertEqual(x_train.shape[0], len(data))
+ else:
+ self.assertEqual(x_test.shape[0], len(data))
+ self.assertEqual(28, len(data[0]['x']))
+ self.assertEqual(28, len(data[0]['x'][0]))
+
+ def testTrainWithFakeDataAndSave(self):
+ x_train, y_train, x_test, y_test = self._getFakeMnistData()
+ mnist_transfer_cnn.train_and_save_model(
+ 2, 2, 2, 2, 1, x_train, y_train, x_test, y_test,
+ self._tmp_dir, optimizer='adam')
+
+ # Check that the model json file is created.
+ json.load(open(os.path.join(self._tmp_dir, 'model.json'), 'rt'))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/mnist-transfer-cnn/serve.sh b/mnist-transfer-cnn/serve.sh
new file mode 100755
index 000000000..e5ede63ed
--- /dev/null
+++ b/mnist-transfer-cnn/serve.sh
@@ -0,0 +1,42 @@
+#!/usr/bin/env bash
+
+# Copyright 2018 Google LLC. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+# This script starts two HTTP servers on different ports:
+# * Port 1234 (using parcel) serves HTML and JavaScript.
+# * Port 1235 (using http-server) serves pretrained model resources.
+#
+# The reason for this arrangement is that Parcel currently has a limitation that
+# prevents it from serving the pretrained models; see
+# https://github.com/parcel-bundler/parcel/issues/1098. Once that issue is
+# resolved, a single Parcel server will be sufficient.
+
+NODE_ENV=development
+RESOURCE_PORT=1235
+
+# Ensure that http-server is available
+yarn
+
+echo Starting the pretrained model server...
+node_modules/http-server/bin/http-server dist --cors -p "${RESOURCE_PORT}" > /dev/null & HTTP_SERVER_PID=$!
+
+echo Starting the example html/js server...
+# This uses port 1234 by default.
+node_modules/parcel-bundler/bin/cli.js serve -d dist --open --no-hmr --public-url / index.html
+
+# When the Parcel server exits, kill the http-server too.
+kill $HTTP_SERVER_PID
+
diff --git a/mnist-transfer-cnn/ui.js b/mnist-transfer-cnn/ui.js
index ed22ab1ed..d13d7e7ed 100644
--- a/mnist-transfer-cnn/ui.js
+++ b/mnist-transfer-cnn/ui.js
@@ -19,6 +19,7 @@ import * as tf from '@tensorflow/tfjs';
import * as util from './util';
export function status(statusText, statusColor) {
+ console.log(statusText);
document.getElementById('status').textContent = statusText;
document.getElementById('status').style.color = statusColor;
}
@@ -106,3 +107,8 @@ export function setPredictResults(predictOut, winner) {
predictValues.innerHTML = valTds;
document.getElementById('winner').textContent = winner;
}
+
+export function disableLoadModelButtons() {
+ document.getElementById('load-pretrained-remote').style.display = 'none';
+ document.getElementById('load-pretrained-local').style.display = 'none';
+}
diff --git a/mnist-transfer-cnn/yarn.lock b/mnist-transfer-cnn/yarn.lock
index 2e61d4f3d..87bfdaf3c 100644
--- a/mnist-transfer-cnn/yarn.lock
+++ b/mnist-transfer-cnn/yarn.lock
@@ -705,6 +705,10 @@ binary-extensions@^1.0.0:
version "1.11.0"
resolved "https://registry.yarnpkg.com/binary-extensions/-/binary-extensions-1.11.0.tgz#46aa1751fb6a2f93ee5e689bb1087d4b14c6c205"
+bindings@~1.2.1:
+ version "1.2.1"
+ resolved "https://registry.yarnpkg.com/bindings/-/bindings-1.2.1.tgz#14ad6113812d2d37d72e67b4cacb4bb726505f11"
+
block-stream@*:
version "0.0.9"
resolved "https://registry.yarnpkg.com/block-stream/-/block-stream-0.0.9.tgz#13ebfe778a03205cfe03751481ebb4b3300c126a"
@@ -766,12 +770,6 @@ brorand@^1.0.1:
version "1.1.0"
resolved "https://registry.yarnpkg.com/brorand/-/brorand-1.1.0.tgz#12c25efe40a45e3c323eb8675a0a0ce57b22371f"
-browser-resolve@^1.11.2:
- version "1.11.2"
- resolved "https://registry.yarnpkg.com/browser-resolve/-/browser-resolve-1.11.2.tgz#8ff09b0a2c421718a1051c260b32e48f442938ce"
- dependencies:
- resolve "1.1.7"
-
browserify-aes@^1.0.0, browserify-aes@^1.0.4:
version "1.1.1"
resolved "https://registry.yarnpkg.com/browserify-aes/-/browserify-aes-1.1.1.tgz#38b7ab55edb806ff2dcda1a7f1620773a477c49f"
@@ -1066,6 +1064,10 @@ colormin@^1.0.5:
css-color-names "0.0.4"
has "^1.0.1"
+colors@1.0.3:
+ version "1.0.3"
+ resolved "https://registry.yarnpkg.com/colors/-/colors-1.0.3.tgz#0433f44d809680fdeb60ed260f1b0c262e82a40b"
+
colors@~1.1.2:
version "1.1.2"
resolved "https://registry.yarnpkg.com/colors/-/colors-1.1.2.tgz#168a4701756b6a7f51a12ce0c97bfa28c084ed63"
@@ -1145,6 +1147,10 @@ core-util-is@1.0.2, core-util-is@~1.0.0:
version "1.0.2"
resolved "https://registry.yarnpkg.com/core-util-is/-/core-util-is-1.0.2.tgz#b5fd54220aa2bc5ab57aab7140c940754503c1a7"
+corser@~2.0.0:
+ version "2.0.1"
+ resolved "https://registry.yarnpkg.com/corser/-/corser-2.0.1.tgz#8eda252ecaab5840dcd975ceb90d9370c819ff87"
+
create-ecdh@^4.0.0:
version "4.0.0"
resolved "https://registry.yarnpkg.com/create-ecdh/-/create-ecdh-4.0.0.tgz#888c723596cdf7612f6498233eebd7a35301737d"
@@ -1433,6 +1439,13 @@ date-now@^0.1.4:
version "0.1.4"
resolved "https://registry.yarnpkg.com/date-now/-/date-now-0.1.4.tgz#eaf439fd4d4848ad74e5cc7dbef200672b9e345b"
+deasync@^0.1.12:
+ version "0.1.12"
+ resolved "https://registry.yarnpkg.com/deasync/-/deasync-0.1.12.tgz#0159492a4133ab301d6c778cf01e74e63b10e549"
+ dependencies:
+ bindings "~1.2.1"
+ nan "^2.0.7"
+
debug@2.6.9, debug@^2.2.0, debug@^2.3.3, debug@^2.6.8:
version "2.6.9"
resolved "https://registry.yarnpkg.com/debug/-/debug-2.6.9.tgz#5d128515df134ff327e90a4c93f4e077a536341f"
@@ -1585,6 +1598,15 @@ ecc-jsbn@~0.1.1:
dependencies:
jsbn "~0.1.0"
+ecstatic@^2.0.0:
+ version "2.2.1"
+ resolved "https://registry.yarnpkg.com/ecstatic/-/ecstatic-2.2.1.tgz#b5087fad439dd9dd49d31e18131454817fe87769"
+ dependencies:
+ he "^1.1.1"
+ mime "^1.2.11"
+ minimist "^1.1.0"
+ url-join "^2.0.2"
+
editorconfig@^0.13.2:
version "0.13.3"
resolved "https://registry.yarnpkg.com/editorconfig/-/editorconfig-0.13.3.tgz#e5219e587951d60958fd94ea9a9a008cdeff1b34"
@@ -1696,6 +1718,10 @@ etag@~1.8.1:
version "1.8.1"
resolved "https://registry.yarnpkg.com/etag/-/etag-1.8.1.tgz#41ae2eeb65efa62268aebfea83ac7d79299b0887"
+eventemitter3@1.x.x:
+ version "1.2.0"
+ resolved "https://registry.yarnpkg.com/eventemitter3/-/eventemitter3-1.2.0.tgz#1c86991d816ad1e504750e73874224ecf3bec508"
+
events@^1.0.0:
version "1.1.1"
resolved "https://registry.yarnpkg.com/events/-/events-1.1.1.tgz#9ebdb7635ad099c70dcc4c2a1f5004288e8bd924"
@@ -2033,6 +2059,10 @@ hawk@3.1.3, hawk@~3.1.3:
hoek "2.x.x"
sntp "1.x.x"
+he@^1.1.1:
+ version "1.1.1"
+ resolved "https://registry.yarnpkg.com/he/-/he-1.1.1.tgz#93410fd21b009735151f8868c2f271f3427e23fd"
+
hmac-drbg@^1.0.0:
version "1.0.1"
resolved "https://registry.yarnpkg.com/hmac-drbg/-/hmac-drbg-1.0.1.tgz#d2745701025a6c775a6c545793ed502fc0c649a1"
@@ -2060,7 +2090,7 @@ html-comment-regex@^1.1.0:
version "1.1.1"
resolved "https://registry.yarnpkg.com/html-comment-regex/-/html-comment-regex-1.1.1.tgz#668b93776eaae55ebde8f3ad464b307a4963625e"
-htmlnano@^0.1.6:
+htmlnano@^0.1.7:
version "0.1.7"
resolved "https://registry.yarnpkg.com/htmlnano/-/htmlnano-0.1.7.tgz#1751937a05f122a3248dba1c63edf01f6d36cb84"
dependencies:
@@ -2091,6 +2121,26 @@ http-errors@~1.6.2:
setprototypeof "1.0.3"
statuses ">= 1.3.1 < 2"
+http-proxy@^1.8.1:
+ version "1.16.2"
+ resolved "https://registry.yarnpkg.com/http-proxy/-/http-proxy-1.16.2.tgz#06dff292952bf64dbe8471fa9df73066d4f37742"
+ dependencies:
+ eventemitter3 "1.x.x"
+ requires-port "1.x.x"
+
+http-server@~0.10.0:
+ version "0.10.0"
+ resolved "https://registry.yarnpkg.com/http-server/-/http-server-0.10.0.tgz#b2a446b16a9db87ed3c622ba9beb1b085b1234a7"
+ dependencies:
+ colors "1.0.3"
+ corser "~2.0.0"
+ ecstatic "^2.0.0"
+ http-proxy "^1.8.1"
+ opener "~1.4.0"
+ optimist "0.6.x"
+ portfinder "^1.0.13"
+ union "~0.4.3"
+
http-signature@~1.1.0:
version "1.1.1"
resolved "https://registry.yarnpkg.com/http-signature/-/http-signature-1.1.1.tgz#df72e267066cd0ac67fb76adf8e134a8fbcf91bf"
@@ -2610,6 +2660,10 @@ mime@1.4.1:
version "1.4.1"
resolved "https://registry.yarnpkg.com/mime/-/mime-1.4.1.tgz#121f9ebc49e3766f311a76e1fa1c8003c4b03aa6"
+mime@^1.2.11:
+ version "1.6.0"
+ resolved "https://registry.yarnpkg.com/mime/-/mime-1.6.0.tgz#32cd9e5c64553bd58d19a568af452acff04981b1"
+
mimic-fn@^1.0.0:
version "1.2.0"
resolved "https://registry.yarnpkg.com/mimic-fn/-/mimic-fn-1.2.0.tgz#820c86a39334640e99516928bd03fca88057d022"
@@ -2632,10 +2686,14 @@ minimist@0.0.8:
version "0.0.8"
resolved "https://registry.yarnpkg.com/minimist/-/minimist-0.0.8.tgz#857fcabfc3397d2625b8228262e86aa7a011b05d"
-minimist@^1.1.3, minimist@^1.2.0:
+minimist@^1.1.0, minimist@^1.1.3, minimist@^1.2.0:
version "1.2.0"
resolved "https://registry.yarnpkg.com/minimist/-/minimist-1.2.0.tgz#a35008b20f41383eec1fb914f4cd5df79a264284"
+minimist@~0.0.1:
+ version "0.0.10"
+ resolved "https://registry.yarnpkg.com/minimist/-/minimist-0.0.10.tgz#de3f98543dbf96082be48ad1a0c7cda836301dcf"
+
mixin-deep@^1.2.0:
version "1.3.1"
resolved "https://registry.yarnpkg.com/mixin-deep/-/mixin-deep-1.3.1.tgz#a49e7268dce1a0d9698e45326c5626df3543d0fe"
@@ -2643,7 +2701,7 @@ mixin-deep@^1.2.0:
for-in "^1.0.2"
is-extendable "^1.0.1"
-"mkdirp@>=0.5 0", mkdirp@^0.5.1, mkdirp@~0.5.0, mkdirp@~0.5.1:
+mkdirp@0.5.x, "mkdirp@>=0.5 0", mkdirp@^0.5.1, mkdirp@~0.5.0, mkdirp@~0.5.1:
version "0.5.1"
resolved "https://registry.yarnpkg.com/mkdirp/-/mkdirp-0.5.1.tgz#30057438eac6cf7f8c4767f38648d6697d75c903"
dependencies:
@@ -2653,7 +2711,7 @@ ms@2.0.0:
version "2.0.0"
resolved "https://registry.yarnpkg.com/ms/-/ms-2.0.0.tgz#5608aeadfc00be6c2901df5f9861788de0d597c8"
-nan@^2.3.0, nan@^2.4.0:
+nan@^2.0.7, nan@^2.3.0, nan@^2.4.0:
version "2.10.0"
resolved "https://registry.yarnpkg.com/nan/-/nan-2.10.0.tgz#96d0cd610ebd58d4b4de9cc0c6828cda99c7548f"
@@ -2860,12 +2918,23 @@ once@^1.3.0, once@^1.3.3:
dependencies:
wrappy "1"
+opener@~1.4.0:
+ version "1.4.3"
+ resolved "https://registry.yarnpkg.com/opener/-/opener-1.4.3.tgz#5c6da2c5d7e5831e8ffa3964950f8d6674ac90b8"
+
opn@^5.1.0:
version "5.3.0"
resolved "https://registry.yarnpkg.com/opn/-/opn-5.3.0.tgz#64871565c863875f052cfdf53d3e3cb5adb53b1c"
dependencies:
is-wsl "^1.1.0"
+optimist@0.6.x:
+ version "0.6.1"
+ resolved "https://registry.yarnpkg.com/optimist/-/optimist-0.6.1.tgz#da3ea74686fa21a19a111c326e90eb15a0196686"
+ dependencies:
+ minimist "~0.0.1"
+ wordwrap "~0.0.2"
+
optionator@^0.8.1:
version "0.8.2"
resolved "https://registry.yarnpkg.com/optionator/-/optionator-0.8.2.tgz#364c5e409d3f4d6301d6c0b4c05bba50180aeb64"
@@ -2938,9 +3007,9 @@ pako@~1.0.5:
version "1.0.6"
resolved "https://registry.yarnpkg.com/pako/-/pako-1.0.6.tgz#0101211baa70c4bca4a0f63f2206e97b7dfaf258"
-parcel-bundler@~1.6.2:
- version "1.6.2"
- resolved "https://registry.yarnpkg.com/parcel-bundler/-/parcel-bundler-1.6.2.tgz#e415c4993b6f4c48cd410427594ce80bbb7d90f1"
+parcel-bundler@~1.7.0:
+ version "1.7.0"
+ resolved "https://registry.yarnpkg.com/parcel-bundler/-/parcel-bundler-1.7.0.tgz#8c2512615fd602d2f39bd97bfd128f8fe524b321"
dependencies:
babel-code-frame "^6.26.0"
babel-core "^6.25.0"
@@ -2953,7 +3022,6 @@ parcel-bundler@~1.6.2:
babel-types "^6.26.0"
babylon "^6.17.4"
babylon-walk "^1.0.2"
- browser-resolve "^1.11.2"
browserslist "^2.11.2"
chalk "^2.1.0"
chokidar "^2.0.1"
@@ -2961,12 +3029,13 @@ parcel-bundler@~1.6.2:
commander "^2.11.0"
cross-spawn "^6.0.4"
cssnano "^3.10.0"
+ deasync "^0.1.12"
dotenv "^5.0.0"
filesize "^3.6.0"
get-port "^3.2.0"
glob "^7.1.2"
grapheme-breaker "^0.3.2"
- htmlnano "^0.1.6"
+ htmlnano "^0.1.7"
is-url "^1.2.2"
js-yaml "^3.10.0"
json5 "^0.5.1"
@@ -2976,13 +3045,12 @@ parcel-bundler@~1.6.2:
node-libs-browser "^2.0.0"
opn "^5.1.0"
physical-cpu-count "^2.0.0"
- postcss "^6.0.10"
+ postcss "^6.0.19"
postcss-value-parser "^3.3.0"
posthtml "^0.11.2"
posthtml-parser "^0.4.0"
posthtml-render "^1.1.0"
resolve "^1.4.0"
- sanitize-filename "^1.6.1"
semver "^5.4.1"
serialize-to-js "^1.1.1"
serve-static "^1.12.4"
@@ -3089,6 +3157,14 @@ pinkie@^2.0.0:
version "2.0.4"
resolved "https://registry.yarnpkg.com/pinkie/-/pinkie-2.0.4.tgz#72556b80cfa0d48a974e80e77248e80ed4f7f870"
+portfinder@^1.0.13:
+ version "1.0.13"
+ resolved "https://registry.yarnpkg.com/portfinder/-/portfinder-1.0.13.tgz#bb32ecd87c27104ae6ee44b5a3ccbf0ebb1aede9"
+ dependencies:
+ async "^1.5.2"
+ debug "^2.2.0"
+ mkdirp "0.5.x"
+
posix-character-classes@^0.1.0:
version "0.1.1"
resolved "https://registry.yarnpkg.com/posix-character-classes/-/posix-character-classes-0.1.1.tgz#01eac0fe3b5af71a2a6c02feabb8c1fef7e00eab"
@@ -3304,9 +3380,9 @@ postcss@^5.0.10, postcss@^5.0.11, postcss@^5.0.12, postcss@^5.0.13, postcss@^5.0
source-map "^0.5.6"
supports-color "^3.2.3"
-postcss@^6.0.10:
- version "6.0.20"
- resolved "https://registry.yarnpkg.com/postcss/-/postcss-6.0.20.tgz#686107e743a12d5530cb68438c590d5b2bf72c3c"
+postcss@^6.0.19:
+ version "6.0.21"
+ resolved "https://registry.yarnpkg.com/postcss/-/postcss-6.0.21.tgz#8265662694eddf9e9a5960db6da33c39e4cd069d"
dependencies:
chalk "^2.3.2"
source-map "^0.6.1"
@@ -3397,6 +3473,10 @@ q@^1.1.2:
version "1.5.1"
resolved "https://registry.yarnpkg.com/q/-/q-1.5.1.tgz#7e32f75b41381291d04611f1bf14109ac00651d7"
+qs@~2.3.3:
+ version "2.3.3"
+ resolved "https://registry.yarnpkg.com/qs/-/qs-2.3.3.tgz#e9e85adbe75da0bbe4c8e0476a086290f863b404"
+
qs@~6.4.0:
version "6.4.0"
resolved "https://registry.yarnpkg.com/qs/-/qs-6.4.0.tgz#13e26d28ad6b0ffaa91312cd3bf708ed351e7233"
@@ -3598,14 +3678,14 @@ require-main-filename@^1.0.1:
version "1.0.1"
resolved "https://registry.yarnpkg.com/require-main-filename/-/require-main-filename-1.0.1.tgz#97f717b69d48784f5f526a6c5aa8ffdda055a4d1"
+requires-port@1.x.x:
+ version "1.0.0"
+ resolved "https://registry.yarnpkg.com/requires-port/-/requires-port-1.0.0.tgz#925d2601d39ac485e091cf0da5c6e694dc3dcaff"
+
resolve-url@^0.2.1:
version "0.2.1"
resolved "https://registry.yarnpkg.com/resolve-url/-/resolve-url-0.2.1.tgz#2c637fe77c893afd2a663fe21aa9080068e2052a"
-resolve@1.1.7:
- version "1.1.7"
- resolved "https://registry.yarnpkg.com/resolve/-/resolve-1.1.7.tgz#203114d82ad2c5ed9e8e0411b3932875e889e97b"
-
resolve@^1.1.5, resolve@^1.1.6, resolve@^1.4.0:
version "1.6.0"
resolved "https://registry.yarnpkg.com/resolve/-/resolve-1.6.0.tgz#0fbd21278b27b4004481c395349e7aba60a9ff5c"
@@ -3649,12 +3729,6 @@ safer-eval@^1.2.3:
dependencies:
clones "^1.1.0"
-sanitize-filename@^1.6.1:
- version "1.6.1"
- resolved "https://registry.yarnpkg.com/sanitize-filename/-/sanitize-filename-1.6.1.tgz#612da1c96473fa02dccda92dcd5b4ab164a6772a"
- dependencies:
- truncate-utf8-bytes "^1.0.0"
-
sax@~1.2.1, sax@~1.2.4:
version "1.2.4"
resolved "https://registry.yarnpkg.com/sax/-/sax-1.2.4.tgz#2816234e2378bddc4e5354fab5caa895df7100d9"
@@ -4138,12 +4212,6 @@ trim-right@^1.0.1:
version "1.0.1"
resolved "https://registry.yarnpkg.com/trim-right/-/trim-right-1.0.1.tgz#cb2e1203067e0c8de1f614094b9fe45704ea6003"
-truncate-utf8-bytes@^1.0.0:
- version "1.0.2"
- resolved "https://registry.yarnpkg.com/truncate-utf8-bytes/-/truncate-utf8-bytes-1.0.2.tgz#405923909592d56f78a5818434b0b78489ca5f2b"
- dependencies:
- utf8-byte-length "^1.0.1"
-
tslib@^1.9.0:
version "1.9.0"
resolved "https://registry.yarnpkg.com/tslib/-/tslib-1.9.0.tgz#e37a86fda8cbbaf23a057f473c9f4dc64e5fc2e8"
@@ -4210,6 +4278,12 @@ union-value@^1.0.0:
is-extendable "^0.1.1"
set-value "^0.4.3"
+union@~0.4.3:
+ version "0.4.6"
+ resolved "https://registry.yarnpkg.com/union/-/union-0.4.6.tgz#198fbdaeba254e788b0efcb630bc11f24a2959e0"
+ dependencies:
+ qs "~2.3.3"
+
uniq@^1.0.1:
version "1.0.1"
resolved "https://registry.yarnpkg.com/uniq/-/uniq-1.0.1.tgz#b31c5ae8254844a3a8281541ce2b04b865a734ff"
@@ -4243,6 +4317,10 @@ urix@^0.1.0:
version "0.1.0"
resolved "https://registry.yarnpkg.com/urix/-/urix-0.1.0.tgz#da937f7a62e21fec1fd18d49b35c2935067a6c72"
+url-join@^2.0.2:
+ version "2.0.5"
+ resolved "https://registry.yarnpkg.com/url-join/-/url-join-2.0.5.tgz#5af22f18c052a000a48d7b82c5e9c2e2feeda728"
+
url@^0.11.0:
version "0.11.0"
resolved "https://registry.yarnpkg.com/url/-/url-0.11.0.tgz#3838e97cfc60521eb73c525a8e55bfdd9e2e28f1"
@@ -4256,10 +4334,6 @@ use@^3.1.0:
dependencies:
kind-of "^6.0.2"
-utf8-byte-length@^1.0.1:
- version "1.0.4"
- resolved "https://registry.yarnpkg.com/utf8-byte-length/-/utf8-byte-length-1.0.4.tgz#f45f150c4c66eee968186505ab93fcbb8ad6bf61"
-
util-deprecate@~1.0.1:
version "1.0.2"
resolved "https://registry.yarnpkg.com/util-deprecate/-/util-deprecate-1.0.2.tgz#450d4dc9fa70de732762fbd2d4a28981419a0ccf"
@@ -4589,6 +4663,10 @@ window-size@^0.2.0:
version "0.2.0"
resolved "https://registry.yarnpkg.com/window-size/-/window-size-0.2.0.tgz#b4315bb4214a3d7058ebeee892e13fa24d98b075"
+wordwrap@~0.0.2:
+ version "0.0.3"
+ resolved "https://registry.yarnpkg.com/wordwrap/-/wordwrap-0.0.3.tgz#a3d5da6cd5c0bc0008d37234bbaf1bed63059107"
+
wordwrap@~1.0.0:
version "1.0.0"
resolved "https://registry.yarnpkg.com/wordwrap/-/wordwrap-1.0.0.tgz#27584810891456a4171c8d0226441ade90cbcaeb"
diff --git a/mobilenet/imagenet_classes.js b/mobilenet/imagenet_classes.js
index e6b0777a6..09f705d79 100644
--- a/mobilenet/imagenet_classes.js
+++ b/mobilenet/imagenet_classes.js
@@ -1,6 +1,6 @@
/**
* @license
- * Copyright 2017 Google Inc. All Rights Reserved.
+ * Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
diff --git a/sentiment/build-resources.sh b/sentiment/build-resources.sh
new file mode 100755
index 000000000..2f0b2fd91
--- /dev/null
+++ b/sentiment/build-resources.sh
@@ -0,0 +1,73 @@
+#!/usr/bin/env bash
+
+# Copyright 2018 Google LLC. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+# Builds resources for the Sentiment demo.
+# Note this is not necessary to run the demo, because we already provide hosted
+# pre-built resources.
+# Usage example: do this from the 'sentiment' directory:
+# ./build-resources.sh lstm
+
+set -e
+
+DEMO_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+
+if [[ $# -lt 1 ]]; then
+ echo "Usage:"
+ echo " build-imdb-demo.sh "
+ echo
+ echo "MODEL_TYPE options: lstm | cnn"
+ exit 1
+fi
+MODEL_TYPE=$1
+shift
+echo "Using model type: ${MODEL_TYPE}"
+
+TRAIN_EPOCHS=5
+while true; do
+ if [[ "$1" == "--epochs" ]]; then
+ TRAIN_EPOCHS=$2
+ shift 2
+ elif [[ -z "$1" ]]; then
+ break
+ else
+ echo "ERROR: Unrecognized argument: $1"
+ exit 1
+ fi
+done
+
+RESOURCES_ROOT="${DEMO_DIR}/dist/resources"
+rm -rf "${RESOURCES_ROOT}"
+mkdir -p "${RESOURCES_ROOT}"
+
+# Run Python script to generate the pretrained model and weights files.
+# Make sure you install the tensorflowjs pip package first.
+
+python "${DEMO_DIR}/python/imdb.py" \
+ "${MODEL_TYPE}" \
+ --epochs "${TRAIN_EPOCHS}" \
+ --artifacts_dir "${RESOURCES_ROOT}"
+
+cd ${DEMO_DIR}
+yarn
+yarn build
+
+echo
+echo "-----------------------------------------------------------"
+echo "Resources written to ${RESOURCES_ROOT}."
+echo "You can now run the demo with 'yarn watch'."
+echo "-----------------------------------------------------------"
+echo
diff --git a/sentiment/index.html b/sentiment/index.html
index fde239d6d..60bd040e4 100644
--- a/sentiment/index.html
+++ b/sentiment/index.html
@@ -32,6 +32,10 @@
TensorFlow.js Layers: Sentiment Analysis Demo
+
+
+
+
Model type:
diff --git a/sentiment/index.js b/sentiment/index.js
index ffcb08a07..66f47ef34 100644
--- a/sentiment/index.js
+++ b/sentiment/index.js
@@ -19,24 +19,33 @@ import * as tf from '@tensorflow/tfjs';
import * as loader from './loader';
import * as ui from './ui';
-const HOSTED_MODEL_JSON_URL =
- 'https://storage.googleapis.com/tfjs-models/tfjs/sentiment_cnn_v1/model.json';
-const HOSTED_METADATA_JSON_URL =
- 'https://storage.googleapis.com/tfjs-models/tfjs/sentiment_cnn_v1/metadata.json';
+
+const HOSTED_URLS = {
+ model:
+ 'https://storage.googleapis.com/tfjs-models/tfjs/sentiment_cnn_v1/model.json',
+ metadata:
+ 'https://storage.googleapis.com/tfjs-models/tfjs/sentiment_cnn_v1/metadata.json'
+};
+
+const LOCAL_URLS = {
+ model: 'http://localhost:1235/resources/model.json',
+ metadata: 'http://localhost:1235/resources/metadata.json'
+};
class SentimentPredictor {
/**
* Initializes the Sentiment demo.
*/
- async init() {
- this.model = await loader.loadHostedPretrainedModel(HOSTED_MODEL_JSON_URL);
+ async init(urls) {
+ this.urls = urls;
+ this.model = await loader.loadHostedPretrainedModel(urls.model);
await this.loadMetadata();
return this;
}
async loadMetadata() {
const sentimentMetadata =
- await loader.loadHostedMetadata(HOSTED_METADATA_JSON_URL);
+ await loader.loadHostedMetadata(this.urls.metadata);
ui.showMetadata(sentimentMetadata);
this.indexFrom = sentimentMetadata['index_from'];
this.maxLen = sentimentMetadata['max_len'];
@@ -50,13 +59,11 @@ class SentimentPredictor {
// Convert to lower case and remove all punctuations.
const inputText =
text.trim().toLowerCase().replace(/(\.|\,|\!)/g, '').split(' ');
- ui.status(inputText);
// Look up word indices.
const inputBuffer = tf.buffer([1, this.maxLen], 'float32');
for (let i = 0; i < inputText.length; ++i) {
// TODO(cais): Deal with OOV words.
const word = inputText[i];
- ui.status(word);
inputBuffer.set(this.wordIndex[word] + this.indexFrom, 0, i);
}
const input = inputBuffer.toTensor();
@@ -78,9 +85,27 @@ class SentimentPredictor {
* function with the UI.
*/
async function setupSentiment() {
- const predictor = await new SentimentPredictor().init();
- ui.setPredictFunction(x => predictor.predict(x));
- ui.prepUI(x => predictor.predict(x));
+ if (await loader.urlExists(HOSTED_URLS.model)) {
+ ui.status('Model available: ' + HOSTED_URLS.model);
+ const button = document.getElementById('load-pretrained-remote');
+ button.addEventListener('click', async () => {
+ const predictor = await new SentimentPredictor().init(HOSTED_URLS);
+ ui.prepUI(x => predictor.predict(x));
+ });
+ button.style.display = 'inline-block';
+ }
+
+ if (await loader.urlExists(LOCAL_URLS.model)) {
+ ui.status('Model available: ' + LOCAL_URLS.model);
+ const button = document.getElementById('load-pretrained-local');
+ button.addEventListener('click', async () => {
+ const predictor = await new SentimentPredictor().init(LOCAL_URLS);
+ ui.prepUI(x => predictor.predict(x));
+ });
+ button.style.display = 'inline-block';
+ }
+
+ ui.status('Standing by.');
}
setupSentiment();
diff --git a/sentiment/loader.js b/sentiment/loader.js
index 968bdffc5..263729b23 100644
--- a/sentiment/loader.js
+++ b/sentiment/loader.js
@@ -18,6 +18,19 @@
import * as tf from '@tensorflow/tfjs';
import * as ui from './ui';
+/**
+ * Test whether a given URL is retrievable.
+ */
+export async function urlExists(url) {
+ ui.status('Testing url ' + url);
+ try {
+ const response = await fetch(url, {method: 'HEAD'});
+ return response.ok;
+ } catch (err) {
+ return false;
+ }
+}
+
/**
* Load pretrained model stored at a remote URL.
*
@@ -28,9 +41,13 @@ export async function loadHostedPretrainedModel(url) {
try {
const model = await tf.loadModel(url);
ui.status('Done loading pretrained model.');
+ // We can't load a model twice due to
+ // https://github.com/tensorflow/tfjs/issues/34
+ // Therefore we remove the load buttons to avoid user confusion.
+ ui.disableLoadModelButtons();
return model;
} catch (err) {
- console.log(err);
+ console.error(err);
ui.status('Loading pretrained model failed.');
}
}
@@ -48,7 +65,7 @@ export async function loadHostedMetadata(url) {
ui.status('Done loading metadata.');
return metadata;
} catch (err) {
- console.log(err);
+ console.error(err);
ui.status('Loading metadata failed.');
}
}
diff --git a/sentiment/package.json b/sentiment/package.json
index c3fd6a4e7..01eb77ae4 100644
--- a/sentiment/package.json
+++ b/sentiment/package.json
@@ -13,7 +13,7 @@
"vega-embed": "~3.0.0"
},
"scripts": {
- "watch": "NODE_ENV=development parcel --no-hmr --open index.html ",
+ "watch": "./serve.sh",
"build": "NODE_ENV=production parcel build index.html --no-minify --public-url ./"
},
"devDependencies": {
@@ -21,7 +21,8 @@
"babel-polyfill": "~6.26.0",
"babel-preset-env": "~1.6.1",
"clang-format": "~1.2.2",
- "parcel-bundler": "~1.6.2"
+ "http-server": "~0.10.0",
+ "parcel-bundler": "~1.7.0"
},
"babel": {
"presets": [
diff --git a/sentiment/python/__init__.py b/sentiment/python/__init__.py
new file mode 100644
index 000000000..636b70f0d
--- /dev/null
+++ b/sentiment/python/__init__.py
@@ -0,0 +1,18 @@
+# Copyright 2018 Google LLC. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
diff --git a/sentiment/python/imdb.py b/sentiment/python/imdb.py
new file mode 100644
index 000000000..2abe3a4e3
--- /dev/null
+++ b/sentiment/python/imdb.py
@@ -0,0 +1,258 @@
+# Copyright 2018 Google LLC. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+"""IMDB sentiment classification example.
+
+Based on Python Keras examples:
+ https://github.com/keras-team/keras/blob/master/examples/imdb_cnn.py
+ https://github.com/keras-team/keras/blob/master/examples/imdb_lstm.py
+
+TODO(cais): Add
+ https://github.com/keras-team/keras/blob/master/examples/imdb_bidirectional_lstm.py
+ once b/74429960 is fixed.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import json
+import os
+
+import keras
+import tensorflowjs as tfjs
+
+
+INDEX_FROM = 3
+# Offset in word index. Used during word index lookup and reverse lookup.
+
+
+def get_word_index(reverse=False):
+ """Get word index.
+
+ Args:
+ reverse: Reverse the index, so that the returned index is from index values
+ to words.
+
+ Returns:
+ The word index as a `dict`.
+ """
+ word_index = keras.datasets.imdb.get_word_index()
+ if reverse:
+ word_index = dict((word_index[key], key) for key in word_index)
+ return word_index
+
+
+def indices_to_words(reverse_index, indices):
+ """Convert an iterable of word indices into words.
+
+ Args:
+ reverse_index: An `dict` mapping word index (as `int`) to word (as `str`).
+ indices: An iterable of word indices.
+
+ Returns:
+ Mapped words as a `list` of `str`s.
+ """
+ return [reverse_index[i - INDEX_FROM] if i >= INDEX_FROM else 'OOV'
+ for i in indices]
+
+
+def get_imdb_data(vocabulary_size, max_len):
+ """Get IMDB data for training and validation.
+
+ Args:
+ vocabulary_size: Size of the vocabulary, as an `int`.
+ max_len: Cut text after this number of words.
+
+ Returns:
+ x_train: An int array of shape `(num_exapmles, max_len)`: index-encoded
+ sentences.
+ y_train: An int array of shape `(num_exapmles,)`: labels for the sentences.
+ x_test: Same as `x_train`, but for test.
+ y_test: Same as `y_train`, but for test.
+ """
+ print("Getting IMDB data with vocabulary_size %d" % vocabulary_size)
+ (x_train, y_train), (x_test, y_test) = keras.datasets.imdb.load_data(
+ num_words=vocabulary_size)
+ x_train = keras.preprocessing.sequence.pad_sequences(x_train, maxlen=max_len)
+ x_test = keras.preprocessing.sequence.pad_sequences(x_test, maxlen=max_len)
+ return x_train, y_train, x_test, y_test
+
+
+def train_model(model_type,
+ vocabulary_size,
+ embedding_size,
+ x_train,
+ y_train,
+ x_test,
+ y_test,
+ epochs,
+ batch_size):
+ """Train a model for IMDB sentiment classification.
+
+ Args:
+ model_type: Type of the model to train, as a `str`.
+ vocabulary_size: Vocabulary size.
+ embedding_size: Embedding dimensions.
+ x_train: An int array of shape `(num_exapmles, max_len)`: index-encoded
+ sentences.
+ y_train: An int array of shape `(num_exapmles,)`: labels for the sentences.
+ x_test: Same as `x_train`, but for test.
+ y_test: Same as `y_train`, but for test.
+ epochs: Number of epochs to train the model for.
+ batch_size: Batch size to use during trainng.
+
+ Returns:
+ The trained model instance.
+
+ Raises:
+ ValueError: on invalid model type.
+ """
+
+ model = keras.Sequential()
+ model.add(keras.layers.Embedding(vocabulary_size, embedding_size))
+ if model_type == 'bidirectional_lstm':
+ # TODO(cais): Uncomment the following once bug b/74429960 is fixed.
+ # model.add(keras.layers.Embedding(
+ # vocabulary_size, 128, input_length=maxlen))
+ # model.add(keras.layers.Bidirectional(
+ # keras.layers.LSTM(64,
+ # kernel_initializer='glorot_normal',
+ # recurrent_initializer ='glorot_normal')))
+ # model.add(keras.layers.Dropout(0.5))
+ raise NotImplementedError()
+ elif model_type == 'cnn':
+ model.add(keras.layers.Dropout(0.2))
+ model.add(keras.layers.Conv1D(250,
+ 3,
+ padding='valid',
+ activation='relu',
+ strides=1))
+ model.add(keras.layers.GlobalMaxPooling1D())
+ model.add(keras.layers.Dense(250, activation='relu'))
+ elif model_type == 'lstm':
+ model.add(
+ keras.layers.LSTM(
+ 128,
+ kernel_initializer='glorot_normal',
+ recurrent_initializer='glorot_normal'))
+ # TODO(cais): Remove glorot_normal and use the default orthogonal once
+ # SVD is available.
+ else:
+ raise ValueError("Invalid model type: '%s'" % model_type)
+ model.add(keras.layers.Dense(1, activation='sigmoid'))
+
+ model.compile('adam', 'binary_crossentropy', metrics=['accuracy'])
+ model.fit(x_train, y_train,
+ batch_size=batch_size,
+ epochs=epochs,
+ validation_data=[x_test, y_test])
+ return model
+
+
+def main():
+ x_train, y_train, x_test, y_test = (
+ get_imdb_data(FLAGS.vocabulary_size, FLAGS.max_len))
+
+ model = train_model(FLAGS.model_type,
+ FLAGS.vocabulary_size,
+ FLAGS.embedding_size,
+ x_train,
+ y_train,
+ x_test,
+ y_test,
+ FLAGS.epochs,
+ FLAGS.batch_size)
+
+ # Display a number test phrases and their final classification.
+ forward_index = get_word_index()
+ reverse_index = get_word_index(reverse=True)
+ print('\n')
+ for i in range(FLAGS.num_show):
+ print('--- Test Case %d ---' % (i + 1))
+ print('Sentence: "' +
+ ' '.join(indices_to_words(reverse_index, x_test[i, :])) + '"')
+ print('Truth: %d' % y_test[i])
+ print('Prediction: %s\n' % model.predict(x_test[i : i + 1, :])[0][0])
+
+ # Save metadata, including word index, INDEX_FROM and max_len and model
+ # hyperparameters.
+ metadata = {
+ 'word_index': forward_index,
+ 'index_from': INDEX_FROM,
+ 'max_len': FLAGS.max_len,
+ 'model_type': FLAGS.model_type,
+ 'vocabulary_size': FLAGS.vocabulary_size,
+ 'embedding_size': FLAGS.embedding_size,
+ 'epochs': FLAGS.epochs,
+ 'batch_size': FLAGS.batch_size,
+ }
+
+ if not os.path.isdir(FLAGS.artifacts_dir):
+ os.makedirs(FLAGS.artifacts_dir)
+ metadata_json_path = os.path.join(FLAGS.artifacts_dir, 'metadata.json')
+ json.dump(metadata, open(metadata_json_path, 'wt'))
+ print('\nSaved model metadata at: %s' % metadata_json_path)
+
+ tfjs.converters.save_keras_model(model, FLAGS.artifacts_dir)
+ print('\nSaved model artifcats in directory: %s' % FLAGS.artifacts_dir)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser('IMDB sentiment classification model')
+ parser.add_argument(
+ 'model_type',
+ type=str,
+ help='Type of model to train for the IMDB sentiment classification task: '
+ '(cnn | lstm)')
+ parser.add_argument(
+ '--vocabulary_size',
+ type=int,
+ default=20000,
+ help='Vocabulary size.')
+ parser.add_argument(
+ '--embedding_size',
+ type=int,
+ default=128,
+ help='Embedding size.')
+ parser.add_argument(
+ '--max_len',
+ type=int,
+ default=100,
+ help='Cut text after this number of words.')
+ parser.add_argument(
+ '--epochs',
+ type=int,
+ default=5,
+ help='Number of epochs to train the model for.')
+ parser.add_argument(
+ '--batch_size',
+ type=int,
+ default=32,
+ help='Batch size used during training.')
+ parser.add_argument(
+ '--num_show',
+ type=int,
+ default=5,
+ help='Number of sentences to show prediction score on after training.')
+ parser.add_argument(
+ '--artifacts_dir',
+ type=str,
+ default='/tmp/mnist.keras',
+ help='Local path for saving the TensorFlow.js artifacts.')
+
+ FLAGS, _ = parser.parse_known_args()
+ main()
diff --git a/sentiment/python/imdb_test.py b/sentiment/python/imdb_test.py
new file mode 100644
index 000000000..003acf95d
--- /dev/null
+++ b/sentiment/python/imdb_test.py
@@ -0,0 +1,96 @@
+# Copyright 2018 Google LLC. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+"""Test for the IMDB model and supporting functions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import shutil
+import tempfile
+import unittest
+
+import numpy as np
+import keras
+
+from . import imdb
+
+class IMDBTest(unittest.TestCase):
+
+ def setUp(self):
+ self._tmp_dir = tempfile.mkdtemp()
+ super(IMDBTest, self).setUp()
+
+ def tearDown(self):
+ if os.path.isdir(self._tmp_dir):
+ shutil.rmtree(self._tmp_dir)
+ super(IMDBTest, self).tearDown()
+
+ def testGetWordIndexForward(self):
+ word_index = imdb.get_word_index()
+ self.assertGreater(word_index['bar'], 0)
+
+ def testIndicesToWordsReverse(self):
+ forward_index = imdb.get_word_index()
+ reverse_index = imdb.get_word_index(reverse=True)
+ self.assertEqual('bar', reverse_index[forward_index['bar']])
+
+ def testIndicesToWords(self):
+ forward_index = imdb.get_word_index()
+ reverse_index = imdb.get_word_index(reverse=True)
+ indices = [forward_index[word] + imdb.INDEX_FROM
+ for word in ['one', 'two', 'three']]
+ self.assertEqual(['one', 'two', 'three'],
+ imdb.indices_to_words(reverse_index, indices))
+
+ def testTrainLSTMModel(self):
+ data_size = 10
+ x_train = np.random.randint(0, 100, (data_size,))
+ y_train = np.random.randint(0, 2, (data_size,))
+ x_test = np.random.randint(0, 100, (data_size,))
+ y_test = np.random.randint(0, 2, (data_size,))
+
+ vocabulary_size = 100
+ embedding_size = 32
+ epochs = 1
+ batch_size = data_size
+ model = imdb.train_model(
+ 'lstm', vocabulary_size, embedding_size,
+ x_train, y_train, x_test, y_test,
+ epochs, batch_size)
+ self.assertTrue(model.layers)
+
+ def testTrainModelWithInvalidModelTypeRaisesError(self):
+ data_size = 10
+ x_train = np.random.randint(0, 100, (data_size,))
+ y_train = np.random.randint(0, 2, (data_size,))
+ x_test = np.random.randint(0, 100, (data_size,))
+ y_test = np.random.randint(0, 2, (data_size,))
+
+ vocabulary_size = 100
+ embedding_size = 32
+ epochs = 1
+ batch_size = data_size
+ with self.assertRaises(ValueError):
+ imdb.train_model(
+ 'nonsensical_model_type', vocabulary_size, embedding_size,
+ x_train, y_train, x_test, y_test,
+ epochs, batch_size)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/sentiment/serve.sh b/sentiment/serve.sh
new file mode 100755
index 000000000..e5ede63ed
--- /dev/null
+++ b/sentiment/serve.sh
@@ -0,0 +1,42 @@
+#!/usr/bin/env bash
+
+# Copyright 2018 Google LLC. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+# This script starts two HTTP servers on different ports:
+# * Port 1234 (using parcel) serves HTML and JavaScript.
+# * Port 1235 (using http-server) serves pretrained model resources.
+#
+# The reason for this arrangement is that Parcel currently has a limitation that
+# prevents it from serving the pretrained models; see
+# https://github.com/parcel-bundler/parcel/issues/1098. Once that issue is
+# resolved, a single Parcel server will be sufficient.
+
+NODE_ENV=development
+RESOURCE_PORT=1235
+
+# Ensure that http-server is available
+yarn
+
+echo Starting the pretrained model server...
+node_modules/http-server/bin/http-server dist --cors -p "${RESOURCE_PORT}" > /dev/null & HTTP_SERVER_PID=$!
+
+echo Starting the example html/js server...
+# This uses port 1234 by default.
+node_modules/parcel-bundler/bin/cli.js serve -d dist --open --no-hmr --public-url / index.html
+
+# When the Parcel server exits, kill the http-server too.
+kill $HTTP_SERVER_PID
+
diff --git a/sentiment/ui.js b/sentiment/ui.js
index d037cc57c..52158e29d 100644
--- a/sentiment/ui.js
+++ b/sentiment/ui.js
@@ -23,6 +23,7 @@ const exampleReviews = {
};
export function status(statusText) {
+ console.log(statusText);
document.getElementById('status').textContent = statusText;
}
@@ -36,6 +37,7 @@ export function showMetadata(sentimentMetadataJSON) {
}
export function prepUI(predict) {
+ setPredictFunction(predict);
const testExampleSelect = document.getElementById('test-example-select');
testExampleSelect.addEventListener('change', () => {
setReviewText(exampleReviews[testExampleSelect.value], predict);
@@ -62,7 +64,12 @@ function setReviewText(text, predict) {
doPredict(predict);
}
-export function setPredictFunction(predict) {
+function setPredictFunction(predict) {
const reviewText = document.getElementById('review-text');
reviewText.addEventListener('input', () => doPredict(predict));
}
+
+export function disableLoadModelButtons() {
+ document.getElementById('load-pretrained-remote').style.display = 'none';
+ document.getElementById('load-pretrained-local').style.display = 'none';
+}
diff --git a/sentiment/yarn.lock b/sentiment/yarn.lock
index 2e61d4f3d..87bfdaf3c 100644
--- a/sentiment/yarn.lock
+++ b/sentiment/yarn.lock
@@ -705,6 +705,10 @@ binary-extensions@^1.0.0:
version "1.11.0"
resolved "https://registry.yarnpkg.com/binary-extensions/-/binary-extensions-1.11.0.tgz#46aa1751fb6a2f93ee5e689bb1087d4b14c6c205"
+bindings@~1.2.1:
+ version "1.2.1"
+ resolved "https://registry.yarnpkg.com/bindings/-/bindings-1.2.1.tgz#14ad6113812d2d37d72e67b4cacb4bb726505f11"
+
block-stream@*:
version "0.0.9"
resolved "https://registry.yarnpkg.com/block-stream/-/block-stream-0.0.9.tgz#13ebfe778a03205cfe03751481ebb4b3300c126a"
@@ -766,12 +770,6 @@ brorand@^1.0.1:
version "1.1.0"
resolved "https://registry.yarnpkg.com/brorand/-/brorand-1.1.0.tgz#12c25efe40a45e3c323eb8675a0a0ce57b22371f"
-browser-resolve@^1.11.2:
- version "1.11.2"
- resolved "https://registry.yarnpkg.com/browser-resolve/-/browser-resolve-1.11.2.tgz#8ff09b0a2c421718a1051c260b32e48f442938ce"
- dependencies:
- resolve "1.1.7"
-
browserify-aes@^1.0.0, browserify-aes@^1.0.4:
version "1.1.1"
resolved "https://registry.yarnpkg.com/browserify-aes/-/browserify-aes-1.1.1.tgz#38b7ab55edb806ff2dcda1a7f1620773a477c49f"
@@ -1066,6 +1064,10 @@ colormin@^1.0.5:
css-color-names "0.0.4"
has "^1.0.1"
+colors@1.0.3:
+ version "1.0.3"
+ resolved "https://registry.yarnpkg.com/colors/-/colors-1.0.3.tgz#0433f44d809680fdeb60ed260f1b0c262e82a40b"
+
colors@~1.1.2:
version "1.1.2"
resolved "https://registry.yarnpkg.com/colors/-/colors-1.1.2.tgz#168a4701756b6a7f51a12ce0c97bfa28c084ed63"
@@ -1145,6 +1147,10 @@ core-util-is@1.0.2, core-util-is@~1.0.0:
version "1.0.2"
resolved "https://registry.yarnpkg.com/core-util-is/-/core-util-is-1.0.2.tgz#b5fd54220aa2bc5ab57aab7140c940754503c1a7"
+corser@~2.0.0:
+ version "2.0.1"
+ resolved "https://registry.yarnpkg.com/corser/-/corser-2.0.1.tgz#8eda252ecaab5840dcd975ceb90d9370c819ff87"
+
create-ecdh@^4.0.0:
version "4.0.0"
resolved "https://registry.yarnpkg.com/create-ecdh/-/create-ecdh-4.0.0.tgz#888c723596cdf7612f6498233eebd7a35301737d"
@@ -1433,6 +1439,13 @@ date-now@^0.1.4:
version "0.1.4"
resolved "https://registry.yarnpkg.com/date-now/-/date-now-0.1.4.tgz#eaf439fd4d4848ad74e5cc7dbef200672b9e345b"
+deasync@^0.1.12:
+ version "0.1.12"
+ resolved "https://registry.yarnpkg.com/deasync/-/deasync-0.1.12.tgz#0159492a4133ab301d6c778cf01e74e63b10e549"
+ dependencies:
+ bindings "~1.2.1"
+ nan "^2.0.7"
+
debug@2.6.9, debug@^2.2.0, debug@^2.3.3, debug@^2.6.8:
version "2.6.9"
resolved "https://registry.yarnpkg.com/debug/-/debug-2.6.9.tgz#5d128515df134ff327e90a4c93f4e077a536341f"
@@ -1585,6 +1598,15 @@ ecc-jsbn@~0.1.1:
dependencies:
jsbn "~0.1.0"
+ecstatic@^2.0.0:
+ version "2.2.1"
+ resolved "https://registry.yarnpkg.com/ecstatic/-/ecstatic-2.2.1.tgz#b5087fad439dd9dd49d31e18131454817fe87769"
+ dependencies:
+ he "^1.1.1"
+ mime "^1.2.11"
+ minimist "^1.1.0"
+ url-join "^2.0.2"
+
editorconfig@^0.13.2:
version "0.13.3"
resolved "https://registry.yarnpkg.com/editorconfig/-/editorconfig-0.13.3.tgz#e5219e587951d60958fd94ea9a9a008cdeff1b34"
@@ -1696,6 +1718,10 @@ etag@~1.8.1:
version "1.8.1"
resolved "https://registry.yarnpkg.com/etag/-/etag-1.8.1.tgz#41ae2eeb65efa62268aebfea83ac7d79299b0887"
+eventemitter3@1.x.x:
+ version "1.2.0"
+ resolved "https://registry.yarnpkg.com/eventemitter3/-/eventemitter3-1.2.0.tgz#1c86991d816ad1e504750e73874224ecf3bec508"
+
events@^1.0.0:
version "1.1.1"
resolved "https://registry.yarnpkg.com/events/-/events-1.1.1.tgz#9ebdb7635ad099c70dcc4c2a1f5004288e8bd924"
@@ -2033,6 +2059,10 @@ hawk@3.1.3, hawk@~3.1.3:
hoek "2.x.x"
sntp "1.x.x"
+he@^1.1.1:
+ version "1.1.1"
+ resolved "https://registry.yarnpkg.com/he/-/he-1.1.1.tgz#93410fd21b009735151f8868c2f271f3427e23fd"
+
hmac-drbg@^1.0.0:
version "1.0.1"
resolved "https://registry.yarnpkg.com/hmac-drbg/-/hmac-drbg-1.0.1.tgz#d2745701025a6c775a6c545793ed502fc0c649a1"
@@ -2060,7 +2090,7 @@ html-comment-regex@^1.1.0:
version "1.1.1"
resolved "https://registry.yarnpkg.com/html-comment-regex/-/html-comment-regex-1.1.1.tgz#668b93776eaae55ebde8f3ad464b307a4963625e"
-htmlnano@^0.1.6:
+htmlnano@^0.1.7:
version "0.1.7"
resolved "https://registry.yarnpkg.com/htmlnano/-/htmlnano-0.1.7.tgz#1751937a05f122a3248dba1c63edf01f6d36cb84"
dependencies:
@@ -2091,6 +2121,26 @@ http-errors@~1.6.2:
setprototypeof "1.0.3"
statuses ">= 1.3.1 < 2"
+http-proxy@^1.8.1:
+ version "1.16.2"
+ resolved "https://registry.yarnpkg.com/http-proxy/-/http-proxy-1.16.2.tgz#06dff292952bf64dbe8471fa9df73066d4f37742"
+ dependencies:
+ eventemitter3 "1.x.x"
+ requires-port "1.x.x"
+
+http-server@~0.10.0:
+ version "0.10.0"
+ resolved "https://registry.yarnpkg.com/http-server/-/http-server-0.10.0.tgz#b2a446b16a9db87ed3c622ba9beb1b085b1234a7"
+ dependencies:
+ colors "1.0.3"
+ corser "~2.0.0"
+ ecstatic "^2.0.0"
+ http-proxy "^1.8.1"
+ opener "~1.4.0"
+ optimist "0.6.x"
+ portfinder "^1.0.13"
+ union "~0.4.3"
+
http-signature@~1.1.0:
version "1.1.1"
resolved "https://registry.yarnpkg.com/http-signature/-/http-signature-1.1.1.tgz#df72e267066cd0ac67fb76adf8e134a8fbcf91bf"
@@ -2610,6 +2660,10 @@ mime@1.4.1:
version "1.4.1"
resolved "https://registry.yarnpkg.com/mime/-/mime-1.4.1.tgz#121f9ebc49e3766f311a76e1fa1c8003c4b03aa6"
+mime@^1.2.11:
+ version "1.6.0"
+ resolved "https://registry.yarnpkg.com/mime/-/mime-1.6.0.tgz#32cd9e5c64553bd58d19a568af452acff04981b1"
+
mimic-fn@^1.0.0:
version "1.2.0"
resolved "https://registry.yarnpkg.com/mimic-fn/-/mimic-fn-1.2.0.tgz#820c86a39334640e99516928bd03fca88057d022"
@@ -2632,10 +2686,14 @@ minimist@0.0.8:
version "0.0.8"
resolved "https://registry.yarnpkg.com/minimist/-/minimist-0.0.8.tgz#857fcabfc3397d2625b8228262e86aa7a011b05d"
-minimist@^1.1.3, minimist@^1.2.0:
+minimist@^1.1.0, minimist@^1.1.3, minimist@^1.2.0:
version "1.2.0"
resolved "https://registry.yarnpkg.com/minimist/-/minimist-1.2.0.tgz#a35008b20f41383eec1fb914f4cd5df79a264284"
+minimist@~0.0.1:
+ version "0.0.10"
+ resolved "https://registry.yarnpkg.com/minimist/-/minimist-0.0.10.tgz#de3f98543dbf96082be48ad1a0c7cda836301dcf"
+
mixin-deep@^1.2.0:
version "1.3.1"
resolved "https://registry.yarnpkg.com/mixin-deep/-/mixin-deep-1.3.1.tgz#a49e7268dce1a0d9698e45326c5626df3543d0fe"
@@ -2643,7 +2701,7 @@ mixin-deep@^1.2.0:
for-in "^1.0.2"
is-extendable "^1.0.1"
-"mkdirp@>=0.5 0", mkdirp@^0.5.1, mkdirp@~0.5.0, mkdirp@~0.5.1:
+mkdirp@0.5.x, "mkdirp@>=0.5 0", mkdirp@^0.5.1, mkdirp@~0.5.0, mkdirp@~0.5.1:
version "0.5.1"
resolved "https://registry.yarnpkg.com/mkdirp/-/mkdirp-0.5.1.tgz#30057438eac6cf7f8c4767f38648d6697d75c903"
dependencies:
@@ -2653,7 +2711,7 @@ ms@2.0.0:
version "2.0.0"
resolved "https://registry.yarnpkg.com/ms/-/ms-2.0.0.tgz#5608aeadfc00be6c2901df5f9861788de0d597c8"
-nan@^2.3.0, nan@^2.4.0:
+nan@^2.0.7, nan@^2.3.0, nan@^2.4.0:
version "2.10.0"
resolved "https://registry.yarnpkg.com/nan/-/nan-2.10.0.tgz#96d0cd610ebd58d4b4de9cc0c6828cda99c7548f"
@@ -2860,12 +2918,23 @@ once@^1.3.0, once@^1.3.3:
dependencies:
wrappy "1"
+opener@~1.4.0:
+ version "1.4.3"
+ resolved "https://registry.yarnpkg.com/opener/-/opener-1.4.3.tgz#5c6da2c5d7e5831e8ffa3964950f8d6674ac90b8"
+
opn@^5.1.0:
version "5.3.0"
resolved "https://registry.yarnpkg.com/opn/-/opn-5.3.0.tgz#64871565c863875f052cfdf53d3e3cb5adb53b1c"
dependencies:
is-wsl "^1.1.0"
+optimist@0.6.x:
+ version "0.6.1"
+ resolved "https://registry.yarnpkg.com/optimist/-/optimist-0.6.1.tgz#da3ea74686fa21a19a111c326e90eb15a0196686"
+ dependencies:
+ minimist "~0.0.1"
+ wordwrap "~0.0.2"
+
optionator@^0.8.1:
version "0.8.2"
resolved "https://registry.yarnpkg.com/optionator/-/optionator-0.8.2.tgz#364c5e409d3f4d6301d6c0b4c05bba50180aeb64"
@@ -2938,9 +3007,9 @@ pako@~1.0.5:
version "1.0.6"
resolved "https://registry.yarnpkg.com/pako/-/pako-1.0.6.tgz#0101211baa70c4bca4a0f63f2206e97b7dfaf258"
-parcel-bundler@~1.6.2:
- version "1.6.2"
- resolved "https://registry.yarnpkg.com/parcel-bundler/-/parcel-bundler-1.6.2.tgz#e415c4993b6f4c48cd410427594ce80bbb7d90f1"
+parcel-bundler@~1.7.0:
+ version "1.7.0"
+ resolved "https://registry.yarnpkg.com/parcel-bundler/-/parcel-bundler-1.7.0.tgz#8c2512615fd602d2f39bd97bfd128f8fe524b321"
dependencies:
babel-code-frame "^6.26.0"
babel-core "^6.25.0"
@@ -2953,7 +3022,6 @@ parcel-bundler@~1.6.2:
babel-types "^6.26.0"
babylon "^6.17.4"
babylon-walk "^1.0.2"
- browser-resolve "^1.11.2"
browserslist "^2.11.2"
chalk "^2.1.0"
chokidar "^2.0.1"
@@ -2961,12 +3029,13 @@ parcel-bundler@~1.6.2:
commander "^2.11.0"
cross-spawn "^6.0.4"
cssnano "^3.10.0"
+ deasync "^0.1.12"
dotenv "^5.0.0"
filesize "^3.6.0"
get-port "^3.2.0"
glob "^7.1.2"
grapheme-breaker "^0.3.2"
- htmlnano "^0.1.6"
+ htmlnano "^0.1.7"
is-url "^1.2.2"
js-yaml "^3.10.0"
json5 "^0.5.1"
@@ -2976,13 +3045,12 @@ parcel-bundler@~1.6.2:
node-libs-browser "^2.0.0"
opn "^5.1.0"
physical-cpu-count "^2.0.0"
- postcss "^6.0.10"
+ postcss "^6.0.19"
postcss-value-parser "^3.3.0"
posthtml "^0.11.2"
posthtml-parser "^0.4.0"
posthtml-render "^1.1.0"
resolve "^1.4.0"
- sanitize-filename "^1.6.1"
semver "^5.4.1"
serialize-to-js "^1.1.1"
serve-static "^1.12.4"
@@ -3089,6 +3157,14 @@ pinkie@^2.0.0:
version "2.0.4"
resolved "https://registry.yarnpkg.com/pinkie/-/pinkie-2.0.4.tgz#72556b80cfa0d48a974e80e77248e80ed4f7f870"
+portfinder@^1.0.13:
+ version "1.0.13"
+ resolved "https://registry.yarnpkg.com/portfinder/-/portfinder-1.0.13.tgz#bb32ecd87c27104ae6ee44b5a3ccbf0ebb1aede9"
+ dependencies:
+ async "^1.5.2"
+ debug "^2.2.0"
+ mkdirp "0.5.x"
+
posix-character-classes@^0.1.0:
version "0.1.1"
resolved "https://registry.yarnpkg.com/posix-character-classes/-/posix-character-classes-0.1.1.tgz#01eac0fe3b5af71a2a6c02feabb8c1fef7e00eab"
@@ -3304,9 +3380,9 @@ postcss@^5.0.10, postcss@^5.0.11, postcss@^5.0.12, postcss@^5.0.13, postcss@^5.0
source-map "^0.5.6"
supports-color "^3.2.3"
-postcss@^6.0.10:
- version "6.0.20"
- resolved "https://registry.yarnpkg.com/postcss/-/postcss-6.0.20.tgz#686107e743a12d5530cb68438c590d5b2bf72c3c"
+postcss@^6.0.19:
+ version "6.0.21"
+ resolved "https://registry.yarnpkg.com/postcss/-/postcss-6.0.21.tgz#8265662694eddf9e9a5960db6da33c39e4cd069d"
dependencies:
chalk "^2.3.2"
source-map "^0.6.1"
@@ -3397,6 +3473,10 @@ q@^1.1.2:
version "1.5.1"
resolved "https://registry.yarnpkg.com/q/-/q-1.5.1.tgz#7e32f75b41381291d04611f1bf14109ac00651d7"
+qs@~2.3.3:
+ version "2.3.3"
+ resolved "https://registry.yarnpkg.com/qs/-/qs-2.3.3.tgz#e9e85adbe75da0bbe4c8e0476a086290f863b404"
+
qs@~6.4.0:
version "6.4.0"
resolved "https://registry.yarnpkg.com/qs/-/qs-6.4.0.tgz#13e26d28ad6b0ffaa91312cd3bf708ed351e7233"
@@ -3598,14 +3678,14 @@ require-main-filename@^1.0.1:
version "1.0.1"
resolved "https://registry.yarnpkg.com/require-main-filename/-/require-main-filename-1.0.1.tgz#97f717b69d48784f5f526a6c5aa8ffdda055a4d1"
+requires-port@1.x.x:
+ version "1.0.0"
+ resolved "https://registry.yarnpkg.com/requires-port/-/requires-port-1.0.0.tgz#925d2601d39ac485e091cf0da5c6e694dc3dcaff"
+
resolve-url@^0.2.1:
version "0.2.1"
resolved "https://registry.yarnpkg.com/resolve-url/-/resolve-url-0.2.1.tgz#2c637fe77c893afd2a663fe21aa9080068e2052a"
-resolve@1.1.7:
- version "1.1.7"
- resolved "https://registry.yarnpkg.com/resolve/-/resolve-1.1.7.tgz#203114d82ad2c5ed9e8e0411b3932875e889e97b"
-
resolve@^1.1.5, resolve@^1.1.6, resolve@^1.4.0:
version "1.6.0"
resolved "https://registry.yarnpkg.com/resolve/-/resolve-1.6.0.tgz#0fbd21278b27b4004481c395349e7aba60a9ff5c"
@@ -3649,12 +3729,6 @@ safer-eval@^1.2.3:
dependencies:
clones "^1.1.0"
-sanitize-filename@^1.6.1:
- version "1.6.1"
- resolved "https://registry.yarnpkg.com/sanitize-filename/-/sanitize-filename-1.6.1.tgz#612da1c96473fa02dccda92dcd5b4ab164a6772a"
- dependencies:
- truncate-utf8-bytes "^1.0.0"
-
sax@~1.2.1, sax@~1.2.4:
version "1.2.4"
resolved "https://registry.yarnpkg.com/sax/-/sax-1.2.4.tgz#2816234e2378bddc4e5354fab5caa895df7100d9"
@@ -4138,12 +4212,6 @@ trim-right@^1.0.1:
version "1.0.1"
resolved "https://registry.yarnpkg.com/trim-right/-/trim-right-1.0.1.tgz#cb2e1203067e0c8de1f614094b9fe45704ea6003"
-truncate-utf8-bytes@^1.0.0:
- version "1.0.2"
- resolved "https://registry.yarnpkg.com/truncate-utf8-bytes/-/truncate-utf8-bytes-1.0.2.tgz#405923909592d56f78a5818434b0b78489ca5f2b"
- dependencies:
- utf8-byte-length "^1.0.1"
-
tslib@^1.9.0:
version "1.9.0"
resolved "https://registry.yarnpkg.com/tslib/-/tslib-1.9.0.tgz#e37a86fda8cbbaf23a057f473c9f4dc64e5fc2e8"
@@ -4210,6 +4278,12 @@ union-value@^1.0.0:
is-extendable "^0.1.1"
set-value "^0.4.3"
+union@~0.4.3:
+ version "0.4.6"
+ resolved "https://registry.yarnpkg.com/union/-/union-0.4.6.tgz#198fbdaeba254e788b0efcb630bc11f24a2959e0"
+ dependencies:
+ qs "~2.3.3"
+
uniq@^1.0.1:
version "1.0.1"
resolved "https://registry.yarnpkg.com/uniq/-/uniq-1.0.1.tgz#b31c5ae8254844a3a8281541ce2b04b865a734ff"
@@ -4243,6 +4317,10 @@ urix@^0.1.0:
version "0.1.0"
resolved "https://registry.yarnpkg.com/urix/-/urix-0.1.0.tgz#da937f7a62e21fec1fd18d49b35c2935067a6c72"
+url-join@^2.0.2:
+ version "2.0.5"
+ resolved "https://registry.yarnpkg.com/url-join/-/url-join-2.0.5.tgz#5af22f18c052a000a48d7b82c5e9c2e2feeda728"
+
url@^0.11.0:
version "0.11.0"
resolved "https://registry.yarnpkg.com/url/-/url-0.11.0.tgz#3838e97cfc60521eb73c525a8e55bfdd9e2e28f1"
@@ -4256,10 +4334,6 @@ use@^3.1.0:
dependencies:
kind-of "^6.0.2"
-utf8-byte-length@^1.0.1:
- version "1.0.4"
- resolved "https://registry.yarnpkg.com/utf8-byte-length/-/utf8-byte-length-1.0.4.tgz#f45f150c4c66eee968186505ab93fcbb8ad6bf61"
-
util-deprecate@~1.0.1:
version "1.0.2"
resolved "https://registry.yarnpkg.com/util-deprecate/-/util-deprecate-1.0.2.tgz#450d4dc9fa70de732762fbd2d4a28981419a0ccf"
@@ -4589,6 +4663,10 @@ window-size@^0.2.0:
version "0.2.0"
resolved "https://registry.yarnpkg.com/window-size/-/window-size-0.2.0.tgz#b4315bb4214a3d7058ebeee892e13fa24d98b075"
+wordwrap@~0.0.2:
+ version "0.0.3"
+ resolved "https://registry.yarnpkg.com/wordwrap/-/wordwrap-0.0.3.tgz#a3d5da6cd5c0bc0008d37234bbaf1bed63059107"
+
wordwrap@~1.0.0:
version "1.0.0"
resolved "https://registry.yarnpkg.com/wordwrap/-/wordwrap-1.0.0.tgz#27584810891456a4171c8d0226441ade90cbcaeb"
diff --git a/translation/build-resources.sh b/translation/build-resources.sh
new file mode 100755
index 000000000..162a363b4
--- /dev/null
+++ b/translation/build-resources.sh
@@ -0,0 +1,84 @@
+#!/usr/bin/env bash
+
+# Copyright 2018 Google LLC. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+# Builds resources for the sequence-to-sequence English-French translation demo.
+# Note this is not necessary to run the demo, because we already provide hosted
+# pre-built resources.
+# Usage example: do this in the 'translation' directory:
+# ./build.sh ~/ml-data/fra-eng/fra.txt
+#
+# You can specify the number of training epochs by using the --epochs flag.
+# For example:
+# ./build-resources.sh ~/ml-data/fra-eng/fra.txt --epochs 10
+
+set -e
+
+DEMO_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+
+TRAIN_DATA_PATH="$1"
+if [[ -z "${TRAIN_DATA_PATH}" ]]; then
+ echo "ERROR: TRAIN_DATA_PATH is not specified."
+ echo "You can download the training data with a command such as:"
+ echo " wget http://www.manythings.org/anki/fra-eng.zip"
+ exit 1
+fi
+shift 1
+
+if [[ ! -f ${TRAIN_DATA_PATH} ]]; then
+ echo "ERROR: Cannot find training data at path '${TRAIN_DATA_PATH}'"
+ exit 1
+fi
+
+TRAIN_EPOCHS=100
+while true; do
+ if [[ "$1" == "--epochs" ]]; then
+ TRAIN_EPOCHS=$2
+ shift 2
+ elif [[ -z "$1" ]]; then
+ break
+ else
+ echo "ERROR: Unrecognized argument: $1"
+ exit 1
+ fi
+done
+
+RESOURCES_ROOT="${DEMO_DIR}/dist/resources"
+rm -rf "${RESOURCES_ROOT}"
+mkdir -p "${RESOURCES_ROOT}"
+
+# Run Python script to generate the pretrained model and weights files.
+# Make sure you install the tensorflowjs pip package first.
+
+python "${DEMO_DIR}/python/translation.py" \
+ "${TRAIN_DATA_PATH}" \
+ --recurrent_initializer glorot_uniform \
+ --artifacts_dir "${RESOURCES_ROOT}" \
+ --epochs "${TRAIN_EPOCHS}"
+# TODO(cais): This --recurrent_initializer is a workaround for the limitation
+# in TensorFlow.js Layers that the default recurrent initializer "Orthogonal" is
+# currently not supported. Remove this once "Orthogonal" becomes available.
+
+cd ${DEMO_DIR}
+yarn
+yarn build
+
+echo
+echo "-----------------------------------------------------------"
+echo "Resources written to ${RESOURCES_ROOT}."
+echo "You can now run the demo with 'yarn watch'."
+echo "-----------------------------------------------------------"
+echo
diff --git a/translation/index.html b/translation/index.html
index 2e0cb35f5..e42756c8e 100644
--- a/translation/index.html
+++ b/translation/index.html
@@ -36,6 +36,10 @@