diff --git a/deploy.sh b/deploy.sh index b96f8f19e..5ffb8766f 100755 --- a/deploy.sh +++ b/deploy.sh @@ -1,5 +1,5 @@ #!/usr/bin/env bash -# Copyright 2018 Google Inc. All Rights Reserved. +# Copyright 2018 Google LLC. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/iris/build-resources.sh b/iris/build-resources.sh new file mode 100755 index 000000000..708383b23 --- /dev/null +++ b/iris/build-resources.sh @@ -0,0 +1,61 @@ +#!/usr/bin/env bash + +# Copyright 2018 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +# Builds resources for the Iris demo. +# Note this is not necessary to run the demo, because we already provide hosted +# pre-built resources. +# Usage example: do this from the 'iris' directory: +# ./build-resources.sh + +set -e + +DEMO_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +TRAIN_EPOCHS=100 +while true; do + if [[ "$1" == "--epochs" ]]; then + TRAIN_EPOCHS=$2 + shift 2 + elif [[ -z "$1" ]]; then + break + else + echo "ERROR: Unrecognized argument: $1" + exit 1 + fi +done + +RESOURCES_ROOT="${DEMO_DIR}/dist/resources" +rm -rf "${RESOURCES_ROOT}" +mkdir -p "${RESOURCES_ROOT}" + +# Run Python script to generate the pretrained model and weights files. +# Make sure you install the tensorflowjs pip package first. + +python "${DEMO_DIR}/python/iris.py" \ + --epochs "${TRAIN_EPOCHS}" \ + --artifacts_dir "${RESOURCES_ROOT}" + +cd ${DEMO_DIR} +yarn +yarn build + +echo +echo "-----------------------------------------------------------" +echo "Resources written to ${RESOURCES_ROOT}." +echo "You can now run the demo with 'yarn watch'." +echo "-----------------------------------------------------------" +echo diff --git a/iris/index.html b/iris/index.html index 47ab5a3af..338b9f745 100644 --- a/iris/index.html +++ b/iris/index.html @@ -117,7 +117,8 @@

TensorFlow.js Layers: Iris Demo

- + +
diff --git a/iris/index.js b/iris/index.js index 9ded1c69e..ba9750562 100644 --- a/iris/index.js +++ b/iris/index.js @@ -17,28 +17,12 @@ import * as tf from '@tensorflow/tfjs'; -import {getIrisData, IRIS_CLASSES, IRIS_NUM_CLASSES} from './data'; -import {clearEvaluateTable, getManualInputData, loadTrainParametersFromUI, plotAccuracies, plotLosses, renderEvaluateTable, renderLogitsForManualInput, setManualInputWinnerMessage, status, wireUpEvaluateTableCallbacks} from './ui'; +import * as data from './data'; +import * as loader from './loader'; +import * as ui from './ui'; let model; -/** - * Load pretrained model stored at a remote URL. - * - * @return An instance of `tf.Model` with model topology and weights loaded. - */ -async function loadHostedPretrainedModel() { - const HOSTED_MODEL_JSON_URL = - 'https://storage.googleapis.com/tfjs-models/tfjs/iris_v1/model.json'; - status('Loading pretrained model from ' + HOSTED_MODEL_JSON_URL); - try { - model = await tf.loadModel(HOSTED_MODEL_JSON_URL); - status('Done loading pretrained model.'); - } catch (err) { - status('Loading pretrained model failed.'); - } -} - /** * Train a `tf.Model` to recognize Iris flower type. * @@ -53,9 +37,9 @@ async function loadHostedPretrainedModel() { * @returns The trained `tf.Model` instance. */ async function trainModel(xTrain, yTrain, xTest, yTest) { - status('Training model... Please wait.'); + ui.status('Training model... Please wait.'); - const params = loadTrainParametersFromUI(); + const params = ui.loadTrainParametersFromUI(); // Define the topology of the model: two dense layers. const model = tf.sequential(); @@ -79,8 +63,8 @@ async function trainModel(xTrain, yTrain, xTest, yTest) { callbacks: { onEpochEnd: async (epoch, logs) => { // Plot the loss and accuracy values at the end of every training epoch. - plotLosses(lossValues, epoch, logs.loss, logs.val_loss); - plotAccuracies(accuracyValues, epoch, logs.acc, logs.val_acc); + ui.plotLosses(lossValues, epoch, logs.loss, logs.val_loss); + ui.plotAccuracies(accuracyValues, epoch, logs.acc, logs.val_acc); // Await web page DOM to refresh for the most recently plotted values. await tf.nextFrame(); @@ -88,7 +72,7 @@ async function trainModel(xTrain, yTrain, xTest, yTest) { } }); - status('Model training complete.'); + ui.status('Model training complete.'); return model; } @@ -99,7 +83,7 @@ async function trainModel(xTrain, yTrain, xTest, yTest) { */ async function predictOnManualInput(model) { if (model == null) { - setManualInputWinnerMessage('ERROR: Please load or train model first.'); + ui.setManualInputWinnerMessage('ERROR: Please load or train model first.'); return; } @@ -107,7 +91,7 @@ async function predictOnManualInput(model) { // `predict` call is released at the end. tf.tidy(() => { // Prepare input data as a 2D `tf.Tensor`. - const inputData = getManualInputData(); + const inputData = ui.getManualInputData(); const input = tf.tensor2d([inputData], [1, 4]); // Call `model.predict` to get the prediction output as probabilities for @@ -115,9 +99,9 @@ async function predictOnManualInput(model) { const predictOut = model.predict(input); const logits = Array.from(predictOut.dataSync()); - const winner = IRIS_CLASSES[predictOut.argMax(-1).dataSync()[0]]; - setManualInputWinnerMessage(winner); - renderLogitsForManualInput(logits); + const winner = data.IRIS_CLASSES[predictOut.argMax(-1).dataSync()[0]]; + ui.setManualInputWinnerMessage(winner); + ui.renderLogitsForManualInput(logits); }); } @@ -130,24 +114,29 @@ async function predictOnManualInput(model) { * [numTestExamples, 3]. */ async function evaluateModelOnTestData(model, xTest, yTest) { - clearEvaluateTable(); + ui.clearEvaluateTable(); tf.tidy(() => { const xData = xTest.dataSync(); const yTrue = yTest.argMax(-1).dataSync(); const predictOut = model.predict(xTest); const yPred = predictOut.argMax(-1); - renderEvaluateTable(xData, yTrue, yPred.dataSync(), predictOut.dataSync()); + ui.renderEvaluateTable( + xData, yTrue, yPred.dataSync(), predictOut.dataSync()); }); predictOnManualInput(model); } +const LOCAL_MODEL_JSON_URL = 'http://localhost:1235/resources/model.json'; +const HOSTED_MODEL_JSON_URL = + 'https://storage.googleapis.com/tfjs-models/tfjs/iris_v1/model.json'; + /** * The main function of the Iris demo. */ async function iris() { - const [xTrain, yTrain, xTest, yTest] = getIrisData(0.15); + const [xTrain, yTrain, xTest, yTest] = data.getIrisData(0.15); document.getElementById('train-from-scratch') .addEventListener('click', async () => { @@ -155,14 +144,32 @@ async function iris() { evaluateModelOnTestData(model, xTest, yTest); }); - document.getElementById('load-pretrained') - .addEventListener('click', async () => { - clearEvaluateTable(); - await loadHostedPretrainedModel(); - predictOnManualInput(model); - }); + if (await loader.urlExists(HOSTED_MODEL_JSON_URL)) { + ui.status('Model available: ' + HOSTED_MODEL_JSON_URL); + const button = document.getElementById('load-pretrained-remote'); + button.addEventListener('click', async () => { + ui.clearEvaluateTable(); + model = await loader.loadHostedPretrainedModel(HOSTED_MODEL_JSON_URL); + predictOnManualInput(model); + }); + // button.style.visibility = 'visible'; + button.style.display = 'inline-block'; + } + + if (await loader.urlExists(LOCAL_MODEL_JSON_URL)) { + ui.status('Model available: ' + LOCAL_MODEL_JSON_URL); + const button = document.getElementById('load-pretrained-local'); + button.addEventListener('click', async () => { + ui.clearEvaluateTable(); + model = await loader.loadHostedPretrainedModel(LOCAL_MODEL_JSON_URL); + predictOnManualInput(model); + }); + // button.style.visibility = 'visible'; + button.style.display = 'inline-block'; + } - wireUpEvaluateTableCallbacks(() => predictOnManualInput(model)); + ui.status('Standing by.'); + ui.wireUpEvaluateTableCallbacks(() => predictOnManualInput(model)); } iris(); diff --git a/iris/loader.js b/iris/loader.js new file mode 100644 index 000000000..51296e4e9 --- /dev/null +++ b/iris/loader.js @@ -0,0 +1,53 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '@tensorflow/tfjs'; +import * as ui from './ui'; + +/** + * Test whether a given URL is retrievable. + */ +export async function urlExists(url) { + ui.status('Testing url ' + url); + try { + const response = await fetch(url, {method: 'HEAD'}); + return response.ok; + } catch (err) { + return false; + } +} + +/** + * Load pretrained model stored at a remote URL. + * + * @return An instance of `tf.Model` with model topology and weights loaded. + */ +export async function loadHostedPretrainedModel(url) { + ui.status('Loading pretrained model from ' + url); + try { + const model = await tf.loadModel(url); + ui.status('Done loading pretrained model.'); + // We can't load a model twice due to + // https://github.com/tensorflow/tfjs/issues/34 + // Therefore we remove the load buttons to avoid user confusion. + ui.disableLoadModelButtons(); + return model; + } catch (err) { + console.error(err); + ui.status('Loading pretrained model failed.'); + } +} diff --git a/iris/package.json b/iris/package.json index a0d3c4a43..855c3384f 100644 --- a/iris/package.json +++ b/iris/package.json @@ -13,15 +13,16 @@ "vega-embed": "^3.0.0" }, "scripts": { - "watch": "NODE_ENV=development parcel --no-hmr --open index.html ", - "build": "NODE_ENV=production parcel build index.html --no-minify --public-url ./" + "watch": "./serve.sh", + "build": "NODE_ENV=production parcel build index.html --no-minify --public-url /" }, "devDependencies": { "babel-plugin-transform-runtime": "~6.23.0", "babel-polyfill": "~6.26.0", "babel-preset-env": "~1.6.1", "clang-format": "~1.2.2", - "parcel-bundler": "~1.6.2" + "http-server": "~0.10.0", + "parcel-bundler": "~1.7.0" }, "babel": { "presets": [ diff --git a/iris/python/__init__.py b/iris/python/__init__.py new file mode 100644 index 000000000..636b70f0d --- /dev/null +++ b/iris/python/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2018 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function diff --git a/iris/python/iris.py b/iris/python/iris.py new file mode 100644 index 000000000..a8f882349 --- /dev/null +++ b/iris/python/iris.py @@ -0,0 +1,102 @@ +# Copyright 2018 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +"""Train a simple model for the Iris dataset; Export the model and weights.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse + +import keras +import numpy as np +import tensorflowjs as tfjs + +import iris_data + + +def train(epochs, + artifacts_dir, + sequential=False): + """Train a Keras model for Iris data classification and save result as JSON. + + Args: + epochs: Number of epochs to traing the Keras model for. + artifacts_dir: Directory to save the model artifacts (model topology JSON, + weights and weight manifest) in. + sequential: Whether to use a Keras Sequential model, instead of the default + functional model. + + Returns: + Final classification accuracy on the training set. + """ + data_x, data_y = iris_data.load() + + if sequential: + model = keras.models.Sequential() + model.add(keras.layers.Dense( + 10, input_shape=[data_x.shape[1]], use_bias=True, activation='sigmoid', + name='Dense1')) + model.add(keras.layers.Dense( + 3, use_bias=True, activation='softmax', name='Dense2')) + else: + iris_x = keras.layers.Input((4,)) + dense1 = keras.layers.Dense( + 10, use_bias=True, name='Dense1', activation='sigmoid')(iris_x) + dense2 = keras.layers.Dense( + 3, use_bias=True, name='Dense2', activation='softmax')(dense1) + # pylint:disable=redefined-variable-type + model = keras.models.Model(inputs=[iris_x], outputs=[dense2]) + # pylint:enable=redefined-variable-type + model.compile(loss='categorical_crossentropy', optimizer='adam') + + model.fit(data_x, data_y, batch_size=8, epochs=epochs) + + # Run prediction on the training set. + pred_ys = np.argmax(model.predict(data_x), axis=1) + true_ys = np.argmax(data_y, axis=1) + final_train_accuracy = np.mean((pred_ys == true_ys).astype(np.float32)) + print('Accuracy on the training set: %g' % final_train_accuracy) + + tfjs.converters.save_keras_model(model, artifacts_dir) + + return final_train_accuracy + + +def main(): + train(FLAGS.epochs, FLAGS.artifacts_dir, sequential=FLAGS.sequential) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser('Iris model training and serialization') + parser.add_argument( + '--sequential', + action='store_true', + help='Use a Keras Sequential model, instead of the default functional ' + 'model.') + parser.add_argument( + '--epochs', + type=int, + default=100, + help='Number of epochs to train the Keras model for.') + parser.add_argument( + '--artifacts_dir', + type=str, + default='/tmp/iris.keras', + help='Local path for saving the TensorFlow.js artifacts.') + + FLAGS, _ = parser.parse_known_args() + main() diff --git a/iris/python/iris_data.py b/iris/python/iris_data.py new file mode 100644 index 000000000..c72e39f71 --- /dev/null +++ b/iris/python/iris_data.py @@ -0,0 +1,217 @@ +# Copyright 2018 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +"""Iris dataset (see https://en.wikipedia.org/wiki/Iris_flower_data_set).""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +IRIS_CLASSES = ('setosa', 'versicolor', 'virginica') +IRIS_DATA = [ + '5.1,3.5,1.4,0.2,Iris-setosa', + '4.9,3.0,1.4,0.2,Iris-setosa', + '4.7,3.2,1.3,0.2,Iris-setosa', + '4.6,3.1,1.5,0.2,Iris-setosa', + '5.0,3.6,1.4,0.2,Iris-setosa', + '5.4,3.9,1.7,0.4,Iris-setosa', + '4.6,3.4,1.4,0.3,Iris-setosa', + '5.0,3.4,1.5,0.2,Iris-setosa', + '4.4,2.9,1.4,0.2,Iris-setosa', + '4.9,3.1,1.5,0.1,Iris-setosa', + '5.4,3.7,1.5,0.2,Iris-setosa', + '4.8,3.4,1.6,0.2,Iris-setosa', + '4.8,3.0,1.4,0.1,Iris-setosa', + '4.3,3.0,1.1,0.1,Iris-setosa', + '5.8,4.0,1.2,0.2,Iris-setosa', + '5.7,4.4,1.5,0.4,Iris-setosa', + '5.4,3.9,1.3,0.4,Iris-setosa', + '5.1,3.5,1.4,0.3,Iris-setosa', + '5.7,3.8,1.7,0.3,Iris-setosa', + '5.1,3.8,1.5,0.3,Iris-setosa', + '5.4,3.4,1.7,0.2,Iris-setosa', + '5.1,3.7,1.5,0.4,Iris-setosa', + '4.6,3.6,1.0,0.2,Iris-setosa', + '5.1,3.3,1.7,0.5,Iris-setosa', + '4.8,3.4,1.9,0.2,Iris-setosa', + '5.0,3.0,1.6,0.2,Iris-setosa', + '5.0,3.4,1.6,0.4,Iris-setosa', + '5.2,3.5,1.5,0.2,Iris-setosa', + '5.2,3.4,1.4,0.2,Iris-setosa', + '4.7,3.2,1.6,0.2,Iris-setosa', + '4.8,3.1,1.6,0.2,Iris-setosa', + '5.4,3.4,1.5,0.4,Iris-setosa', + '5.2,4.1,1.5,0.1,Iris-setosa', + '5.5,4.2,1.4,0.2,Iris-setosa', + '4.9,3.1,1.5,0.1,Iris-setosa', + '5.0,3.2,1.2,0.2,Iris-setosa', + '5.5,3.5,1.3,0.2,Iris-setosa', + '4.9,3.1,1.5,0.1,Iris-setosa', + '4.4,3.0,1.3,0.2,Iris-setosa', + '5.1,3.4,1.5,0.2,Iris-setosa', + '5.0,3.5,1.3,0.3,Iris-setosa', + '4.5,2.3,1.3,0.3,Iris-setosa', + '4.4,3.2,1.3,0.2,Iris-setosa', + '5.0,3.5,1.6,0.6,Iris-setosa', + '5.1,3.8,1.9,0.4,Iris-setosa', + '4.8,3.0,1.4,0.3,Iris-setosa', + '5.1,3.8,1.6,0.2,Iris-setosa', + '4.6,3.2,1.4,0.2,Iris-setosa', + '5.3,3.7,1.5,0.2,Iris-setosa', + '5.0,3.3,1.4,0.2,Iris-setosa', + '7.0,3.2,4.7,1.4,Iris-versicolor', + '6.4,3.2,4.5,1.5,Iris-versicolor', + '6.9,3.1,4.9,1.5,Iris-versicolor', + '5.5,2.3,4.0,1.3,Iris-versicolor', + '6.5,2.8,4.6,1.5,Iris-versicolor', + '5.7,2.8,4.5,1.3,Iris-versicolor', + '6.3,3.3,4.7,1.6,Iris-versicolor', + '4.9,2.4,3.3,1.0,Iris-versicolor', + '6.6,2.9,4.6,1.3,Iris-versicolor', + '5.2,2.7,3.9,1.4,Iris-versicolor', + '5.0,2.0,3.5,1.0,Iris-versicolor', + '5.9,3.0,4.2,1.5,Iris-versicolor', + '6.0,2.2,4.0,1.0,Iris-versicolor', + '6.1,2.9,4.7,1.4,Iris-versicolor', + '5.6,2.9,3.6,1.3,Iris-versicolor', + '6.7,3.1,4.4,1.4,Iris-versicolor', + '5.6,3.0,4.5,1.5,Iris-versicolor', + '5.8,2.7,4.1,1.0,Iris-versicolor', + '6.2,2.2,4.5,1.5,Iris-versicolor', + '5.6,2.5,3.9,1.1,Iris-versicolor', + '5.9,3.2,4.8,1.8,Iris-versicolor', + '6.1,2.8,4.0,1.3,Iris-versicolor', + '6.3,2.5,4.9,1.5,Iris-versicolor', + '6.1,2.8,4.7,1.2,Iris-versicolor', + '6.4,2.9,4.3,1.3,Iris-versicolor', + '6.6,3.0,4.4,1.4,Iris-versicolor', + '6.8,2.8,4.8,1.4,Iris-versicolor', + '6.7,3.0,5.0,1.7,Iris-versicolor', + '6.0,2.9,4.5,1.5,Iris-versicolor', + '5.7,2.6,3.5,1.0,Iris-versicolor', + '5.5,2.4,3.8,1.1,Iris-versicolor', + '5.5,2.4,3.7,1.0,Iris-versicolor', + '5.8,2.7,3.9,1.2,Iris-versicolor', + '6.0,2.7,5.1,1.6,Iris-versicolor', + '5.4,3.0,4.5,1.5,Iris-versicolor', + '6.0,3.4,4.5,1.6,Iris-versicolor', + '6.7,3.1,4.7,1.5,Iris-versicolor', + '6.3,2.3,4.4,1.3,Iris-versicolor', + '5.6,3.0,4.1,1.3,Iris-versicolor', + '5.5,2.5,4.0,1.3,Iris-versicolor', + '5.5,2.6,4.4,1.2,Iris-versicolor', + '6.1,3.0,4.6,1.4,Iris-versicolor', + '5.8,2.6,4.0,1.2,Iris-versicolor', + '5.0,2.3,3.3,1.0,Iris-versicolor', + '5.6,2.7,4.2,1.3,Iris-versicolor', + '5.7,3.0,4.2,1.2,Iris-versicolor', + '5.7,2.9,4.2,1.3,Iris-versicolor', + '6.2,2.9,4.3,1.3,Iris-versicolor', + '5.1,2.5,3.0,1.1,Iris-versicolor', + '5.7,2.8,4.1,1.3,Iris-versicolor', + '6.3,3.3,6.0,2.5,Iris-virginica', + '5.8,2.7,5.1,1.9,Iris-virginica', + '7.1,3.0,5.9,2.1,Iris-virginica', + '6.3,2.9,5.6,1.8,Iris-virginica', + '6.5,3.0,5.8,2.2,Iris-virginica', + '7.6,3.0,6.6,2.1,Iris-virginica', + '4.9,2.5,4.5,1.7,Iris-virginica', + '7.3,2.9,6.3,1.8,Iris-virginica', + '6.7,2.5,5.8,1.8,Iris-virginica', + '7.2,3.6,6.1,2.5,Iris-virginica', + '6.5,3.2,5.1,2.0,Iris-virginica', + '6.4,2.7,5.3,1.9,Iris-virginica', + '6.8,3.0,5.5,2.1,Iris-virginica', + '5.7,2.5,5.0,2.0,Iris-virginica', + '5.8,2.8,5.1,2.4,Iris-virginica', + '6.4,3.2,5.3,2.3,Iris-virginica', + '6.5,3.0,5.5,1.8,Iris-virginica', + '7.7,3.8,6.7,2.2,Iris-virginica', + '7.7,2.6,6.9,2.3,Iris-virginica', + '6.0,2.2,5.0,1.5,Iris-virginica', + '6.9,3.2,5.7,2.3,Iris-virginica', + '5.6,2.8,4.9,2.0,Iris-virginica', + '7.7,2.8,6.7,2.0,Iris-virginica', + '6.3,2.7,4.9,1.8,Iris-virginica', + '6.7,3.3,5.7,2.1,Iris-virginica', + '7.2,3.2,6.0,1.8,Iris-virginica', + '6.2,2.8,4.8,1.8,Iris-virginica', + '6.1,3.0,4.9,1.8,Iris-virginica', + '6.4,2.8,5.6,2.1,Iris-virginica', + '7.2,3.0,5.8,1.6,Iris-virginica', + '7.4,2.8,6.1,1.9,Iris-virginica', + '7.9,3.8,6.4,2.0,Iris-virginica', + '6.4,2.8,5.6,2.2,Iris-virginica', + '6.3,2.8,5.1,1.5,Iris-virginica', + '6.1,2.6,5.6,1.4,Iris-virginica', + '7.7,3.0,6.1,2.3,Iris-virginica', + '6.3,3.4,5.6,2.4,Iris-virginica', + '6.4,3.1,5.5,1.8,Iris-virginica', + '6.0,3.0,4.8,1.8,Iris-virginica', + '6.9,3.1,5.4,2.1,Iris-virginica', + '6.7,3.1,5.6,2.4,Iris-virginica', + '6.9,3.1,5.1,2.3,Iris-virginica', + '5.8,2.7,5.1,1.9,Iris-virginica', + '6.8,3.2,5.9,2.3,Iris-virginica', + '6.7,3.3,5.7,2.5,Iris-virginica', + '6.7,3.0,5.2,2.3,Iris-virginica', + '6.3,2.5,5.0,1.9,Iris-virginica', + '6.5,3.0,5.2,2.0,Iris-virginica', + '6.2,3.4,5.4,2.3,Iris-virginica', + '5.9,3.0,5.1,1.8,Iris-virginica'] + + +def load(): + """Load Iris data. + + Returns: + Iris data as a numpy ndarray of size [n, 4] and dtype `float32`, n being the + number of available samples. + Iris classification target as a numpy ndarray of [n, 3] and dtype `float32`. + The order of the data is randomly shuffled. + """ + iris_x = [] + iris_y = [] + for line in IRIS_DATA: + items = line.split(',') + xs = [float(x) for x in items[:4]] + iris_x.append(xs) + assert items[-1].startswith('Iris-') + iris_y.append(IRIS_CLASSES.index(items[-1].replace('Iris-', ''))) + + # Randomly shuffle the data. + iris_xy = list(zip(iris_x, iris_y)) + np.random.shuffle(iris_xy) + iris_x, iris_y = zip(*iris_xy) + return (np.array(iris_x, dtype=np.float32), + _to_one_hot(iris_y, len(IRIS_CLASSES))) + + +def _to_one_hot(indices, num_classes): + """Convert indices to one-hot encoding. + + Args: + indices: A list of `int` indices with length `n`, each eleemnt of which is + assumed to be a zero-based class index >= 0 and < `num_classes`. + num_classes: Total number of possible classes as an `int`. + + Returns: + A numpy ndarray of shape [n, num_classes] and dtype `float32`. + """ + one_hot = np.zeros([len(indices), num_classes], dtype=np.float32) + one_hot[np.arange(len(indices)), indices] = 1 + return one_hot diff --git a/iris/python/iris_data_test.py b/iris/python/iris_data_test.py new file mode 100644 index 000000000..613c14332 --- /dev/null +++ b/iris/python/iris_data_test.py @@ -0,0 +1,42 @@ +# Copyright 2018 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +"""Test for the Iris dataset module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest + +import numpy as np + +import iris_data + +class IrisDataTest(unittest.TestCase): + + def testLoadData(self): + iris_x, iris_y = iris_data.load() + self.assertEqual(2, len(iris_x.shape)) + self.assertGreater(iris_x.shape[0], 0) + self.assertEqual(4, iris_x.shape[1]) + self.assertEqual(iris_x.shape[0], iris_y.shape[0]) + self.assertEqual(3, iris_y.shape[1]) + self.assertTrue( + np.allclose(np.ones([iris_y.shape[0], 1]), np.sum(iris_y, axis=1))) + + +if __name__ == '__main__': + unittest.main() diff --git a/iris/python/iris_test.py b/iris/python/iris_test.py new file mode 100644 index 000000000..3fed96d87 --- /dev/null +++ b/iris/python/iris_test.py @@ -0,0 +1,58 @@ +# Copyright 2018 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +"""Test for the Iris model.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import os +import shutil +import tempfile +import unittest + +import iris + + +class IrisTest(unittest.TestCase): + + def setUp(self): + self._tmp_dir = tempfile.mkdtemp() + super(IrisTest, self).setUp() + + def tearDown(self): + if os.path.isdir(self._tmp_dir): + shutil.rmtree(self._tmp_dir) + super(IrisTest, self).tearDown() + + def testTrainAndSaveNonSequential(self): + final_train_accuracy = iris.train(100, self._tmp_dir) + self.assertGreater(final_train_accuracy, 0.9) + + # Check that the model json file is created. + json.load(open(os.path.join(self._tmp_dir, 'model.json'), 'rt')) + + def testTrainAndSaveSequential(self): + final_train_accuracy = iris.train(100, self._tmp_dir, sequential=True) + self.assertGreater(final_train_accuracy, 0.9) + + # Check that the model json file is created. + json.load(open(os.path.join(self._tmp_dir, 'model.json'), 'rt')) + + +if __name__ == '__main__': + unittest.main() diff --git a/iris/serve.sh b/iris/serve.sh new file mode 100755 index 000000000..e5ede63ed --- /dev/null +++ b/iris/serve.sh @@ -0,0 +1,42 @@ +#!/usr/bin/env bash + +# Copyright 2018 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +# This script starts two HTTP servers on different ports: +# * Port 1234 (using parcel) serves HTML and JavaScript. +# * Port 1235 (using http-server) serves pretrained model resources. +# +# The reason for this arrangement is that Parcel currently has a limitation that +# prevents it from serving the pretrained models; see +# https://github.com/parcel-bundler/parcel/issues/1098. Once that issue is +# resolved, a single Parcel server will be sufficient. + +NODE_ENV=development +RESOURCE_PORT=1235 + +# Ensure that http-server is available +yarn + +echo Starting the pretrained model server... +node_modules/http-server/bin/http-server dist --cors -p "${RESOURCE_PORT}" > /dev/null & HTTP_SERVER_PID=$! + +echo Starting the example html/js server... +# This uses port 1234 by default. +node_modules/parcel-bundler/bin/cli.js serve -d dist --open --no-hmr --public-url / index.html + +# When the Parcel server exits, kill the http-server too. +kill $HTTP_SERVER_PID + diff --git a/iris/ui.js b/iris/ui.js index ed080104e..6c62521e6 100644 --- a/iris/ui.js +++ b/iris/ui.js @@ -227,5 +227,11 @@ export function loadTrainParametersFromUI() { } export function status(statusText) { + console.log(statusText); document.getElementById('demo-status').textContent = statusText; } + +export function disableLoadModelButtons() { + document.getElementById('load-pretrained-remote').style.display = 'none'; + document.getElementById('load-pretrained-local').style.display = 'none'; +} diff --git a/iris/yarn.lock b/iris/yarn.lock index d5fdf748a..024624301 100644 --- a/iris/yarn.lock +++ b/iris/yarn.lock @@ -705,6 +705,10 @@ binary-extensions@^1.0.0: version "1.11.0" resolved "https://registry.yarnpkg.com/binary-extensions/-/binary-extensions-1.11.0.tgz#46aa1751fb6a2f93ee5e689bb1087d4b14c6c205" +bindings@~1.2.1: + version "1.2.1" + resolved "https://registry.yarnpkg.com/bindings/-/bindings-1.2.1.tgz#14ad6113812d2d37d72e67b4cacb4bb726505f11" + block-stream@*: version "0.0.9" resolved "https://registry.yarnpkg.com/block-stream/-/block-stream-0.0.9.tgz#13ebfe778a03205cfe03751481ebb4b3300c126a" @@ -766,12 +770,6 @@ brorand@^1.0.1: version "1.1.0" resolved "https://registry.yarnpkg.com/brorand/-/brorand-1.1.0.tgz#12c25efe40a45e3c323eb8675a0a0ce57b22371f" -browser-resolve@^1.11.2: - version "1.11.2" - resolved "https://registry.yarnpkg.com/browser-resolve/-/browser-resolve-1.11.2.tgz#8ff09b0a2c421718a1051c260b32e48f442938ce" - dependencies: - resolve "1.1.7" - browserify-aes@^1.0.0, browserify-aes@^1.0.4: version "1.1.1" resolved "https://registry.yarnpkg.com/browserify-aes/-/browserify-aes-1.1.1.tgz#38b7ab55edb806ff2dcda1a7f1620773a477c49f" @@ -907,7 +905,7 @@ chalk@^1.1.3: strip-ansi "^3.0.0" supports-color "^2.0.0" -chalk@^2.1.0, chalk@^2.3.1: +chalk@^2.1.0, chalk@^2.3.2: version "2.3.2" resolved "https://registry.yarnpkg.com/chalk/-/chalk-2.3.2.tgz#250dc96b07491bfd601e648d66ddf5f60c7a5c65" dependencies: @@ -1038,6 +1036,10 @@ colormin@^1.0.5: css-color-names "0.0.4" has "^1.0.1" +colors@1.0.3: + version "1.0.3" + resolved "https://registry.yarnpkg.com/colors/-/colors-1.0.3.tgz#0433f44d809680fdeb60ed260f1b0c262e82a40b" + colors@~1.1.2: version "1.1.2" resolved "https://registry.yarnpkg.com/colors/-/colors-1.1.2.tgz#168a4701756b6a7f51a12ce0c97bfa28c084ed63" @@ -1117,6 +1119,10 @@ core-util-is@1.0.2, core-util-is@~1.0.0: version "1.0.2" resolved "https://registry.yarnpkg.com/core-util-is/-/core-util-is-1.0.2.tgz#b5fd54220aa2bc5ab57aab7140c940754503c1a7" +corser@~2.0.0: + version "2.0.1" + resolved "https://registry.yarnpkg.com/corser/-/corser-2.0.1.tgz#8eda252ecaab5840dcd975ceb90d9370c819ff87" + create-ecdh@^4.0.0: version "4.0.0" resolved "https://registry.yarnpkg.com/create-ecdh/-/create-ecdh-4.0.0.tgz#888c723596cdf7612f6498233eebd7a35301737d" @@ -1405,6 +1411,13 @@ date-now@^0.1.4: version "0.1.4" resolved "https://registry.yarnpkg.com/date-now/-/date-now-0.1.4.tgz#eaf439fd4d4848ad74e5cc7dbef200672b9e345b" +deasync@^0.1.12: + version "0.1.12" + resolved "https://registry.yarnpkg.com/deasync/-/deasync-0.1.12.tgz#0159492a4133ab301d6c778cf01e74e63b10e549" + dependencies: + bindings "~1.2.1" + nan "^2.0.7" + debug@2.6.9, debug@^2.2.0, debug@^2.3.3, debug@^2.6.8: version "2.6.9" resolved "https://registry.yarnpkg.com/debug/-/debug-2.6.9.tgz#5d128515df134ff327e90a4c93f4e077a536341f" @@ -1557,6 +1570,15 @@ ecc-jsbn@~0.1.1: dependencies: jsbn "~0.1.0" +ecstatic@^2.0.0: + version "2.2.1" + resolved "https://registry.yarnpkg.com/ecstatic/-/ecstatic-2.2.1.tgz#b5087fad439dd9dd49d31e18131454817fe87769" + dependencies: + he "^1.1.1" + mime "^1.2.11" + minimist "^1.1.0" + url-join "^2.0.2" + editorconfig@^0.13.2: version "0.13.3" resolved "https://registry.yarnpkg.com/editorconfig/-/editorconfig-0.13.3.tgz#e5219e587951d60958fd94ea9a9a008cdeff1b34" @@ -1662,6 +1684,10 @@ etag@~1.8.1: version "1.8.1" resolved "https://registry.yarnpkg.com/etag/-/etag-1.8.1.tgz#41ae2eeb65efa62268aebfea83ac7d79299b0887" +eventemitter3@1.x.x: + version "1.2.0" + resolved "https://registry.yarnpkg.com/eventemitter3/-/eventemitter3-1.2.0.tgz#1c86991d816ad1e504750e73874224ecf3bec508" + events@^1.0.0: version "1.1.1" resolved "https://registry.yarnpkg.com/events/-/events-1.1.1.tgz#9ebdb7635ad099c70dcc4c2a1f5004288e8bd924" @@ -1992,6 +2018,10 @@ hawk@3.1.3, hawk@~3.1.3: hoek "2.x.x" sntp "1.x.x" +he@^1.1.1: + version "1.1.1" + resolved "https://registry.yarnpkg.com/he/-/he-1.1.1.tgz#93410fd21b009735151f8868c2f271f3427e23fd" + hmac-drbg@^1.0.0: version "1.0.1" resolved "https://registry.yarnpkg.com/hmac-drbg/-/hmac-drbg-1.0.1.tgz#d2745701025a6c775a6c545793ed502fc0c649a1" @@ -2015,7 +2045,7 @@ html-comment-regex@^1.1.0: version "1.1.1" resolved "https://registry.yarnpkg.com/html-comment-regex/-/html-comment-regex-1.1.1.tgz#668b93776eaae55ebde8f3ad464b307a4963625e" -htmlnano@^0.1.6: +htmlnano@^0.1.7: version "0.1.7" resolved "https://registry.yarnpkg.com/htmlnano/-/htmlnano-0.1.7.tgz#1751937a05f122a3248dba1c63edf01f6d36cb84" dependencies: @@ -2046,6 +2076,26 @@ http-errors@~1.6.2: setprototypeof "1.0.3" statuses ">= 1.3.1 < 2" +http-proxy@^1.8.1: + version "1.16.2" + resolved "https://registry.yarnpkg.com/http-proxy/-/http-proxy-1.16.2.tgz#06dff292952bf64dbe8471fa9df73066d4f37742" + dependencies: + eventemitter3 "1.x.x" + requires-port "1.x.x" + +http-server@~0.10.0: + version "0.10.0" + resolved "https://registry.yarnpkg.com/http-server/-/http-server-0.10.0.tgz#b2a446b16a9db87ed3c622ba9beb1b085b1234a7" + dependencies: + colors "1.0.3" + corser "~2.0.0" + ecstatic "^2.0.0" + http-proxy "^1.8.1" + opener "~1.4.0" + optimist "0.6.x" + portfinder "^1.0.13" + union "~0.4.3" + http-signature@~1.1.0: version "1.1.1" resolved "https://registry.yarnpkg.com/http-signature/-/http-signature-1.1.1.tgz#df72e267066cd0ac67fb76adf8e134a8fbcf91bf" @@ -2537,6 +2587,10 @@ mime@1.4.1: version "1.4.1" resolved "https://registry.yarnpkg.com/mime/-/mime-1.4.1.tgz#121f9ebc49e3766f311a76e1fa1c8003c4b03aa6" +mime@^1.2.11: + version "1.6.0" + resolved "https://registry.yarnpkg.com/mime/-/mime-1.6.0.tgz#32cd9e5c64553bd58d19a568af452acff04981b1" + mimic-fn@^1.0.0: version "1.2.0" resolved "https://registry.yarnpkg.com/mimic-fn/-/mimic-fn-1.2.0.tgz#820c86a39334640e99516928bd03fca88057d022" @@ -2559,10 +2613,14 @@ minimist@0.0.8: version "0.0.8" resolved "https://registry.yarnpkg.com/minimist/-/minimist-0.0.8.tgz#857fcabfc3397d2625b8228262e86aa7a011b05d" -minimist@^1.1.3, minimist@^1.2.0: +minimist@^1.1.0, minimist@^1.1.3, minimist@^1.2.0: version "1.2.0" resolved "https://registry.yarnpkg.com/minimist/-/minimist-1.2.0.tgz#a35008b20f41383eec1fb914f4cd5df79a264284" +minimist@~0.0.1: + version "0.0.10" + resolved "https://registry.yarnpkg.com/minimist/-/minimist-0.0.10.tgz#de3f98543dbf96082be48ad1a0c7cda836301dcf" + mixin-deep@^1.2.0: version "1.3.1" resolved "https://registry.yarnpkg.com/mixin-deep/-/mixin-deep-1.3.1.tgz#a49e7268dce1a0d9698e45326c5626df3543d0fe" @@ -2570,7 +2628,7 @@ mixin-deep@^1.2.0: for-in "^1.0.2" is-extendable "^1.0.1" -"mkdirp@>=0.5 0", mkdirp@^0.5.1, mkdirp@~0.5.0, mkdirp@~0.5.1: +mkdirp@0.5.x, "mkdirp@>=0.5 0", mkdirp@^0.5.1, mkdirp@~0.5.0, mkdirp@~0.5.1: version "0.5.1" resolved "https://registry.yarnpkg.com/mkdirp/-/mkdirp-0.5.1.tgz#30057438eac6cf7f8c4767f38648d6697d75c903" dependencies: @@ -2580,7 +2638,7 @@ ms@2.0.0: version "2.0.0" resolved "https://registry.yarnpkg.com/ms/-/ms-2.0.0.tgz#5608aeadfc00be6c2901df5f9861788de0d597c8" -nan@^2.3.0: +nan@^2.0.7, nan@^2.3.0: version "2.10.0" resolved "https://registry.yarnpkg.com/nan/-/nan-2.10.0.tgz#96d0cd610ebd58d4b4de9cc0c6828cda99c7548f" @@ -2778,12 +2836,23 @@ once@^1.3.0, once@^1.3.3: dependencies: wrappy "1" +opener@~1.4.0: + version "1.4.3" + resolved "https://registry.yarnpkg.com/opener/-/opener-1.4.3.tgz#5c6da2c5d7e5831e8ffa3964950f8d6674ac90b8" + opn@^5.1.0: version "5.3.0" resolved "https://registry.yarnpkg.com/opn/-/opn-5.3.0.tgz#64871565c863875f052cfdf53d3e3cb5adb53b1c" dependencies: is-wsl "^1.1.0" +optimist@0.6.x: + version "0.6.1" + resolved "https://registry.yarnpkg.com/optimist/-/optimist-0.6.1.tgz#da3ea74686fa21a19a111c326e90eb15a0196686" + dependencies: + minimist "~0.0.1" + wordwrap "~0.0.2" + optionator@^0.8.1: version "0.8.2" resolved "https://registry.yarnpkg.com/optionator/-/optionator-0.8.2.tgz#364c5e409d3f4d6301d6c0b4c05bba50180aeb64" @@ -2850,9 +2919,9 @@ pako@~1.0.5: version "1.0.6" resolved "https://registry.yarnpkg.com/pako/-/pako-1.0.6.tgz#0101211baa70c4bca4a0f63f2206e97b7dfaf258" -parcel-bundler@~1.6.2: - version "1.6.2" - resolved "https://registry.yarnpkg.com/parcel-bundler/-/parcel-bundler-1.6.2.tgz#e415c4993b6f4c48cd410427594ce80bbb7d90f1" +parcel-bundler@~1.7.0: + version "1.7.0" + resolved "https://registry.yarnpkg.com/parcel-bundler/-/parcel-bundler-1.7.0.tgz#8c2512615fd602d2f39bd97bfd128f8fe524b321" dependencies: babel-code-frame "^6.26.0" babel-core "^6.25.0" @@ -2865,7 +2934,6 @@ parcel-bundler@~1.6.2: babel-types "^6.26.0" babylon "^6.17.4" babylon-walk "^1.0.2" - browser-resolve "^1.11.2" browserslist "^2.11.2" chalk "^2.1.0" chokidar "^2.0.1" @@ -2873,12 +2941,13 @@ parcel-bundler@~1.6.2: commander "^2.11.0" cross-spawn "^6.0.4" cssnano "^3.10.0" + deasync "^0.1.12" dotenv "^5.0.0" filesize "^3.6.0" get-port "^3.2.0" glob "^7.1.2" grapheme-breaker "^0.3.2" - htmlnano "^0.1.6" + htmlnano "^0.1.7" is-url "^1.2.2" js-yaml "^3.10.0" json5 "^0.5.1" @@ -2888,13 +2957,12 @@ parcel-bundler@~1.6.2: node-libs-browser "^2.0.0" opn "^5.1.0" physical-cpu-count "^2.0.0" - postcss "^6.0.10" + postcss "^6.0.19" postcss-value-parser "^3.3.0" posthtml "^0.11.2" posthtml-parser "^0.4.0" posthtml-render "^1.1.0" resolve "^1.4.0" - sanitize-filename "^1.6.1" semver "^5.4.1" serialize-to-js "^1.1.1" serve-static "^1.12.4" @@ -2967,6 +3035,14 @@ physical-cpu-count@^2.0.0: version "2.0.0" resolved "https://registry.yarnpkg.com/physical-cpu-count/-/physical-cpu-count-2.0.0.tgz#18de2f97e4bf7a9551ad7511942b5496f7aba660" +portfinder@^1.0.13: + version "1.0.13" + resolved "https://registry.yarnpkg.com/portfinder/-/portfinder-1.0.13.tgz#bb32ecd87c27104ae6ee44b5a3ccbf0ebb1aede9" + dependencies: + async "^1.5.2" + debug "^2.2.0" + mkdirp "0.5.x" + posix-character-classes@^0.1.0: version "0.1.1" resolved "https://registry.yarnpkg.com/posix-character-classes/-/posix-character-classes-0.1.1.tgz#01eac0fe3b5af71a2a6c02feabb8c1fef7e00eab" @@ -3182,13 +3258,13 @@ postcss@^5.0.10, postcss@^5.0.11, postcss@^5.0.12, postcss@^5.0.13, postcss@^5.0 source-map "^0.5.6" supports-color "^3.2.3" -postcss@^6.0.10: - version "6.0.19" - resolved "https://registry.yarnpkg.com/postcss/-/postcss-6.0.19.tgz#76a78386f670b9d9494a655bf23ac012effd1555" +postcss@^6.0.19: + version "6.0.21" + resolved "https://registry.yarnpkg.com/postcss/-/postcss-6.0.21.tgz#8265662694eddf9e9a5960db6da33c39e4cd069d" dependencies: - chalk "^2.3.1" + chalk "^2.3.2" source-map "^0.6.1" - supports-color "^5.2.0" + supports-color "^5.3.0" posthtml-parser@^0.3.3: version "0.3.3" @@ -3275,6 +3351,10 @@ q@^1.1.2: version "1.5.1" resolved "https://registry.yarnpkg.com/q/-/q-1.5.1.tgz#7e32f75b41381291d04611f1bf14109ac00651d7" +qs@~2.3.3: + version "2.3.3" + resolved "https://registry.yarnpkg.com/qs/-/qs-2.3.3.tgz#e9e85adbe75da0bbe4c8e0476a086290f863b404" + qs@~6.4.0: version "6.4.0" resolved "https://registry.yarnpkg.com/qs/-/qs-6.4.0.tgz#13e26d28ad6b0ffaa91312cd3bf708ed351e7233" @@ -3461,14 +3541,14 @@ require-main-filename@^1.0.1: version "1.0.1" resolved "https://registry.yarnpkg.com/require-main-filename/-/require-main-filename-1.0.1.tgz#97f717b69d48784f5f526a6c5aa8ffdda055a4d1" +requires-port@1.x.x: + version "1.0.0" + resolved "https://registry.yarnpkg.com/requires-port/-/requires-port-1.0.0.tgz#925d2601d39ac485e091cf0da5c6e694dc3dcaff" + resolve-url@^0.2.1: version "0.2.1" resolved "https://registry.yarnpkg.com/resolve-url/-/resolve-url-0.2.1.tgz#2c637fe77c893afd2a663fe21aa9080068e2052a" -resolve@1.1.7: - version "1.1.7" - resolved "https://registry.yarnpkg.com/resolve/-/resolve-1.1.7.tgz#203114d82ad2c5ed9e8e0411b3932875e889e97b" - resolve@^1.1.5, resolve@^1.1.6, resolve@^1.4.0: version "1.5.0" resolved "https://registry.yarnpkg.com/resolve/-/resolve-1.5.0.tgz#1f09acce796c9a762579f31b2c1cc4c3cddf9f36" @@ -3512,12 +3592,6 @@ safer-eval@^1.2.3: dependencies: clones "^1.1.0" -sanitize-filename@^1.6.1: - version "1.6.1" - resolved "https://registry.yarnpkg.com/sanitize-filename/-/sanitize-filename-1.6.1.tgz#612da1c96473fa02dccda92dcd5b4ab164a6772a" - dependencies: - truncate-utf8-bytes "^1.0.0" - sax@~1.2.1, sax@~1.2.4: version "1.2.4" resolved "https://registry.yarnpkg.com/sax/-/sax-1.2.4.tgz#2816234e2378bddc4e5354fab5caa895df7100d9" @@ -3844,7 +3918,7 @@ supports-color@^3.2.3: dependencies: has-flag "^1.0.0" -supports-color@^5.2.0, supports-color@^5.3.0: +supports-color@^5.3.0: version "5.3.0" resolved "https://registry.yarnpkg.com/supports-color/-/supports-color-5.3.0.tgz#5b24ac15db80fa927cf5227a4a33fd3c4c7676c0" dependencies: @@ -3973,12 +4047,6 @@ trim-right@^1.0.1: version "1.0.1" resolved "https://registry.yarnpkg.com/trim-right/-/trim-right-1.0.1.tgz#cb2e1203067e0c8de1f614094b9fe45704ea6003" -truncate-utf8-bytes@^1.0.0: - version "1.0.2" - resolved "https://registry.yarnpkg.com/truncate-utf8-bytes/-/truncate-utf8-bytes-1.0.2.tgz#405923909592d56f78a5818434b0b78489ca5f2b" - dependencies: - utf8-byte-length "^1.0.1" - tslib@^1.9.0: version "1.9.0" resolved "https://registry.yarnpkg.com/tslib/-/tslib-1.9.0.tgz#e37a86fda8cbbaf23a057f473c9f4dc64e5fc2e8" @@ -4045,6 +4113,12 @@ union-value@^1.0.0: is-extendable "^0.1.1" set-value "^0.4.3" +union@~0.4.3: + version "0.4.6" + resolved "https://registry.yarnpkg.com/union/-/union-0.4.6.tgz#198fbdaeba254e788b0efcb630bc11f24a2959e0" + dependencies: + qs "~2.3.3" + uniq@^1.0.1: version "1.0.1" resolved "https://registry.yarnpkg.com/uniq/-/uniq-1.0.1.tgz#b31c5ae8254844a3a8281541ce2b04b865a734ff" @@ -4078,6 +4152,10 @@ urix@^0.1.0: version "0.1.0" resolved "https://registry.yarnpkg.com/urix/-/urix-0.1.0.tgz#da937f7a62e21fec1fd18d49b35c2935067a6c72" +url-join@^2.0.2: + version "2.0.5" + resolved "https://registry.yarnpkg.com/url-join/-/url-join-2.0.5.tgz#5af22f18c052a000a48d7b82c5e9c2e2feeda728" + url@^0.11.0: version "0.11.0" resolved "https://registry.yarnpkg.com/url/-/url-0.11.0.tgz#3838e97cfc60521eb73c525a8e55bfdd9e2e28f1" @@ -4091,10 +4169,6 @@ use@^3.1.0: dependencies: kind-of "^6.0.2" -utf8-byte-length@^1.0.1: - version "1.0.4" - resolved "https://registry.yarnpkg.com/utf8-byte-length/-/utf8-byte-length-1.0.4.tgz#f45f150c4c66eee968186505ab93fcbb8ad6bf61" - util-deprecate@~1.0.1: version "1.0.2" resolved "https://registry.yarnpkg.com/util-deprecate/-/util-deprecate-1.0.2.tgz#450d4dc9fa70de732762fbd2d4a28981419a0ccf" @@ -4406,6 +4480,10 @@ wide-align@^1.1.0: dependencies: string-width "^1.0.2" +wordwrap@~0.0.2: + version "0.0.3" + resolved "https://registry.yarnpkg.com/wordwrap/-/wordwrap-0.0.3.tgz#a3d5da6cd5c0bc0008d37234bbaf1bed63059107" + wordwrap@~1.0.0: version "1.0.0" resolved "https://registry.yarnpkg.com/wordwrap/-/wordwrap-1.0.0.tgz#27584810891456a4171c8d0226441ade90cbcaeb" diff --git a/mnist-transfer-cnn/build-resources.sh b/mnist-transfer-cnn/build-resources.sh new file mode 100755 index 000000000..b7edad544 --- /dev/null +++ b/mnist-transfer-cnn/build-resources.sh @@ -0,0 +1,63 @@ +#!/usr/bin/env bash + +# Copyright 2018 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +# Builds resources for the MNIST Transfer CNN demo. +# Note this is not necessary to run the demo, because we already provide hosted +# pre-built resources. +# Usage example: do this from the 'mnist-transfer-cnn' directory: +# ./build-resources.sh + +set -e + +DEMO_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +TRAIN_EPOCHS=5 +while true; do + if [[ "$1" == "--epochs" ]]; then + TRAIN_EPOCHS=$2 + shift 2 + elif [[ -z "$1" ]]; then + break + else + echo "ERROR: Unrecognized argument: $1" + exit 1 + fi +done + +RESOURCES_ROOT="${DEMO_DIR}/dist/resources" +rm -rf "${RESOURCES_ROOT}" +mkdir -p "${RESOURCES_ROOT}" + +# Run Python script to generate the pretrained model and weights files. +# Make sure you have installed the tensorfowjs pip package first. + +python "${DEMO_DIR}/python/mnist_transfer_cnn.py" \ + --epochs "${TRAIN_EPOCHS}" \ + --artifacts_dir "${RESOURCES_ROOT}" \ + --gte5_data_path_prefix "${RESOURCES_ROOT}/gte5" \ + --gte5_cutoff 1024 + +cd ${DEMO_DIR} +yarn +yarn build + +echo +echo "-----------------------------------------------------------" +echo "Resources written to ${RESOURCES_ROOT}." +echo "You can now run the demo with 'yarn watch'." +echo "-----------------------------------------------------------" +echo diff --git a/mnist-transfer-cnn/index.html b/mnist-transfer-cnn/index.html index 1833ac96d..821ed5267 100644 --- a/mnist-transfer-cnn/index.html +++ b/mnist-transfer-cnn/index.html @@ -70,6 +70,11 @@

TensorFlow.js Layers: MNIST CNN Transfer Learning Demo

+
+ + +
+
diff --git a/mnist-transfer-cnn/index.js b/mnist-transfer-cnn/index.js index d144acbdc..4816cc026 100644 --- a/mnist-transfer-cnn/index.js +++ b/mnist-transfer-cnn/index.js @@ -20,19 +20,28 @@ import * as loader from './loader'; import * as ui from './ui'; import * as util from './util'; -const HOSTED_MODEL_JSON_URL = - 'https://storage.googleapis.com/tfjs-models/tfjs/mnist_transfer_cnn_v1/model.json'; -const HOSTED_TRAIN_DATA_JSON_URL = - 'https://storage.googleapis.com/tfjs-models/tfjs/mnist_transfer_cnn_v1/gte5.train.json'; -const HOSTED_TEST_DATA_JSON_URL = - 'https://storage.googleapis.com/tfjs-models/tfjs/mnist_transfer_cnn_v1/gte5.test.json'; +const HOSTED_URLS = { + model: + 'https://storage.googleapis.com/tfjs-models/tfjs/mnist_transfer_cnn_v1/model.json', + train: + 'https://storage.googleapis.com/tfjs-models/tfjs/mnist_transfer_cnn_v1/gte5.train.json', + test: + 'https://storage.googleapis.com/tfjs-models/tfjs/mnist_transfer_cnn_v1/gte5.test.json' +}; + +const LOCAL_URLS = { + model: 'http://localhost:1235/resources/model.json', + train: 'http://localhost:1235/resources/gte5.train.json', + test: 'http://localhost:1235/resources/gte5.test.json' +}; class MnistTransferCNNPredictor { /** * Initializes the MNIST Transfer CNN demo. */ - async init() { - this.model = await loader.loadHostedPretrainedModel(HOSTED_MODEL_JSON_URL); + async init(urls) { + this.urls = urls; + this.model = await loader.loadHostedPretrainedModel(urls.model); this.imageSize = this.model.layers[0].batchInputShape[1]; this.numClasses = 5; @@ -44,10 +53,10 @@ class MnistTransferCNNPredictor { async loadRetrainData() { console.log('Loading data for transfer learning...'); ui.status('Loading data for transfer learning...'); - this.gte5TrainData = await loader.loadHostedData( - HOSTED_TRAIN_DATA_JSON_URL, this.numClasses); + this.gte5TrainData = + await loader.loadHostedData(this.urls.train, this.numClasses); this.gte5TestData = - await loader.loadHostedData(HOSTED_TEST_DATA_JSON_URL, this.numClasses); + await loader.loadHostedData(this.urls.test, this.numClasses); ui.status('Done loading data for transfer learning.'); } @@ -135,10 +144,31 @@ class MnistTransferCNNPredictor { * and retrain functions with the UI. */ async function setupMnistTransferCNN() { - const predictor = await new MnistTransferCNNPredictor().init(); - ui.prepUI( - x => predictor.predict(x), x => predictor.retrainModel(), - predictor.testExamples, predictor.imageSize); + if (await loader.urlExists(HOSTED_URLS.model)) { + ui.status('Model available: ' + HOSTED_URLS.model); + const button = document.getElementById('load-pretrained-remote'); + button.addEventListener('click', async () => { + const predictor = await new MnistTransferCNNPredictor().init(HOSTED_URLS); + ui.prepUI( + x => predictor.predict(x), x => predictor.retrainModel(), + predictor.testExamples, predictor.imageSize); + }); + button.style.display = 'inline-block'; + } + + if (await loader.urlExists(LOCAL_URLS.model)) { + ui.status('Model available: ' + LOCAL_URLS.model); + const button = document.getElementById('load-pretrained-local'); + button.addEventListener('click', async () => { + const predictor = await new MnistTransferCNNPredictor().init(LOCAL_URLS); + ui.prepUI( + x => predictor.predict(x), x => predictor.retrainModel(), + predictor.testExamples, predictor.imageSize); + }); + button.style.display = 'inline-block'; + } + + ui.status('Standing by.'); } setupMnistTransferCNN(); diff --git a/mnist-transfer-cnn/loader.js b/mnist-transfer-cnn/loader.js index 92a7389b3..600b90b58 100644 --- a/mnist-transfer-cnn/loader.js +++ b/mnist-transfer-cnn/loader.js @@ -19,6 +19,19 @@ import * as tf from '@tensorflow/tfjs'; import * as ui from './ui'; import * as util from './util'; +/** + * Test whether a given URL is retrievable. + */ +export async function urlExists(url) { + ui.status('Testing url ' + url); + try { + const response = await fetch(url, {method: 'HEAD'}); + return response.ok; + } catch (err) { + return false; + } +} + /** * Load pretrained model stored at a remote URL. * @@ -29,9 +42,13 @@ export async function loadHostedPretrainedModel(url) { try { const model = await tf.loadModel(url); ui.status('Done loading pretrained model.'); + // We can't load a model twice due to + // https://github.com/tensorflow/tfjs/issues/34 + // Therefore we remove the load buttons to avoid user confusion. + ui.disableLoadModelButtons(); return model; } catch (err) { - console.log(err); + console.error(err); ui.status('Loading pretrained model failed.'); } } @@ -51,7 +68,7 @@ export async function loadHostedData(url, numClasses) { ui.status('Done loading data.'); return result; } catch (err) { - console.log(err); + console.error(err); ui.status('Loading data failed.'); } } diff --git a/mnist-transfer-cnn/package.json b/mnist-transfer-cnn/package.json index f05c6e629..6b22e2338 100644 --- a/mnist-transfer-cnn/package.json +++ b/mnist-transfer-cnn/package.json @@ -13,7 +13,7 @@ "vega-embed": "~3.0.0" }, "scripts": { - "watch": "NODE_ENV=development parcel --no-hmr --open index.html ", + "watch": "./serve.sh", "build": "NODE_ENV=production parcel build index.html --no-minify --public-url ./" }, "devDependencies": { @@ -21,7 +21,8 @@ "babel-polyfill": "~6.26.0", "babel-preset-env": "~1.6.1", "clang-format": "~1.2.2", - "parcel-bundler": "~1.6.2" + "http-server": "~0.10.0", + "parcel-bundler": "~1.7.0" }, "babel": { "presets": [ diff --git a/mnist-transfer-cnn/python/__init__.py b/mnist-transfer-cnn/python/__init__.py new file mode 100644 index 000000000..636b70f0d --- /dev/null +++ b/mnist-transfer-cnn/python/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2018 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function diff --git a/mnist-transfer-cnn/python/mnist_transfer_cnn.py b/mnist-transfer-cnn/python/mnist_transfer_cnn.py new file mode 100644 index 000000000..b9856b661 --- /dev/null +++ b/mnist-transfer-cnn/python/mnist_transfer_cnn.py @@ -0,0 +1,288 @@ +# Copyright 2018 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Train a simple CNN model for transfer learning in browser with TF.js Layers. + +The model architecture in this file is based on the Keras stock example at: + https://github.com/keras-team/keras/blob/master/examples/mnist_transfer_cnn.py. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import datetime +import json + +import keras +from keras import backend as K +import tensorflowjs as tfjs + +# Input image dimensions. +IMG_ROWS, IMG_COLS = 28, 28 +NUM_CLASSES = 5 + +INPUT_SHAPE = ((1, IMG_ROWS, IMG_COLS) + if K.image_data_format() == 'channels_first' + else (IMG_ROWS, IMG_COLS, 1)) + + +def load_mnist_data(gte5_cutoff): + (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() + + # create two datasets one with digits below 5 and one with 5 and above + x_train_lt5 = x_train[y_train < 5] + y_train_lt5 = y_train[y_train < 5] + x_test_lt5 = x_test[y_test < 5] + y_test_lt5 = y_test[y_test < 5] + + x_train_gte5 = x_train[y_train >= 5] + y_train_gte5 = y_train[y_train >= 5] - 5 + x_test_gte5 = x_test[y_test >= 5] + y_test_gte5 = y_test[y_test >= 5] - 5 + + if gte5_cutoff > 0: + x_train_gte5 = x_train_gte5[:gte5_cutoff, ...] + y_train_gte5 = y_train_gte5[:gte5_cutoff, ...] + x_test_gte5 = x_test_gte5[:gte5_cutoff, ...] + y_test_gte5 = y_test_gte5[:gte5_cutoff, ...] + + return (x_train_lt5, y_train_lt5, x_test_lt5, y_test_lt5, + x_train_gte5, y_train_gte5, x_test_gte5, y_test_gte5) + + +def train_model(model, + optimizer, + train, + test, + num_classes, + batch_size=128, + epochs=5): + """Train or re-train a model using data. + + Args: + model: `keras.Model` instance to (re-)train. + optimizer: Name of the optimizer to use. + train: Training data, of shape (NUM_EXAMPLES, IMG_ROWS, IMG_COLS). + test: Test data, of shape (NUM_EXAMPLES, IMG_ROWS, IMG_COLS). + num_classes: Number of classes. + batch_size: Batch size. + epochs: How many epochs to train. + """ + x_train = train[0].reshape((train[0].shape[0],) + INPUT_SHAPE) + x_test = test[0].reshape((test[0].shape[0],) + INPUT_SHAPE) + x_train = x_train.astype('float32') + x_test = x_test.astype('float32') + x_train /= 255 + x_test /= 255 + + # convert class vectors to binary class matrices + y_train = keras.utils.to_categorical(train[1], num_classes) + y_test = keras.utils.to_categorical(test[1], num_classes) + + model.compile(loss='categorical_crossentropy', + optimizer=optimizer, + metrics=['accuracy']) + + t = datetime.datetime.now() + model.fit(x_train, y_train, + batch_size=batch_size, + epochs=epochs, + verbose=1, + validation_data=(x_test, y_test)) + print('Training time: %s' % (datetime.datetime.now() - t)) + score = model.evaluate(x_test, y_test, verbose=0) + print('Test score:', score[0]) + print('Test accuracy:', score[1]) + + +def write_mnist_examples_to_json_file(x, y, js_path): + """Write a batch of MNIST examples to a JavaScript (.js) file. + + Args: + x: A numpy array representing the image data, with shape + (NUM_EXAMPLES, IMG_ROWS, IMG_COLS). + y: A numpy array representing the image labels (as integer indices), + with shape (NUM_EXAMPLES,). + js_path: Path to the JavaScript file to write to. + """ + data = [] + num_examples = x.shape[0] + for i in range(num_examples): + data.append({'x': x[i, ...].tolist(), 'y': int(y[i])}) + with open(js_path, 'wt') as f: + f.write(json.dumps(data)) + + +def write_gte5_data(x_train_gte5, + y_train_gte5, + x_test_gte5, + y_test_gte5, + gte5_data_path_prefix): + """Write the transfer-learning data to .js files. + + Args: + x_train_lt5: x (image) data for training: digits >= 5. + y_train_lt5: y (label) data for training: digits >= 5. + x_test_lt5: x (image) data for test (validation): digits >= 5. + y_test_lt5: y (label) data for test (validation): digits >= 5. + gte5_data_path_prefix: Path prefix for writing the files. For example, + if the value is '/tmp/foo', the train and test files will be written + at '/tmp/foo.train.js' and '/tmp/foo.test.js', respectively. + """ + gte5_train_path = gte5_data_path_prefix + '.train.json' + write_mnist_examples_to_json_file( + x_train_gte5, y_train_gte5, gte5_train_path) + print('Wrote gte5 training data to: %s' % gte5_train_path) + gte5_test_path = gte5_data_path_prefix + '.test.json' + write_mnist_examples_to_json_file( + x_test_gte5, y_test_gte5, gte5_test_path) + print('Wrote gte5 test data to: %s' % gte5_test_path) + + +def train_and_save_model(filters, + kernel_size, + pool_size, + batch_size, + epochs, + x_train_lt5, + y_train_lt5, + x_test_lt5, + y_test_lt5, + artifacts_dir, + optimizer='adam'): + """Train and save MNIST CNN model. + + Args: + filters: number of filters for convolution layers. + kernel_size: kernel size for convolution layers. + pool_size: pooling kernel size for pooling layers. + batch_size: batch size. + epochs: number of epochs to train for. + x_train_lt5: x (image) data for training: digits < 5. + y_train_lt5: y (label) data for training: digits < 5. + x_test_lt5: x (image) data for test (validation): digits < 5. + y_test_lt5: y (label) data for test (validation): digits < 5. + artifacts_dir: Directory to save the model artifacts (model topology JSON, + weights and weight manifest) in. + optimizer: The name of the optimizer to use, as a string. + """ + + feature_layers = [ + keras.layers.Conv2D(filters, kernel_size, + padding='valid', + input_shape=INPUT_SHAPE), + keras.layers.Activation('relu'), + keras.layers.Conv2D(filters, kernel_size), + keras.layers.Activation('relu'), + keras.layers.MaxPooling2D(pool_size=pool_size), + keras.layers.Dropout(0.25), + keras.layers.Flatten(), + ] + classification_layers = [ + keras.layers.Dense(128), + keras.layers.Activation('relu'), + keras.layers.Dropout(0.5), + keras.layers.Dense(NUM_CLASSES), + keras.layers.Activation('softmax') + ] + model = keras.models.Sequential(feature_layers + classification_layers) + + train_model(model, + optimizer, + (x_train_lt5, y_train_lt5), + (x_test_lt5, y_test_lt5), + NUM_CLASSES, batch_size=batch_size, epochs=epochs) + tfjs.converters.save_keras_model(model, artifacts_dir) + + +def main(): + (x_train_lt5, y_train_lt5, x_test_lt5, y_test_lt5, + x_train_gte5, y_train_gte5, x_test_gte5, y_test_gte5) = load_mnist_data( + FLAGS.gte5_cutoff) + + write_gte5_data(x_train_gte5, y_train_gte5, x_test_gte5, y_test_gte5, + FLAGS.gte5_data_path_prefix) + + train_and_save_model(FLAGS.filters, + FLAGS.kernel_size, + FLAGS.pool_size, + FLAGS.batch_size, + FLAGS.epochs, + x_train_lt5, + y_train_lt5, + x_test_lt5, + y_test_lt5, + FLAGS.artifacts_dir, + optimizer=FLAGS.optimizer) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser('MNIST model training and serialization') + parser.add_argument( + '--epochs', + type=int, + default=5, + help='Number of epochs to train the Keras model for.') + parser.add_argument( + '--batch_size', + type=int, + default=128, + help='Batch size for training.') + parser.add_argument( + '--filters', + type=int, + default=32, + help='Number of convolutional filters to use.') + parser.add_argument( + '--pool_size', + type=int, + default=2, + help='Size of pooling area for max pooling.') + parser.add_argument( + '--kernel_size', + type=int, + default=3, + help='Convolutional kernel size.') + parser.add_argument( + '--artifacts_dir', + type=str, + default='/tmp/mnist.keras', + help='Local path for saving the TensorFlow.js artifacts.') + parser.add_argument( + '--optimizer', + type=str, + default='adam', + help='Name of the optimizer to use for training.') + parser.add_argument( + '--gte5_cutoff', + type=int, + default=1024, + help='If value is > 0, will cause only the first this many examples ' + 'in the gte5 (label >= 5) transfer learning subset to be written to ' + 'JavaScript (.js) files.') + parser.add_argument( + '--gte5_data_path_prefix', + type=str, + default='/tmp/mnist_transfer_cnn.gte5', + help='Prefix for the label >= 5 data for transfer learning.' + 'For example, if the prefix is /tmp/foo.gte5, the train and test ' + 'data will be written to /tmp/foo.gte5.train.js and ' + '/tmp/foo.gte5.test.js, respectively.') + # TODO(cais, soergel): Eventually we want to use the dataset API and not write + # writing the data to file. + + FLAGS, _ = parser.parse_known_args() + main() diff --git a/mnist-transfer-cnn/python/mnist_transfer_cnn_test.py b/mnist-transfer-cnn/python/mnist_transfer_cnn_test.py new file mode 100644 index 000000000..fb6aaacdc --- /dev/null +++ b/mnist-transfer-cnn/python/mnist_transfer_cnn_test.py @@ -0,0 +1,78 @@ +# Copyright 2018 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Test for the MNIST transfer learning CNN model.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import os +import shutil +import tempfile +import unittest + +import numpy as np + +from . import mnist_transfer_cnn + +class MnistTest(unittest.TestCase): + + def setUp(self): + self._tmp_dir = tempfile.mkdtemp() + super(MnistTest, self).setUp() + + def tearDown(self): + if os.path.isdir(self._tmp_dir): + shutil.rmtree(self._tmp_dir) + super(MnistTest, self).tearDown() + + def _getFakeMnistData(self): + """Generate some fake MNIST data for testing.""" + x_train = np.random.rand(4, 28, 28) + y_train = np.random.rand(4,) + x_test = np.random.rand(4, 28, 28) + y_test = np.random.rand(4,) + return x_train, y_train, x_test, y_test + + def testWriteGte5DataToJsFiles(self): + x_train, y_train, x_test, y_test = self._getFakeMnistData() + js_file_prefix = os.path.join(self._tmp_dir, 'gte5') + mnist_transfer_cnn.write_gte5_data(x_train, y_train, x_test, y_test, + js_file_prefix) + + for js_path_suffix in ('.train.json', '.test.json'): + with open(js_file_prefix + js_path_suffix, 'rt') as f: + json_string = f.read() + data = json.loads(json_string) + if 'train' in js_path_suffix: + self.assertEqual(x_train.shape[0], len(data)) + else: + self.assertEqual(x_test.shape[0], len(data)) + self.assertEqual(28, len(data[0]['x'])) + self.assertEqual(28, len(data[0]['x'][0])) + + def testTrainWithFakeDataAndSave(self): + x_train, y_train, x_test, y_test = self._getFakeMnistData() + mnist_transfer_cnn.train_and_save_model( + 2, 2, 2, 2, 1, x_train, y_train, x_test, y_test, + self._tmp_dir, optimizer='adam') + + # Check that the model json file is created. + json.load(open(os.path.join(self._tmp_dir, 'model.json'), 'rt')) + + +if __name__ == '__main__': + unittest.main() diff --git a/mnist-transfer-cnn/serve.sh b/mnist-transfer-cnn/serve.sh new file mode 100755 index 000000000..e5ede63ed --- /dev/null +++ b/mnist-transfer-cnn/serve.sh @@ -0,0 +1,42 @@ +#!/usr/bin/env bash + +# Copyright 2018 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +# This script starts two HTTP servers on different ports: +# * Port 1234 (using parcel) serves HTML and JavaScript. +# * Port 1235 (using http-server) serves pretrained model resources. +# +# The reason for this arrangement is that Parcel currently has a limitation that +# prevents it from serving the pretrained models; see +# https://github.com/parcel-bundler/parcel/issues/1098. Once that issue is +# resolved, a single Parcel server will be sufficient. + +NODE_ENV=development +RESOURCE_PORT=1235 + +# Ensure that http-server is available +yarn + +echo Starting the pretrained model server... +node_modules/http-server/bin/http-server dist --cors -p "${RESOURCE_PORT}" > /dev/null & HTTP_SERVER_PID=$! + +echo Starting the example html/js server... +# This uses port 1234 by default. +node_modules/parcel-bundler/bin/cli.js serve -d dist --open --no-hmr --public-url / index.html + +# When the Parcel server exits, kill the http-server too. +kill $HTTP_SERVER_PID + diff --git a/mnist-transfer-cnn/ui.js b/mnist-transfer-cnn/ui.js index ed22ab1ed..d13d7e7ed 100644 --- a/mnist-transfer-cnn/ui.js +++ b/mnist-transfer-cnn/ui.js @@ -19,6 +19,7 @@ import * as tf from '@tensorflow/tfjs'; import * as util from './util'; export function status(statusText, statusColor) { + console.log(statusText); document.getElementById('status').textContent = statusText; document.getElementById('status').style.color = statusColor; } @@ -106,3 +107,8 @@ export function setPredictResults(predictOut, winner) { predictValues.innerHTML = valTds; document.getElementById('winner').textContent = winner; } + +export function disableLoadModelButtons() { + document.getElementById('load-pretrained-remote').style.display = 'none'; + document.getElementById('load-pretrained-local').style.display = 'none'; +} diff --git a/mnist-transfer-cnn/yarn.lock b/mnist-transfer-cnn/yarn.lock index 2e61d4f3d..87bfdaf3c 100644 --- a/mnist-transfer-cnn/yarn.lock +++ b/mnist-transfer-cnn/yarn.lock @@ -705,6 +705,10 @@ binary-extensions@^1.0.0: version "1.11.0" resolved "https://registry.yarnpkg.com/binary-extensions/-/binary-extensions-1.11.0.tgz#46aa1751fb6a2f93ee5e689bb1087d4b14c6c205" +bindings@~1.2.1: + version "1.2.1" + resolved "https://registry.yarnpkg.com/bindings/-/bindings-1.2.1.tgz#14ad6113812d2d37d72e67b4cacb4bb726505f11" + block-stream@*: version "0.0.9" resolved "https://registry.yarnpkg.com/block-stream/-/block-stream-0.0.9.tgz#13ebfe778a03205cfe03751481ebb4b3300c126a" @@ -766,12 +770,6 @@ brorand@^1.0.1: version "1.1.0" resolved "https://registry.yarnpkg.com/brorand/-/brorand-1.1.0.tgz#12c25efe40a45e3c323eb8675a0a0ce57b22371f" -browser-resolve@^1.11.2: - version "1.11.2" - resolved "https://registry.yarnpkg.com/browser-resolve/-/browser-resolve-1.11.2.tgz#8ff09b0a2c421718a1051c260b32e48f442938ce" - dependencies: - resolve "1.1.7" - browserify-aes@^1.0.0, browserify-aes@^1.0.4: version "1.1.1" resolved "https://registry.yarnpkg.com/browserify-aes/-/browserify-aes-1.1.1.tgz#38b7ab55edb806ff2dcda1a7f1620773a477c49f" @@ -1066,6 +1064,10 @@ colormin@^1.0.5: css-color-names "0.0.4" has "^1.0.1" +colors@1.0.3: + version "1.0.3" + resolved "https://registry.yarnpkg.com/colors/-/colors-1.0.3.tgz#0433f44d809680fdeb60ed260f1b0c262e82a40b" + colors@~1.1.2: version "1.1.2" resolved "https://registry.yarnpkg.com/colors/-/colors-1.1.2.tgz#168a4701756b6a7f51a12ce0c97bfa28c084ed63" @@ -1145,6 +1147,10 @@ core-util-is@1.0.2, core-util-is@~1.0.0: version "1.0.2" resolved "https://registry.yarnpkg.com/core-util-is/-/core-util-is-1.0.2.tgz#b5fd54220aa2bc5ab57aab7140c940754503c1a7" +corser@~2.0.0: + version "2.0.1" + resolved "https://registry.yarnpkg.com/corser/-/corser-2.0.1.tgz#8eda252ecaab5840dcd975ceb90d9370c819ff87" + create-ecdh@^4.0.0: version "4.0.0" resolved "https://registry.yarnpkg.com/create-ecdh/-/create-ecdh-4.0.0.tgz#888c723596cdf7612f6498233eebd7a35301737d" @@ -1433,6 +1439,13 @@ date-now@^0.1.4: version "0.1.4" resolved "https://registry.yarnpkg.com/date-now/-/date-now-0.1.4.tgz#eaf439fd4d4848ad74e5cc7dbef200672b9e345b" +deasync@^0.1.12: + version "0.1.12" + resolved "https://registry.yarnpkg.com/deasync/-/deasync-0.1.12.tgz#0159492a4133ab301d6c778cf01e74e63b10e549" + dependencies: + bindings "~1.2.1" + nan "^2.0.7" + debug@2.6.9, debug@^2.2.0, debug@^2.3.3, debug@^2.6.8: version "2.6.9" resolved "https://registry.yarnpkg.com/debug/-/debug-2.6.9.tgz#5d128515df134ff327e90a4c93f4e077a536341f" @@ -1585,6 +1598,15 @@ ecc-jsbn@~0.1.1: dependencies: jsbn "~0.1.0" +ecstatic@^2.0.0: + version "2.2.1" + resolved "https://registry.yarnpkg.com/ecstatic/-/ecstatic-2.2.1.tgz#b5087fad439dd9dd49d31e18131454817fe87769" + dependencies: + he "^1.1.1" + mime "^1.2.11" + minimist "^1.1.0" + url-join "^2.0.2" + editorconfig@^0.13.2: version "0.13.3" resolved "https://registry.yarnpkg.com/editorconfig/-/editorconfig-0.13.3.tgz#e5219e587951d60958fd94ea9a9a008cdeff1b34" @@ -1696,6 +1718,10 @@ etag@~1.8.1: version "1.8.1" resolved "https://registry.yarnpkg.com/etag/-/etag-1.8.1.tgz#41ae2eeb65efa62268aebfea83ac7d79299b0887" +eventemitter3@1.x.x: + version "1.2.0" + resolved "https://registry.yarnpkg.com/eventemitter3/-/eventemitter3-1.2.0.tgz#1c86991d816ad1e504750e73874224ecf3bec508" + events@^1.0.0: version "1.1.1" resolved "https://registry.yarnpkg.com/events/-/events-1.1.1.tgz#9ebdb7635ad099c70dcc4c2a1f5004288e8bd924" @@ -2033,6 +2059,10 @@ hawk@3.1.3, hawk@~3.1.3: hoek "2.x.x" sntp "1.x.x" +he@^1.1.1: + version "1.1.1" + resolved "https://registry.yarnpkg.com/he/-/he-1.1.1.tgz#93410fd21b009735151f8868c2f271f3427e23fd" + hmac-drbg@^1.0.0: version "1.0.1" resolved "https://registry.yarnpkg.com/hmac-drbg/-/hmac-drbg-1.0.1.tgz#d2745701025a6c775a6c545793ed502fc0c649a1" @@ -2060,7 +2090,7 @@ html-comment-regex@^1.1.0: version "1.1.1" resolved "https://registry.yarnpkg.com/html-comment-regex/-/html-comment-regex-1.1.1.tgz#668b93776eaae55ebde8f3ad464b307a4963625e" -htmlnano@^0.1.6: +htmlnano@^0.1.7: version "0.1.7" resolved "https://registry.yarnpkg.com/htmlnano/-/htmlnano-0.1.7.tgz#1751937a05f122a3248dba1c63edf01f6d36cb84" dependencies: @@ -2091,6 +2121,26 @@ http-errors@~1.6.2: setprototypeof "1.0.3" statuses ">= 1.3.1 < 2" +http-proxy@^1.8.1: + version "1.16.2" + resolved "https://registry.yarnpkg.com/http-proxy/-/http-proxy-1.16.2.tgz#06dff292952bf64dbe8471fa9df73066d4f37742" + dependencies: + eventemitter3 "1.x.x" + requires-port "1.x.x" + +http-server@~0.10.0: + version "0.10.0" + resolved "https://registry.yarnpkg.com/http-server/-/http-server-0.10.0.tgz#b2a446b16a9db87ed3c622ba9beb1b085b1234a7" + dependencies: + colors "1.0.3" + corser "~2.0.0" + ecstatic "^2.0.0" + http-proxy "^1.8.1" + opener "~1.4.0" + optimist "0.6.x" + portfinder "^1.0.13" + union "~0.4.3" + http-signature@~1.1.0: version "1.1.1" resolved "https://registry.yarnpkg.com/http-signature/-/http-signature-1.1.1.tgz#df72e267066cd0ac67fb76adf8e134a8fbcf91bf" @@ -2610,6 +2660,10 @@ mime@1.4.1: version "1.4.1" resolved "https://registry.yarnpkg.com/mime/-/mime-1.4.1.tgz#121f9ebc49e3766f311a76e1fa1c8003c4b03aa6" +mime@^1.2.11: + version "1.6.0" + resolved "https://registry.yarnpkg.com/mime/-/mime-1.6.0.tgz#32cd9e5c64553bd58d19a568af452acff04981b1" + mimic-fn@^1.0.0: version "1.2.0" resolved "https://registry.yarnpkg.com/mimic-fn/-/mimic-fn-1.2.0.tgz#820c86a39334640e99516928bd03fca88057d022" @@ -2632,10 +2686,14 @@ minimist@0.0.8: version "0.0.8" resolved "https://registry.yarnpkg.com/minimist/-/minimist-0.0.8.tgz#857fcabfc3397d2625b8228262e86aa7a011b05d" -minimist@^1.1.3, minimist@^1.2.0: +minimist@^1.1.0, minimist@^1.1.3, minimist@^1.2.0: version "1.2.0" resolved "https://registry.yarnpkg.com/minimist/-/minimist-1.2.0.tgz#a35008b20f41383eec1fb914f4cd5df79a264284" +minimist@~0.0.1: + version "0.0.10" + resolved "https://registry.yarnpkg.com/minimist/-/minimist-0.0.10.tgz#de3f98543dbf96082be48ad1a0c7cda836301dcf" + mixin-deep@^1.2.0: version "1.3.1" resolved "https://registry.yarnpkg.com/mixin-deep/-/mixin-deep-1.3.1.tgz#a49e7268dce1a0d9698e45326c5626df3543d0fe" @@ -2643,7 +2701,7 @@ mixin-deep@^1.2.0: for-in "^1.0.2" is-extendable "^1.0.1" -"mkdirp@>=0.5 0", mkdirp@^0.5.1, mkdirp@~0.5.0, mkdirp@~0.5.1: +mkdirp@0.5.x, "mkdirp@>=0.5 0", mkdirp@^0.5.1, mkdirp@~0.5.0, mkdirp@~0.5.1: version "0.5.1" resolved "https://registry.yarnpkg.com/mkdirp/-/mkdirp-0.5.1.tgz#30057438eac6cf7f8c4767f38648d6697d75c903" dependencies: @@ -2653,7 +2711,7 @@ ms@2.0.0: version "2.0.0" resolved "https://registry.yarnpkg.com/ms/-/ms-2.0.0.tgz#5608aeadfc00be6c2901df5f9861788de0d597c8" -nan@^2.3.0, nan@^2.4.0: +nan@^2.0.7, nan@^2.3.0, nan@^2.4.0: version "2.10.0" resolved "https://registry.yarnpkg.com/nan/-/nan-2.10.0.tgz#96d0cd610ebd58d4b4de9cc0c6828cda99c7548f" @@ -2860,12 +2918,23 @@ once@^1.3.0, once@^1.3.3: dependencies: wrappy "1" +opener@~1.4.0: + version "1.4.3" + resolved "https://registry.yarnpkg.com/opener/-/opener-1.4.3.tgz#5c6da2c5d7e5831e8ffa3964950f8d6674ac90b8" + opn@^5.1.0: version "5.3.0" resolved "https://registry.yarnpkg.com/opn/-/opn-5.3.0.tgz#64871565c863875f052cfdf53d3e3cb5adb53b1c" dependencies: is-wsl "^1.1.0" +optimist@0.6.x: + version "0.6.1" + resolved "https://registry.yarnpkg.com/optimist/-/optimist-0.6.1.tgz#da3ea74686fa21a19a111c326e90eb15a0196686" + dependencies: + minimist "~0.0.1" + wordwrap "~0.0.2" + optionator@^0.8.1: version "0.8.2" resolved "https://registry.yarnpkg.com/optionator/-/optionator-0.8.2.tgz#364c5e409d3f4d6301d6c0b4c05bba50180aeb64" @@ -2938,9 +3007,9 @@ pako@~1.0.5: version "1.0.6" resolved "https://registry.yarnpkg.com/pako/-/pako-1.0.6.tgz#0101211baa70c4bca4a0f63f2206e97b7dfaf258" -parcel-bundler@~1.6.2: - version "1.6.2" - resolved "https://registry.yarnpkg.com/parcel-bundler/-/parcel-bundler-1.6.2.tgz#e415c4993b6f4c48cd410427594ce80bbb7d90f1" +parcel-bundler@~1.7.0: + version "1.7.0" + resolved "https://registry.yarnpkg.com/parcel-bundler/-/parcel-bundler-1.7.0.tgz#8c2512615fd602d2f39bd97bfd128f8fe524b321" dependencies: babel-code-frame "^6.26.0" babel-core "^6.25.0" @@ -2953,7 +3022,6 @@ parcel-bundler@~1.6.2: babel-types "^6.26.0" babylon "^6.17.4" babylon-walk "^1.0.2" - browser-resolve "^1.11.2" browserslist "^2.11.2" chalk "^2.1.0" chokidar "^2.0.1" @@ -2961,12 +3029,13 @@ parcel-bundler@~1.6.2: commander "^2.11.0" cross-spawn "^6.0.4" cssnano "^3.10.0" + deasync "^0.1.12" dotenv "^5.0.0" filesize "^3.6.0" get-port "^3.2.0" glob "^7.1.2" grapheme-breaker "^0.3.2" - htmlnano "^0.1.6" + htmlnano "^0.1.7" is-url "^1.2.2" js-yaml "^3.10.0" json5 "^0.5.1" @@ -2976,13 +3045,12 @@ parcel-bundler@~1.6.2: node-libs-browser "^2.0.0" opn "^5.1.0" physical-cpu-count "^2.0.0" - postcss "^6.0.10" + postcss "^6.0.19" postcss-value-parser "^3.3.0" posthtml "^0.11.2" posthtml-parser "^0.4.0" posthtml-render "^1.1.0" resolve "^1.4.0" - sanitize-filename "^1.6.1" semver "^5.4.1" serialize-to-js "^1.1.1" serve-static "^1.12.4" @@ -3089,6 +3157,14 @@ pinkie@^2.0.0: version "2.0.4" resolved "https://registry.yarnpkg.com/pinkie/-/pinkie-2.0.4.tgz#72556b80cfa0d48a974e80e77248e80ed4f7f870" +portfinder@^1.0.13: + version "1.0.13" + resolved "https://registry.yarnpkg.com/portfinder/-/portfinder-1.0.13.tgz#bb32ecd87c27104ae6ee44b5a3ccbf0ebb1aede9" + dependencies: + async "^1.5.2" + debug "^2.2.0" + mkdirp "0.5.x" + posix-character-classes@^0.1.0: version "0.1.1" resolved "https://registry.yarnpkg.com/posix-character-classes/-/posix-character-classes-0.1.1.tgz#01eac0fe3b5af71a2a6c02feabb8c1fef7e00eab" @@ -3304,9 +3380,9 @@ postcss@^5.0.10, postcss@^5.0.11, postcss@^5.0.12, postcss@^5.0.13, postcss@^5.0 source-map "^0.5.6" supports-color "^3.2.3" -postcss@^6.0.10: - version "6.0.20" - resolved "https://registry.yarnpkg.com/postcss/-/postcss-6.0.20.tgz#686107e743a12d5530cb68438c590d5b2bf72c3c" +postcss@^6.0.19: + version "6.0.21" + resolved "https://registry.yarnpkg.com/postcss/-/postcss-6.0.21.tgz#8265662694eddf9e9a5960db6da33c39e4cd069d" dependencies: chalk "^2.3.2" source-map "^0.6.1" @@ -3397,6 +3473,10 @@ q@^1.1.2: version "1.5.1" resolved "https://registry.yarnpkg.com/q/-/q-1.5.1.tgz#7e32f75b41381291d04611f1bf14109ac00651d7" +qs@~2.3.3: + version "2.3.3" + resolved "https://registry.yarnpkg.com/qs/-/qs-2.3.3.tgz#e9e85adbe75da0bbe4c8e0476a086290f863b404" + qs@~6.4.0: version "6.4.0" resolved "https://registry.yarnpkg.com/qs/-/qs-6.4.0.tgz#13e26d28ad6b0ffaa91312cd3bf708ed351e7233" @@ -3598,14 +3678,14 @@ require-main-filename@^1.0.1: version "1.0.1" resolved "https://registry.yarnpkg.com/require-main-filename/-/require-main-filename-1.0.1.tgz#97f717b69d48784f5f526a6c5aa8ffdda055a4d1" +requires-port@1.x.x: + version "1.0.0" + resolved "https://registry.yarnpkg.com/requires-port/-/requires-port-1.0.0.tgz#925d2601d39ac485e091cf0da5c6e694dc3dcaff" + resolve-url@^0.2.1: version "0.2.1" resolved "https://registry.yarnpkg.com/resolve-url/-/resolve-url-0.2.1.tgz#2c637fe77c893afd2a663fe21aa9080068e2052a" -resolve@1.1.7: - version "1.1.7" - resolved "https://registry.yarnpkg.com/resolve/-/resolve-1.1.7.tgz#203114d82ad2c5ed9e8e0411b3932875e889e97b" - resolve@^1.1.5, resolve@^1.1.6, resolve@^1.4.0: version "1.6.0" resolved "https://registry.yarnpkg.com/resolve/-/resolve-1.6.0.tgz#0fbd21278b27b4004481c395349e7aba60a9ff5c" @@ -3649,12 +3729,6 @@ safer-eval@^1.2.3: dependencies: clones "^1.1.0" -sanitize-filename@^1.6.1: - version "1.6.1" - resolved "https://registry.yarnpkg.com/sanitize-filename/-/sanitize-filename-1.6.1.tgz#612da1c96473fa02dccda92dcd5b4ab164a6772a" - dependencies: - truncate-utf8-bytes "^1.0.0" - sax@~1.2.1, sax@~1.2.4: version "1.2.4" resolved "https://registry.yarnpkg.com/sax/-/sax-1.2.4.tgz#2816234e2378bddc4e5354fab5caa895df7100d9" @@ -4138,12 +4212,6 @@ trim-right@^1.0.1: version "1.0.1" resolved "https://registry.yarnpkg.com/trim-right/-/trim-right-1.0.1.tgz#cb2e1203067e0c8de1f614094b9fe45704ea6003" -truncate-utf8-bytes@^1.0.0: - version "1.0.2" - resolved "https://registry.yarnpkg.com/truncate-utf8-bytes/-/truncate-utf8-bytes-1.0.2.tgz#405923909592d56f78a5818434b0b78489ca5f2b" - dependencies: - utf8-byte-length "^1.0.1" - tslib@^1.9.0: version "1.9.0" resolved "https://registry.yarnpkg.com/tslib/-/tslib-1.9.0.tgz#e37a86fda8cbbaf23a057f473c9f4dc64e5fc2e8" @@ -4210,6 +4278,12 @@ union-value@^1.0.0: is-extendable "^0.1.1" set-value "^0.4.3" +union@~0.4.3: + version "0.4.6" + resolved "https://registry.yarnpkg.com/union/-/union-0.4.6.tgz#198fbdaeba254e788b0efcb630bc11f24a2959e0" + dependencies: + qs "~2.3.3" + uniq@^1.0.1: version "1.0.1" resolved "https://registry.yarnpkg.com/uniq/-/uniq-1.0.1.tgz#b31c5ae8254844a3a8281541ce2b04b865a734ff" @@ -4243,6 +4317,10 @@ urix@^0.1.0: version "0.1.0" resolved "https://registry.yarnpkg.com/urix/-/urix-0.1.0.tgz#da937f7a62e21fec1fd18d49b35c2935067a6c72" +url-join@^2.0.2: + version "2.0.5" + resolved "https://registry.yarnpkg.com/url-join/-/url-join-2.0.5.tgz#5af22f18c052a000a48d7b82c5e9c2e2feeda728" + url@^0.11.0: version "0.11.0" resolved "https://registry.yarnpkg.com/url/-/url-0.11.0.tgz#3838e97cfc60521eb73c525a8e55bfdd9e2e28f1" @@ -4256,10 +4334,6 @@ use@^3.1.0: dependencies: kind-of "^6.0.2" -utf8-byte-length@^1.0.1: - version "1.0.4" - resolved "https://registry.yarnpkg.com/utf8-byte-length/-/utf8-byte-length-1.0.4.tgz#f45f150c4c66eee968186505ab93fcbb8ad6bf61" - util-deprecate@~1.0.1: version "1.0.2" resolved "https://registry.yarnpkg.com/util-deprecate/-/util-deprecate-1.0.2.tgz#450d4dc9fa70de732762fbd2d4a28981419a0ccf" @@ -4589,6 +4663,10 @@ window-size@^0.2.0: version "0.2.0" resolved "https://registry.yarnpkg.com/window-size/-/window-size-0.2.0.tgz#b4315bb4214a3d7058ebeee892e13fa24d98b075" +wordwrap@~0.0.2: + version "0.0.3" + resolved "https://registry.yarnpkg.com/wordwrap/-/wordwrap-0.0.3.tgz#a3d5da6cd5c0bc0008d37234bbaf1bed63059107" + wordwrap@~1.0.0: version "1.0.0" resolved "https://registry.yarnpkg.com/wordwrap/-/wordwrap-1.0.0.tgz#27584810891456a4171c8d0226441ade90cbcaeb" diff --git a/mobilenet/imagenet_classes.js b/mobilenet/imagenet_classes.js index e6b0777a6..09f705d79 100644 --- a/mobilenet/imagenet_classes.js +++ b/mobilenet/imagenet_classes.js @@ -1,6 +1,6 @@ /** * @license - * Copyright 2017 Google Inc. All Rights Reserved. + * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at diff --git a/sentiment/build-resources.sh b/sentiment/build-resources.sh new file mode 100755 index 000000000..2f0b2fd91 --- /dev/null +++ b/sentiment/build-resources.sh @@ -0,0 +1,73 @@ +#!/usr/bin/env bash + +# Copyright 2018 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +# Builds resources for the Sentiment demo. +# Note this is not necessary to run the demo, because we already provide hosted +# pre-built resources. +# Usage example: do this from the 'sentiment' directory: +# ./build-resources.sh lstm + +set -e + +DEMO_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +if [[ $# -lt 1 ]]; then + echo "Usage:" + echo " build-imdb-demo.sh " + echo + echo "MODEL_TYPE options: lstm | cnn" + exit 1 +fi +MODEL_TYPE=$1 +shift +echo "Using model type: ${MODEL_TYPE}" + +TRAIN_EPOCHS=5 +while true; do + if [[ "$1" == "--epochs" ]]; then + TRAIN_EPOCHS=$2 + shift 2 + elif [[ -z "$1" ]]; then + break + else + echo "ERROR: Unrecognized argument: $1" + exit 1 + fi +done + +RESOURCES_ROOT="${DEMO_DIR}/dist/resources" +rm -rf "${RESOURCES_ROOT}" +mkdir -p "${RESOURCES_ROOT}" + +# Run Python script to generate the pretrained model and weights files. +# Make sure you install the tensorflowjs pip package first. + +python "${DEMO_DIR}/python/imdb.py" \ + "${MODEL_TYPE}" \ + --epochs "${TRAIN_EPOCHS}" \ + --artifacts_dir "${RESOURCES_ROOT}" + +cd ${DEMO_DIR} +yarn +yarn build + +echo +echo "-----------------------------------------------------------" +echo "Resources written to ${RESOURCES_ROOT}." +echo "You can now run the demo with 'yarn watch'." +echo "-----------------------------------------------------------" +echo diff --git a/sentiment/index.html b/sentiment/index.html index fde239d6d..60bd040e4 100644 --- a/sentiment/index.html +++ b/sentiment/index.html @@ -32,6 +32,10 @@

TensorFlow.js Layers: Sentiment Analysis Demo


+
+ + +
Model type: diff --git a/sentiment/index.js b/sentiment/index.js index ffcb08a07..66f47ef34 100644 --- a/sentiment/index.js +++ b/sentiment/index.js @@ -19,24 +19,33 @@ import * as tf from '@tensorflow/tfjs'; import * as loader from './loader'; import * as ui from './ui'; -const HOSTED_MODEL_JSON_URL = - 'https://storage.googleapis.com/tfjs-models/tfjs/sentiment_cnn_v1/model.json'; -const HOSTED_METADATA_JSON_URL = - 'https://storage.googleapis.com/tfjs-models/tfjs/sentiment_cnn_v1/metadata.json'; + +const HOSTED_URLS = { + model: + 'https://storage.googleapis.com/tfjs-models/tfjs/sentiment_cnn_v1/model.json', + metadata: + 'https://storage.googleapis.com/tfjs-models/tfjs/sentiment_cnn_v1/metadata.json' +}; + +const LOCAL_URLS = { + model: 'http://localhost:1235/resources/model.json', + metadata: 'http://localhost:1235/resources/metadata.json' +}; class SentimentPredictor { /** * Initializes the Sentiment demo. */ - async init() { - this.model = await loader.loadHostedPretrainedModel(HOSTED_MODEL_JSON_URL); + async init(urls) { + this.urls = urls; + this.model = await loader.loadHostedPretrainedModel(urls.model); await this.loadMetadata(); return this; } async loadMetadata() { const sentimentMetadata = - await loader.loadHostedMetadata(HOSTED_METADATA_JSON_URL); + await loader.loadHostedMetadata(this.urls.metadata); ui.showMetadata(sentimentMetadata); this.indexFrom = sentimentMetadata['index_from']; this.maxLen = sentimentMetadata['max_len']; @@ -50,13 +59,11 @@ class SentimentPredictor { // Convert to lower case and remove all punctuations. const inputText = text.trim().toLowerCase().replace(/(\.|\,|\!)/g, '').split(' '); - ui.status(inputText); // Look up word indices. const inputBuffer = tf.buffer([1, this.maxLen], 'float32'); for (let i = 0; i < inputText.length; ++i) { // TODO(cais): Deal with OOV words. const word = inputText[i]; - ui.status(word); inputBuffer.set(this.wordIndex[word] + this.indexFrom, 0, i); } const input = inputBuffer.toTensor(); @@ -78,9 +85,27 @@ class SentimentPredictor { * function with the UI. */ async function setupSentiment() { - const predictor = await new SentimentPredictor().init(); - ui.setPredictFunction(x => predictor.predict(x)); - ui.prepUI(x => predictor.predict(x)); + if (await loader.urlExists(HOSTED_URLS.model)) { + ui.status('Model available: ' + HOSTED_URLS.model); + const button = document.getElementById('load-pretrained-remote'); + button.addEventListener('click', async () => { + const predictor = await new SentimentPredictor().init(HOSTED_URLS); + ui.prepUI(x => predictor.predict(x)); + }); + button.style.display = 'inline-block'; + } + + if (await loader.urlExists(LOCAL_URLS.model)) { + ui.status('Model available: ' + LOCAL_URLS.model); + const button = document.getElementById('load-pretrained-local'); + button.addEventListener('click', async () => { + const predictor = await new SentimentPredictor().init(LOCAL_URLS); + ui.prepUI(x => predictor.predict(x)); + }); + button.style.display = 'inline-block'; + } + + ui.status('Standing by.'); } setupSentiment(); diff --git a/sentiment/loader.js b/sentiment/loader.js index 968bdffc5..263729b23 100644 --- a/sentiment/loader.js +++ b/sentiment/loader.js @@ -18,6 +18,19 @@ import * as tf from '@tensorflow/tfjs'; import * as ui from './ui'; +/** + * Test whether a given URL is retrievable. + */ +export async function urlExists(url) { + ui.status('Testing url ' + url); + try { + const response = await fetch(url, {method: 'HEAD'}); + return response.ok; + } catch (err) { + return false; + } +} + /** * Load pretrained model stored at a remote URL. * @@ -28,9 +41,13 @@ export async function loadHostedPretrainedModel(url) { try { const model = await tf.loadModel(url); ui.status('Done loading pretrained model.'); + // We can't load a model twice due to + // https://github.com/tensorflow/tfjs/issues/34 + // Therefore we remove the load buttons to avoid user confusion. + ui.disableLoadModelButtons(); return model; } catch (err) { - console.log(err); + console.error(err); ui.status('Loading pretrained model failed.'); } } @@ -48,7 +65,7 @@ export async function loadHostedMetadata(url) { ui.status('Done loading metadata.'); return metadata; } catch (err) { - console.log(err); + console.error(err); ui.status('Loading metadata failed.'); } } diff --git a/sentiment/package.json b/sentiment/package.json index c3fd6a4e7..01eb77ae4 100644 --- a/sentiment/package.json +++ b/sentiment/package.json @@ -13,7 +13,7 @@ "vega-embed": "~3.0.0" }, "scripts": { - "watch": "NODE_ENV=development parcel --no-hmr --open index.html ", + "watch": "./serve.sh", "build": "NODE_ENV=production parcel build index.html --no-minify --public-url ./" }, "devDependencies": { @@ -21,7 +21,8 @@ "babel-polyfill": "~6.26.0", "babel-preset-env": "~1.6.1", "clang-format": "~1.2.2", - "parcel-bundler": "~1.6.2" + "http-server": "~0.10.0", + "parcel-bundler": "~1.7.0" }, "babel": { "presets": [ diff --git a/sentiment/python/__init__.py b/sentiment/python/__init__.py new file mode 100644 index 000000000..636b70f0d --- /dev/null +++ b/sentiment/python/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2018 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function diff --git a/sentiment/python/imdb.py b/sentiment/python/imdb.py new file mode 100644 index 000000000..2abe3a4e3 --- /dev/null +++ b/sentiment/python/imdb.py @@ -0,0 +1,258 @@ +# Copyright 2018 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +"""IMDB sentiment classification example. + +Based on Python Keras examples: + https://github.com/keras-team/keras/blob/master/examples/imdb_cnn.py + https://github.com/keras-team/keras/blob/master/examples/imdb_lstm.py + +TODO(cais): Add + https://github.com/keras-team/keras/blob/master/examples/imdb_bidirectional_lstm.py + once b/74429960 is fixed. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import json +import os + +import keras +import tensorflowjs as tfjs + + +INDEX_FROM = 3 +# Offset in word index. Used during word index lookup and reverse lookup. + + +def get_word_index(reverse=False): + """Get word index. + + Args: + reverse: Reverse the index, so that the returned index is from index values + to words. + + Returns: + The word index as a `dict`. + """ + word_index = keras.datasets.imdb.get_word_index() + if reverse: + word_index = dict((word_index[key], key) for key in word_index) + return word_index + + +def indices_to_words(reverse_index, indices): + """Convert an iterable of word indices into words. + + Args: + reverse_index: An `dict` mapping word index (as `int`) to word (as `str`). + indices: An iterable of word indices. + + Returns: + Mapped words as a `list` of `str`s. + """ + return [reverse_index[i - INDEX_FROM] if i >= INDEX_FROM else 'OOV' + for i in indices] + + +def get_imdb_data(vocabulary_size, max_len): + """Get IMDB data for training and validation. + + Args: + vocabulary_size: Size of the vocabulary, as an `int`. + max_len: Cut text after this number of words. + + Returns: + x_train: An int array of shape `(num_exapmles, max_len)`: index-encoded + sentences. + y_train: An int array of shape `(num_exapmles,)`: labels for the sentences. + x_test: Same as `x_train`, but for test. + y_test: Same as `y_train`, but for test. + """ + print("Getting IMDB data with vocabulary_size %d" % vocabulary_size) + (x_train, y_train), (x_test, y_test) = keras.datasets.imdb.load_data( + num_words=vocabulary_size) + x_train = keras.preprocessing.sequence.pad_sequences(x_train, maxlen=max_len) + x_test = keras.preprocessing.sequence.pad_sequences(x_test, maxlen=max_len) + return x_train, y_train, x_test, y_test + + +def train_model(model_type, + vocabulary_size, + embedding_size, + x_train, + y_train, + x_test, + y_test, + epochs, + batch_size): + """Train a model for IMDB sentiment classification. + + Args: + model_type: Type of the model to train, as a `str`. + vocabulary_size: Vocabulary size. + embedding_size: Embedding dimensions. + x_train: An int array of shape `(num_exapmles, max_len)`: index-encoded + sentences. + y_train: An int array of shape `(num_exapmles,)`: labels for the sentences. + x_test: Same as `x_train`, but for test. + y_test: Same as `y_train`, but for test. + epochs: Number of epochs to train the model for. + batch_size: Batch size to use during trainng. + + Returns: + The trained model instance. + + Raises: + ValueError: on invalid model type. + """ + + model = keras.Sequential() + model.add(keras.layers.Embedding(vocabulary_size, embedding_size)) + if model_type == 'bidirectional_lstm': + # TODO(cais): Uncomment the following once bug b/74429960 is fixed. + # model.add(keras.layers.Embedding( + # vocabulary_size, 128, input_length=maxlen)) + # model.add(keras.layers.Bidirectional( + # keras.layers.LSTM(64, + # kernel_initializer='glorot_normal', + # recurrent_initializer ='glorot_normal'))) + # model.add(keras.layers.Dropout(0.5)) + raise NotImplementedError() + elif model_type == 'cnn': + model.add(keras.layers.Dropout(0.2)) + model.add(keras.layers.Conv1D(250, + 3, + padding='valid', + activation='relu', + strides=1)) + model.add(keras.layers.GlobalMaxPooling1D()) + model.add(keras.layers.Dense(250, activation='relu')) + elif model_type == 'lstm': + model.add( + keras.layers.LSTM( + 128, + kernel_initializer='glorot_normal', + recurrent_initializer='glorot_normal')) + # TODO(cais): Remove glorot_normal and use the default orthogonal once + # SVD is available. + else: + raise ValueError("Invalid model type: '%s'" % model_type) + model.add(keras.layers.Dense(1, activation='sigmoid')) + + model.compile('adam', 'binary_crossentropy', metrics=['accuracy']) + model.fit(x_train, y_train, + batch_size=batch_size, + epochs=epochs, + validation_data=[x_test, y_test]) + return model + + +def main(): + x_train, y_train, x_test, y_test = ( + get_imdb_data(FLAGS.vocabulary_size, FLAGS.max_len)) + + model = train_model(FLAGS.model_type, + FLAGS.vocabulary_size, + FLAGS.embedding_size, + x_train, + y_train, + x_test, + y_test, + FLAGS.epochs, + FLAGS.batch_size) + + # Display a number test phrases and their final classification. + forward_index = get_word_index() + reverse_index = get_word_index(reverse=True) + print('\n') + for i in range(FLAGS.num_show): + print('--- Test Case %d ---' % (i + 1)) + print('Sentence: "' + + ' '.join(indices_to_words(reverse_index, x_test[i, :])) + '"') + print('Truth: %d' % y_test[i]) + print('Prediction: %s\n' % model.predict(x_test[i : i + 1, :])[0][0]) + + # Save metadata, including word index, INDEX_FROM and max_len and model + # hyperparameters. + metadata = { + 'word_index': forward_index, + 'index_from': INDEX_FROM, + 'max_len': FLAGS.max_len, + 'model_type': FLAGS.model_type, + 'vocabulary_size': FLAGS.vocabulary_size, + 'embedding_size': FLAGS.embedding_size, + 'epochs': FLAGS.epochs, + 'batch_size': FLAGS.batch_size, + } + + if not os.path.isdir(FLAGS.artifacts_dir): + os.makedirs(FLAGS.artifacts_dir) + metadata_json_path = os.path.join(FLAGS.artifacts_dir, 'metadata.json') + json.dump(metadata, open(metadata_json_path, 'wt')) + print('\nSaved model metadata at: %s' % metadata_json_path) + + tfjs.converters.save_keras_model(model, FLAGS.artifacts_dir) + print('\nSaved model artifcats in directory: %s' % FLAGS.artifacts_dir) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser('IMDB sentiment classification model') + parser.add_argument( + 'model_type', + type=str, + help='Type of model to train for the IMDB sentiment classification task: ' + '(cnn | lstm)') + parser.add_argument( + '--vocabulary_size', + type=int, + default=20000, + help='Vocabulary size.') + parser.add_argument( + '--embedding_size', + type=int, + default=128, + help='Embedding size.') + parser.add_argument( + '--max_len', + type=int, + default=100, + help='Cut text after this number of words.') + parser.add_argument( + '--epochs', + type=int, + default=5, + help='Number of epochs to train the model for.') + parser.add_argument( + '--batch_size', + type=int, + default=32, + help='Batch size used during training.') + parser.add_argument( + '--num_show', + type=int, + default=5, + help='Number of sentences to show prediction score on after training.') + parser.add_argument( + '--artifacts_dir', + type=str, + default='/tmp/mnist.keras', + help='Local path for saving the TensorFlow.js artifacts.') + + FLAGS, _ = parser.parse_known_args() + main() diff --git a/sentiment/python/imdb_test.py b/sentiment/python/imdb_test.py new file mode 100644 index 000000000..003acf95d --- /dev/null +++ b/sentiment/python/imdb_test.py @@ -0,0 +1,96 @@ +# Copyright 2018 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +"""Test for the IMDB model and supporting functions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import shutil +import tempfile +import unittest + +import numpy as np +import keras + +from . import imdb + +class IMDBTest(unittest.TestCase): + + def setUp(self): + self._tmp_dir = tempfile.mkdtemp() + super(IMDBTest, self).setUp() + + def tearDown(self): + if os.path.isdir(self._tmp_dir): + shutil.rmtree(self._tmp_dir) + super(IMDBTest, self).tearDown() + + def testGetWordIndexForward(self): + word_index = imdb.get_word_index() + self.assertGreater(word_index['bar'], 0) + + def testIndicesToWordsReverse(self): + forward_index = imdb.get_word_index() + reverse_index = imdb.get_word_index(reverse=True) + self.assertEqual('bar', reverse_index[forward_index['bar']]) + + def testIndicesToWords(self): + forward_index = imdb.get_word_index() + reverse_index = imdb.get_word_index(reverse=True) + indices = [forward_index[word] + imdb.INDEX_FROM + for word in ['one', 'two', 'three']] + self.assertEqual(['one', 'two', 'three'], + imdb.indices_to_words(reverse_index, indices)) + + def testTrainLSTMModel(self): + data_size = 10 + x_train = np.random.randint(0, 100, (data_size,)) + y_train = np.random.randint(0, 2, (data_size,)) + x_test = np.random.randint(0, 100, (data_size,)) + y_test = np.random.randint(0, 2, (data_size,)) + + vocabulary_size = 100 + embedding_size = 32 + epochs = 1 + batch_size = data_size + model = imdb.train_model( + 'lstm', vocabulary_size, embedding_size, + x_train, y_train, x_test, y_test, + epochs, batch_size) + self.assertTrue(model.layers) + + def testTrainModelWithInvalidModelTypeRaisesError(self): + data_size = 10 + x_train = np.random.randint(0, 100, (data_size,)) + y_train = np.random.randint(0, 2, (data_size,)) + x_test = np.random.randint(0, 100, (data_size,)) + y_test = np.random.randint(0, 2, (data_size,)) + + vocabulary_size = 100 + embedding_size = 32 + epochs = 1 + batch_size = data_size + with self.assertRaises(ValueError): + imdb.train_model( + 'nonsensical_model_type', vocabulary_size, embedding_size, + x_train, y_train, x_test, y_test, + epochs, batch_size) + + +if __name__ == '__main__': + unittest.main() diff --git a/sentiment/serve.sh b/sentiment/serve.sh new file mode 100755 index 000000000..e5ede63ed --- /dev/null +++ b/sentiment/serve.sh @@ -0,0 +1,42 @@ +#!/usr/bin/env bash + +# Copyright 2018 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +# This script starts two HTTP servers on different ports: +# * Port 1234 (using parcel) serves HTML and JavaScript. +# * Port 1235 (using http-server) serves pretrained model resources. +# +# The reason for this arrangement is that Parcel currently has a limitation that +# prevents it from serving the pretrained models; see +# https://github.com/parcel-bundler/parcel/issues/1098. Once that issue is +# resolved, a single Parcel server will be sufficient. + +NODE_ENV=development +RESOURCE_PORT=1235 + +# Ensure that http-server is available +yarn + +echo Starting the pretrained model server... +node_modules/http-server/bin/http-server dist --cors -p "${RESOURCE_PORT}" > /dev/null & HTTP_SERVER_PID=$! + +echo Starting the example html/js server... +# This uses port 1234 by default. +node_modules/parcel-bundler/bin/cli.js serve -d dist --open --no-hmr --public-url / index.html + +# When the Parcel server exits, kill the http-server too. +kill $HTTP_SERVER_PID + diff --git a/sentiment/ui.js b/sentiment/ui.js index d037cc57c..52158e29d 100644 --- a/sentiment/ui.js +++ b/sentiment/ui.js @@ -23,6 +23,7 @@ const exampleReviews = { }; export function status(statusText) { + console.log(statusText); document.getElementById('status').textContent = statusText; } @@ -36,6 +37,7 @@ export function showMetadata(sentimentMetadataJSON) { } export function prepUI(predict) { + setPredictFunction(predict); const testExampleSelect = document.getElementById('test-example-select'); testExampleSelect.addEventListener('change', () => { setReviewText(exampleReviews[testExampleSelect.value], predict); @@ -62,7 +64,12 @@ function setReviewText(text, predict) { doPredict(predict); } -export function setPredictFunction(predict) { +function setPredictFunction(predict) { const reviewText = document.getElementById('review-text'); reviewText.addEventListener('input', () => doPredict(predict)); } + +export function disableLoadModelButtons() { + document.getElementById('load-pretrained-remote').style.display = 'none'; + document.getElementById('load-pretrained-local').style.display = 'none'; +} diff --git a/sentiment/yarn.lock b/sentiment/yarn.lock index 2e61d4f3d..87bfdaf3c 100644 --- a/sentiment/yarn.lock +++ b/sentiment/yarn.lock @@ -705,6 +705,10 @@ binary-extensions@^1.0.0: version "1.11.0" resolved "https://registry.yarnpkg.com/binary-extensions/-/binary-extensions-1.11.0.tgz#46aa1751fb6a2f93ee5e689bb1087d4b14c6c205" +bindings@~1.2.1: + version "1.2.1" + resolved "https://registry.yarnpkg.com/bindings/-/bindings-1.2.1.tgz#14ad6113812d2d37d72e67b4cacb4bb726505f11" + block-stream@*: version "0.0.9" resolved "https://registry.yarnpkg.com/block-stream/-/block-stream-0.0.9.tgz#13ebfe778a03205cfe03751481ebb4b3300c126a" @@ -766,12 +770,6 @@ brorand@^1.0.1: version "1.1.0" resolved "https://registry.yarnpkg.com/brorand/-/brorand-1.1.0.tgz#12c25efe40a45e3c323eb8675a0a0ce57b22371f" -browser-resolve@^1.11.2: - version "1.11.2" - resolved "https://registry.yarnpkg.com/browser-resolve/-/browser-resolve-1.11.2.tgz#8ff09b0a2c421718a1051c260b32e48f442938ce" - dependencies: - resolve "1.1.7" - browserify-aes@^1.0.0, browserify-aes@^1.0.4: version "1.1.1" resolved "https://registry.yarnpkg.com/browserify-aes/-/browserify-aes-1.1.1.tgz#38b7ab55edb806ff2dcda1a7f1620773a477c49f" @@ -1066,6 +1064,10 @@ colormin@^1.0.5: css-color-names "0.0.4" has "^1.0.1" +colors@1.0.3: + version "1.0.3" + resolved "https://registry.yarnpkg.com/colors/-/colors-1.0.3.tgz#0433f44d809680fdeb60ed260f1b0c262e82a40b" + colors@~1.1.2: version "1.1.2" resolved "https://registry.yarnpkg.com/colors/-/colors-1.1.2.tgz#168a4701756b6a7f51a12ce0c97bfa28c084ed63" @@ -1145,6 +1147,10 @@ core-util-is@1.0.2, core-util-is@~1.0.0: version "1.0.2" resolved "https://registry.yarnpkg.com/core-util-is/-/core-util-is-1.0.2.tgz#b5fd54220aa2bc5ab57aab7140c940754503c1a7" +corser@~2.0.0: + version "2.0.1" + resolved "https://registry.yarnpkg.com/corser/-/corser-2.0.1.tgz#8eda252ecaab5840dcd975ceb90d9370c819ff87" + create-ecdh@^4.0.0: version "4.0.0" resolved "https://registry.yarnpkg.com/create-ecdh/-/create-ecdh-4.0.0.tgz#888c723596cdf7612f6498233eebd7a35301737d" @@ -1433,6 +1439,13 @@ date-now@^0.1.4: version "0.1.4" resolved "https://registry.yarnpkg.com/date-now/-/date-now-0.1.4.tgz#eaf439fd4d4848ad74e5cc7dbef200672b9e345b" +deasync@^0.1.12: + version "0.1.12" + resolved "https://registry.yarnpkg.com/deasync/-/deasync-0.1.12.tgz#0159492a4133ab301d6c778cf01e74e63b10e549" + dependencies: + bindings "~1.2.1" + nan "^2.0.7" + debug@2.6.9, debug@^2.2.0, debug@^2.3.3, debug@^2.6.8: version "2.6.9" resolved "https://registry.yarnpkg.com/debug/-/debug-2.6.9.tgz#5d128515df134ff327e90a4c93f4e077a536341f" @@ -1585,6 +1598,15 @@ ecc-jsbn@~0.1.1: dependencies: jsbn "~0.1.0" +ecstatic@^2.0.0: + version "2.2.1" + resolved "https://registry.yarnpkg.com/ecstatic/-/ecstatic-2.2.1.tgz#b5087fad439dd9dd49d31e18131454817fe87769" + dependencies: + he "^1.1.1" + mime "^1.2.11" + minimist "^1.1.0" + url-join "^2.0.2" + editorconfig@^0.13.2: version "0.13.3" resolved "https://registry.yarnpkg.com/editorconfig/-/editorconfig-0.13.3.tgz#e5219e587951d60958fd94ea9a9a008cdeff1b34" @@ -1696,6 +1718,10 @@ etag@~1.8.1: version "1.8.1" resolved "https://registry.yarnpkg.com/etag/-/etag-1.8.1.tgz#41ae2eeb65efa62268aebfea83ac7d79299b0887" +eventemitter3@1.x.x: + version "1.2.0" + resolved "https://registry.yarnpkg.com/eventemitter3/-/eventemitter3-1.2.0.tgz#1c86991d816ad1e504750e73874224ecf3bec508" + events@^1.0.0: version "1.1.1" resolved "https://registry.yarnpkg.com/events/-/events-1.1.1.tgz#9ebdb7635ad099c70dcc4c2a1f5004288e8bd924" @@ -2033,6 +2059,10 @@ hawk@3.1.3, hawk@~3.1.3: hoek "2.x.x" sntp "1.x.x" +he@^1.1.1: + version "1.1.1" + resolved "https://registry.yarnpkg.com/he/-/he-1.1.1.tgz#93410fd21b009735151f8868c2f271f3427e23fd" + hmac-drbg@^1.0.0: version "1.0.1" resolved "https://registry.yarnpkg.com/hmac-drbg/-/hmac-drbg-1.0.1.tgz#d2745701025a6c775a6c545793ed502fc0c649a1" @@ -2060,7 +2090,7 @@ html-comment-regex@^1.1.0: version "1.1.1" resolved "https://registry.yarnpkg.com/html-comment-regex/-/html-comment-regex-1.1.1.tgz#668b93776eaae55ebde8f3ad464b307a4963625e" -htmlnano@^0.1.6: +htmlnano@^0.1.7: version "0.1.7" resolved "https://registry.yarnpkg.com/htmlnano/-/htmlnano-0.1.7.tgz#1751937a05f122a3248dba1c63edf01f6d36cb84" dependencies: @@ -2091,6 +2121,26 @@ http-errors@~1.6.2: setprototypeof "1.0.3" statuses ">= 1.3.1 < 2" +http-proxy@^1.8.1: + version "1.16.2" + resolved "https://registry.yarnpkg.com/http-proxy/-/http-proxy-1.16.2.tgz#06dff292952bf64dbe8471fa9df73066d4f37742" + dependencies: + eventemitter3 "1.x.x" + requires-port "1.x.x" + +http-server@~0.10.0: + version "0.10.0" + resolved "https://registry.yarnpkg.com/http-server/-/http-server-0.10.0.tgz#b2a446b16a9db87ed3c622ba9beb1b085b1234a7" + dependencies: + colors "1.0.3" + corser "~2.0.0" + ecstatic "^2.0.0" + http-proxy "^1.8.1" + opener "~1.4.0" + optimist "0.6.x" + portfinder "^1.0.13" + union "~0.4.3" + http-signature@~1.1.0: version "1.1.1" resolved "https://registry.yarnpkg.com/http-signature/-/http-signature-1.1.1.tgz#df72e267066cd0ac67fb76adf8e134a8fbcf91bf" @@ -2610,6 +2660,10 @@ mime@1.4.1: version "1.4.1" resolved "https://registry.yarnpkg.com/mime/-/mime-1.4.1.tgz#121f9ebc49e3766f311a76e1fa1c8003c4b03aa6" +mime@^1.2.11: + version "1.6.0" + resolved "https://registry.yarnpkg.com/mime/-/mime-1.6.0.tgz#32cd9e5c64553bd58d19a568af452acff04981b1" + mimic-fn@^1.0.0: version "1.2.0" resolved "https://registry.yarnpkg.com/mimic-fn/-/mimic-fn-1.2.0.tgz#820c86a39334640e99516928bd03fca88057d022" @@ -2632,10 +2686,14 @@ minimist@0.0.8: version "0.0.8" resolved "https://registry.yarnpkg.com/minimist/-/minimist-0.0.8.tgz#857fcabfc3397d2625b8228262e86aa7a011b05d" -minimist@^1.1.3, minimist@^1.2.0: +minimist@^1.1.0, minimist@^1.1.3, minimist@^1.2.0: version "1.2.0" resolved "https://registry.yarnpkg.com/minimist/-/minimist-1.2.0.tgz#a35008b20f41383eec1fb914f4cd5df79a264284" +minimist@~0.0.1: + version "0.0.10" + resolved "https://registry.yarnpkg.com/minimist/-/minimist-0.0.10.tgz#de3f98543dbf96082be48ad1a0c7cda836301dcf" + mixin-deep@^1.2.0: version "1.3.1" resolved "https://registry.yarnpkg.com/mixin-deep/-/mixin-deep-1.3.1.tgz#a49e7268dce1a0d9698e45326c5626df3543d0fe" @@ -2643,7 +2701,7 @@ mixin-deep@^1.2.0: for-in "^1.0.2" is-extendable "^1.0.1" -"mkdirp@>=0.5 0", mkdirp@^0.5.1, mkdirp@~0.5.0, mkdirp@~0.5.1: +mkdirp@0.5.x, "mkdirp@>=0.5 0", mkdirp@^0.5.1, mkdirp@~0.5.0, mkdirp@~0.5.1: version "0.5.1" resolved "https://registry.yarnpkg.com/mkdirp/-/mkdirp-0.5.1.tgz#30057438eac6cf7f8c4767f38648d6697d75c903" dependencies: @@ -2653,7 +2711,7 @@ ms@2.0.0: version "2.0.0" resolved "https://registry.yarnpkg.com/ms/-/ms-2.0.0.tgz#5608aeadfc00be6c2901df5f9861788de0d597c8" -nan@^2.3.0, nan@^2.4.0: +nan@^2.0.7, nan@^2.3.0, nan@^2.4.0: version "2.10.0" resolved "https://registry.yarnpkg.com/nan/-/nan-2.10.0.tgz#96d0cd610ebd58d4b4de9cc0c6828cda99c7548f" @@ -2860,12 +2918,23 @@ once@^1.3.0, once@^1.3.3: dependencies: wrappy "1" +opener@~1.4.0: + version "1.4.3" + resolved "https://registry.yarnpkg.com/opener/-/opener-1.4.3.tgz#5c6da2c5d7e5831e8ffa3964950f8d6674ac90b8" + opn@^5.1.0: version "5.3.0" resolved "https://registry.yarnpkg.com/opn/-/opn-5.3.0.tgz#64871565c863875f052cfdf53d3e3cb5adb53b1c" dependencies: is-wsl "^1.1.0" +optimist@0.6.x: + version "0.6.1" + resolved "https://registry.yarnpkg.com/optimist/-/optimist-0.6.1.tgz#da3ea74686fa21a19a111c326e90eb15a0196686" + dependencies: + minimist "~0.0.1" + wordwrap "~0.0.2" + optionator@^0.8.1: version "0.8.2" resolved "https://registry.yarnpkg.com/optionator/-/optionator-0.8.2.tgz#364c5e409d3f4d6301d6c0b4c05bba50180aeb64" @@ -2938,9 +3007,9 @@ pako@~1.0.5: version "1.0.6" resolved "https://registry.yarnpkg.com/pako/-/pako-1.0.6.tgz#0101211baa70c4bca4a0f63f2206e97b7dfaf258" -parcel-bundler@~1.6.2: - version "1.6.2" - resolved "https://registry.yarnpkg.com/parcel-bundler/-/parcel-bundler-1.6.2.tgz#e415c4993b6f4c48cd410427594ce80bbb7d90f1" +parcel-bundler@~1.7.0: + version "1.7.0" + resolved "https://registry.yarnpkg.com/parcel-bundler/-/parcel-bundler-1.7.0.tgz#8c2512615fd602d2f39bd97bfd128f8fe524b321" dependencies: babel-code-frame "^6.26.0" babel-core "^6.25.0" @@ -2953,7 +3022,6 @@ parcel-bundler@~1.6.2: babel-types "^6.26.0" babylon "^6.17.4" babylon-walk "^1.0.2" - browser-resolve "^1.11.2" browserslist "^2.11.2" chalk "^2.1.0" chokidar "^2.0.1" @@ -2961,12 +3029,13 @@ parcel-bundler@~1.6.2: commander "^2.11.0" cross-spawn "^6.0.4" cssnano "^3.10.0" + deasync "^0.1.12" dotenv "^5.0.0" filesize "^3.6.0" get-port "^3.2.0" glob "^7.1.2" grapheme-breaker "^0.3.2" - htmlnano "^0.1.6" + htmlnano "^0.1.7" is-url "^1.2.2" js-yaml "^3.10.0" json5 "^0.5.1" @@ -2976,13 +3045,12 @@ parcel-bundler@~1.6.2: node-libs-browser "^2.0.0" opn "^5.1.0" physical-cpu-count "^2.0.0" - postcss "^6.0.10" + postcss "^6.0.19" postcss-value-parser "^3.3.0" posthtml "^0.11.2" posthtml-parser "^0.4.0" posthtml-render "^1.1.0" resolve "^1.4.0" - sanitize-filename "^1.6.1" semver "^5.4.1" serialize-to-js "^1.1.1" serve-static "^1.12.4" @@ -3089,6 +3157,14 @@ pinkie@^2.0.0: version "2.0.4" resolved "https://registry.yarnpkg.com/pinkie/-/pinkie-2.0.4.tgz#72556b80cfa0d48a974e80e77248e80ed4f7f870" +portfinder@^1.0.13: + version "1.0.13" + resolved "https://registry.yarnpkg.com/portfinder/-/portfinder-1.0.13.tgz#bb32ecd87c27104ae6ee44b5a3ccbf0ebb1aede9" + dependencies: + async "^1.5.2" + debug "^2.2.0" + mkdirp "0.5.x" + posix-character-classes@^0.1.0: version "0.1.1" resolved "https://registry.yarnpkg.com/posix-character-classes/-/posix-character-classes-0.1.1.tgz#01eac0fe3b5af71a2a6c02feabb8c1fef7e00eab" @@ -3304,9 +3380,9 @@ postcss@^5.0.10, postcss@^5.0.11, postcss@^5.0.12, postcss@^5.0.13, postcss@^5.0 source-map "^0.5.6" supports-color "^3.2.3" -postcss@^6.0.10: - version "6.0.20" - resolved "https://registry.yarnpkg.com/postcss/-/postcss-6.0.20.tgz#686107e743a12d5530cb68438c590d5b2bf72c3c" +postcss@^6.0.19: + version "6.0.21" + resolved "https://registry.yarnpkg.com/postcss/-/postcss-6.0.21.tgz#8265662694eddf9e9a5960db6da33c39e4cd069d" dependencies: chalk "^2.3.2" source-map "^0.6.1" @@ -3397,6 +3473,10 @@ q@^1.1.2: version "1.5.1" resolved "https://registry.yarnpkg.com/q/-/q-1.5.1.tgz#7e32f75b41381291d04611f1bf14109ac00651d7" +qs@~2.3.3: + version "2.3.3" + resolved "https://registry.yarnpkg.com/qs/-/qs-2.3.3.tgz#e9e85adbe75da0bbe4c8e0476a086290f863b404" + qs@~6.4.0: version "6.4.0" resolved "https://registry.yarnpkg.com/qs/-/qs-6.4.0.tgz#13e26d28ad6b0ffaa91312cd3bf708ed351e7233" @@ -3598,14 +3678,14 @@ require-main-filename@^1.0.1: version "1.0.1" resolved "https://registry.yarnpkg.com/require-main-filename/-/require-main-filename-1.0.1.tgz#97f717b69d48784f5f526a6c5aa8ffdda055a4d1" +requires-port@1.x.x: + version "1.0.0" + resolved "https://registry.yarnpkg.com/requires-port/-/requires-port-1.0.0.tgz#925d2601d39ac485e091cf0da5c6e694dc3dcaff" + resolve-url@^0.2.1: version "0.2.1" resolved "https://registry.yarnpkg.com/resolve-url/-/resolve-url-0.2.1.tgz#2c637fe77c893afd2a663fe21aa9080068e2052a" -resolve@1.1.7: - version "1.1.7" - resolved "https://registry.yarnpkg.com/resolve/-/resolve-1.1.7.tgz#203114d82ad2c5ed9e8e0411b3932875e889e97b" - resolve@^1.1.5, resolve@^1.1.6, resolve@^1.4.0: version "1.6.0" resolved "https://registry.yarnpkg.com/resolve/-/resolve-1.6.0.tgz#0fbd21278b27b4004481c395349e7aba60a9ff5c" @@ -3649,12 +3729,6 @@ safer-eval@^1.2.3: dependencies: clones "^1.1.0" -sanitize-filename@^1.6.1: - version "1.6.1" - resolved "https://registry.yarnpkg.com/sanitize-filename/-/sanitize-filename-1.6.1.tgz#612da1c96473fa02dccda92dcd5b4ab164a6772a" - dependencies: - truncate-utf8-bytes "^1.0.0" - sax@~1.2.1, sax@~1.2.4: version "1.2.4" resolved "https://registry.yarnpkg.com/sax/-/sax-1.2.4.tgz#2816234e2378bddc4e5354fab5caa895df7100d9" @@ -4138,12 +4212,6 @@ trim-right@^1.0.1: version "1.0.1" resolved "https://registry.yarnpkg.com/trim-right/-/trim-right-1.0.1.tgz#cb2e1203067e0c8de1f614094b9fe45704ea6003" -truncate-utf8-bytes@^1.0.0: - version "1.0.2" - resolved "https://registry.yarnpkg.com/truncate-utf8-bytes/-/truncate-utf8-bytes-1.0.2.tgz#405923909592d56f78a5818434b0b78489ca5f2b" - dependencies: - utf8-byte-length "^1.0.1" - tslib@^1.9.0: version "1.9.0" resolved "https://registry.yarnpkg.com/tslib/-/tslib-1.9.0.tgz#e37a86fda8cbbaf23a057f473c9f4dc64e5fc2e8" @@ -4210,6 +4278,12 @@ union-value@^1.0.0: is-extendable "^0.1.1" set-value "^0.4.3" +union@~0.4.3: + version "0.4.6" + resolved "https://registry.yarnpkg.com/union/-/union-0.4.6.tgz#198fbdaeba254e788b0efcb630bc11f24a2959e0" + dependencies: + qs "~2.3.3" + uniq@^1.0.1: version "1.0.1" resolved "https://registry.yarnpkg.com/uniq/-/uniq-1.0.1.tgz#b31c5ae8254844a3a8281541ce2b04b865a734ff" @@ -4243,6 +4317,10 @@ urix@^0.1.0: version "0.1.0" resolved "https://registry.yarnpkg.com/urix/-/urix-0.1.0.tgz#da937f7a62e21fec1fd18d49b35c2935067a6c72" +url-join@^2.0.2: + version "2.0.5" + resolved "https://registry.yarnpkg.com/url-join/-/url-join-2.0.5.tgz#5af22f18c052a000a48d7b82c5e9c2e2feeda728" + url@^0.11.0: version "0.11.0" resolved "https://registry.yarnpkg.com/url/-/url-0.11.0.tgz#3838e97cfc60521eb73c525a8e55bfdd9e2e28f1" @@ -4256,10 +4334,6 @@ use@^3.1.0: dependencies: kind-of "^6.0.2" -utf8-byte-length@^1.0.1: - version "1.0.4" - resolved "https://registry.yarnpkg.com/utf8-byte-length/-/utf8-byte-length-1.0.4.tgz#f45f150c4c66eee968186505ab93fcbb8ad6bf61" - util-deprecate@~1.0.1: version "1.0.2" resolved "https://registry.yarnpkg.com/util-deprecate/-/util-deprecate-1.0.2.tgz#450d4dc9fa70de732762fbd2d4a28981419a0ccf" @@ -4589,6 +4663,10 @@ window-size@^0.2.0: version "0.2.0" resolved "https://registry.yarnpkg.com/window-size/-/window-size-0.2.0.tgz#b4315bb4214a3d7058ebeee892e13fa24d98b075" +wordwrap@~0.0.2: + version "0.0.3" + resolved "https://registry.yarnpkg.com/wordwrap/-/wordwrap-0.0.3.tgz#a3d5da6cd5c0bc0008d37234bbaf1bed63059107" + wordwrap@~1.0.0: version "1.0.0" resolved "https://registry.yarnpkg.com/wordwrap/-/wordwrap-1.0.0.tgz#27584810891456a4171c8d0226441ade90cbcaeb" diff --git a/translation/build-resources.sh b/translation/build-resources.sh new file mode 100755 index 000000000..162a363b4 --- /dev/null +++ b/translation/build-resources.sh @@ -0,0 +1,84 @@ +#!/usr/bin/env bash + +# Copyright 2018 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +# Builds resources for the sequence-to-sequence English-French translation demo. +# Note this is not necessary to run the demo, because we already provide hosted +# pre-built resources. +# Usage example: do this in the 'translation' directory: +# ./build.sh ~/ml-data/fra-eng/fra.txt +# +# You can specify the number of training epochs by using the --epochs flag. +# For example: +# ./build-resources.sh ~/ml-data/fra-eng/fra.txt --epochs 10 + +set -e + +DEMO_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +TRAIN_DATA_PATH="$1" +if [[ -z "${TRAIN_DATA_PATH}" ]]; then + echo "ERROR: TRAIN_DATA_PATH is not specified." + echo "You can download the training data with a command such as:" + echo " wget http://www.manythings.org/anki/fra-eng.zip" + exit 1 +fi +shift 1 + +if [[ ! -f ${TRAIN_DATA_PATH} ]]; then + echo "ERROR: Cannot find training data at path '${TRAIN_DATA_PATH}'" + exit 1 +fi + +TRAIN_EPOCHS=100 +while true; do + if [[ "$1" == "--epochs" ]]; then + TRAIN_EPOCHS=$2 + shift 2 + elif [[ -z "$1" ]]; then + break + else + echo "ERROR: Unrecognized argument: $1" + exit 1 + fi +done + +RESOURCES_ROOT="${DEMO_DIR}/dist/resources" +rm -rf "${RESOURCES_ROOT}" +mkdir -p "${RESOURCES_ROOT}" + +# Run Python script to generate the pretrained model and weights files. +# Make sure you install the tensorflowjs pip package first. + +python "${DEMO_DIR}/python/translation.py" \ + "${TRAIN_DATA_PATH}" \ + --recurrent_initializer glorot_uniform \ + --artifacts_dir "${RESOURCES_ROOT}" \ + --epochs "${TRAIN_EPOCHS}" +# TODO(cais): This --recurrent_initializer is a workaround for the limitation +# in TensorFlow.js Layers that the default recurrent initializer "Orthogonal" is +# currently not supported. Remove this once "Orthogonal" becomes available. + +cd ${DEMO_DIR} +yarn +yarn build + +echo +echo "-----------------------------------------------------------" +echo "Resources written to ${RESOURCES_ROOT}." +echo "You can now run the demo with 'yarn watch'." +echo "-----------------------------------------------------------" +echo diff --git a/translation/index.html b/translation/index.html index 2e0cb35f5..e42756c8e 100644 --- a/translation/index.html +++ b/translation/index.html @@ -36,6 +36,10 @@

TensorFlow.js Layers: Sequence-to-Sequence (English-French Translation) Demo

+
+ + +
Standing by.
diff --git a/translation/index.js b/translation/index.js index 0cf4c4a1d..810305259 100644 --- a/translation/index.js +++ b/translation/index.js @@ -19,18 +19,26 @@ import * as tf from '@tensorflow/tfjs'; import * as loader from './loader'; import * as ui from './ui'; -const HOSTED_MODEL_JSON_URL = - 'https://storage.googleapis.com/tfjs-models/tfjs/translation_en_fr_v1/model.json'; -const HOSTED_METADATA_JSON_URL = - 'https://storage.googleapis.com/tfjs-models/tfjs/translation_en_fr_v1/metadata.json'; +const HOSTED_URLS = { + model: + 'https://storage.googleapis.com/tfjs-models/tfjs/translation_en_fr_v1/model.json', + metadata: + 'https://storage.googleapis.com/tfjs-models/tfjs/translation_en_fr_v1/metadata.json' +}; + +const LOCAL_URLS = { + model: 'http://localhost:1235/resources/model.json', + metadata: 'http://localhost:1235/resources/metadata.json' +}; class Translator { /** * Initializes the Translation demo. */ - async init() { - const model = await loader.loadHostedPretrainedModel(HOSTED_MODEL_JSON_URL); + async init(urls) { + this.urls = urls; + const model = await loader.loadHostedPretrainedModel(urls.model); await this.loadMetadata(); this.prepareEncoderModel(model); this.prepareDecoderModel(model); @@ -39,7 +47,7 @@ class Translator { async loadMetadata() { const translationMetadata = - await loader.loadHostedMetadata(HOSTED_METADATA_JSON_URL); + await loader.loadHostedMetadata(this.urls.metadata); this.maxDecoderSeqLength = translationMetadata['max_decoder_seq_length']; this.maxEncoderSeqLength = translationMetadata['max_encoder_seq_length']; console.log('maxDecoderSeqLength = ' + this.maxDecoderSeqLength); @@ -180,9 +188,29 @@ class Translator { * function with the UI. */ async function setupTranslator() { - const translator = await new Translator().init(); - ui.setTranslationFunction(x => translator.translate(x)); - ui.setEnglish('Go.', x => translator.translate(x)); + if (await loader.urlExists(HOSTED_URLS.model)) { + ui.status('Model available: ' + HOSTED_URLS.model); + const button = document.getElementById('load-pretrained-remote'); + button.addEventListener('click', async () => { + const translator = await new Translator().init(HOSTED_URLS); + ui.setTranslationFunction(x => translator.translate(x)); + ui.setEnglish('Go.', x => translator.translate(x)); + }); + button.style.display = 'inline-block'; + } + + if (await loader.urlExists(LOCAL_URLS.model)) { + ui.status('Model available: ' + LOCAL_URLS.model); + const button = document.getElementById('load-pretrained-local'); + button.addEventListener('click', async () => { + const translator = await new Translator().init(LOCAL_URLS); + ui.setTranslationFunction(x => translator.translate(x)); + ui.setEnglish('Go.', x => translator.translate(x)); + }); + button.style.display = 'inline-block'; + } + + ui.status('Standing by.'); } setupTranslator(); diff --git a/translation/loader.js b/translation/loader.js index 02390e5ca..263729b23 100644 --- a/translation/loader.js +++ b/translation/loader.js @@ -16,7 +16,20 @@ */ import * as tf from '@tensorflow/tfjs'; -import {status} from './ui'; +import * as ui from './ui'; + +/** + * Test whether a given URL is retrievable. + */ +export async function urlExists(url) { + ui.status('Testing url ' + url); + try { + const response = await fetch(url, {method: 'HEAD'}); + return response.ok; + } catch (err) { + return false; + } +} /** * Load pretrained model stored at a remote URL. @@ -24,14 +37,18 @@ import {status} from './ui'; * @return An instance of `tf.Model` with model topology and weights loaded. */ export async function loadHostedPretrainedModel(url) { - status('Loading pretrained model from ' + url); + ui.status('Loading pretrained model from ' + url); try { const model = await tf.loadModel(url); - status('Done loading pretrained model.'); + ui.status('Done loading pretrained model.'); + // We can't load a model twice due to + // https://github.com/tensorflow/tfjs/issues/34 + // Therefore we remove the load buttons to avoid user confusion. + ui.disableLoadModelButtons(); return model; } catch (err) { - console.log(err); - status('Loading pretrained model failed.'); + console.error(err); + ui.status('Loading pretrained model failed.'); } } @@ -41,14 +58,14 @@ export async function loadHostedPretrainedModel(url) { * @return An object containing metadata as key-value pairs. */ export async function loadHostedMetadata(url) { - status('Loading metadata from ' + url); + ui.status('Loading metadata from ' + url); try { const metadataJson = await fetch(url); const metadata = await metadataJson.json(); - status('Done loading metadata.'); + ui.status('Done loading metadata.'); return metadata; } catch (err) { - console.log(err); - status('Loading metadata failed.'); + console.error(err); + ui.status('Loading metadata failed.'); } } diff --git a/translation/package.json b/translation/package.json index 094621de0..ff5bd5f57 100644 --- a/translation/package.json +++ b/translation/package.json @@ -13,7 +13,7 @@ "vega-embed": "^3.0.0" }, "scripts": { - "watch": "NODE_ENV=development parcel --no-hmr --open index.html ", + "watch": "./serve.sh", "build": "NODE_ENV=production parcel build index.html --no-minify --public-url ./" }, "devDependencies": { @@ -21,7 +21,8 @@ "babel-polyfill": "~6.26.0", "babel-preset-env": "~1.6.1", "clang-format": "~1.2.2", - "parcel-bundler": "~1.6.2" + "http-server": "~0.10.0", + "parcel-bundler": "~1.7.0" }, "babel": { "presets": [ diff --git a/translation/python/__init__.py b/translation/python/__init__.py new file mode 100644 index 000000000..636b70f0d --- /dev/null +++ b/translation/python/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2018 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function diff --git a/translation/python/translation.py b/translation/python/translation.py new file mode 100644 index 000000000..650d6c174 --- /dev/null +++ b/translation/python/translation.py @@ -0,0 +1,344 @@ +# Copyright 2018 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +"""Train a simple LSTM model for character-level language translation. + +This is based on the Keras example at: + https://github.com/keras-team/keras/blob/master/examples/lstm_seq2seq.py + +The training data can be downloaded with a command like the following example: + wget http://www.manythings.org/anki/fra-eng.zip +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import io +import json +import os + +from keras.models import Model +from keras.layers import Input, LSTM, Dense +import numpy as np +import tensorflowjs as tfjs + + +def read_data(): + # Vectorize the data. + input_texts = [] + target_texts = [] + input_characters = set() + target_characters = set() + lines = io.open(FLAGS.data_path, 'r', encoding='utf-8').read().split('\n') + for line in lines[: min(FLAGS.num_samples, len(lines) - 1)]: + input_text, target_text = line.split('\t') + # We use "tab" as the "start sequence" character for the targets, and "\n" + # as "end sequence" character. + target_text = '\t' + target_text + '\n' + input_texts.append(input_text) + target_texts.append(target_text) + for char in input_text: + if char not in input_characters: + input_characters.add(char) + for char in target_text: + if char not in target_characters: + target_characters.add(char) + + input_characters = sorted(list(input_characters)) + target_characters = sorted(list(target_characters)) + num_encoder_tokens = len(input_characters) + num_decoder_tokens = len(target_characters) + max_encoder_seq_length = max([len(txt) for txt in input_texts]) + max_decoder_seq_length = max([len(txt) for txt in target_texts]) + + print('Number of samples:', len(input_texts)) + print('Number of unique input tokens:', num_encoder_tokens) + print('Number of unique output tokens:', num_decoder_tokens) + print('Max sequence length for inputs:', max_encoder_seq_length) + print('Max sequence length for outputs:', max_decoder_seq_length) + + input_token_index = dict( + [(char, i) for i, char in enumerate(input_characters)]) + target_token_index = dict( + [(char, i) for i, char in enumerate(target_characters)]) + + # Save the token indices to file. + metadata_json_path = os.path.join( + FLAGS.artifacts_dir, 'metadata.json') + if not os.path.isdir(os.path.dirname(metadata_json_path)): + os.makedirs(os.path.dirname(metadata_json_path)) + with io.open(metadata_json_path, 'w', encoding='utf-8') as f: + metadata = { + 'input_token_index': input_token_index, + 'target_token_index': target_token_index, + 'max_encoder_seq_length': max_encoder_seq_length, + 'max_decoder_seq_length': max_decoder_seq_length + } + f.write(json.dumps(metadata, ensure_ascii=False)) + print('Saved metadata at: %s' % metadata_json_path) + + encoder_input_data = np.zeros( + (len(input_texts), max_encoder_seq_length, num_encoder_tokens), + dtype='float32') + decoder_input_data = np.zeros( + (len(input_texts), max_decoder_seq_length, num_decoder_tokens), + dtype='float32') + decoder_target_data = np.zeros( + (len(input_texts), max_decoder_seq_length, num_decoder_tokens), + dtype='float32') + + for i, (input_text, target_text) in enumerate(zip(input_texts, target_texts)): + for t, char in enumerate(input_text): + encoder_input_data[i, t, input_token_index[char]] = 1. + for t, char in enumerate(target_text): + # decoder_target_data is ahead of decoder_input_data by one timestep + decoder_input_data[i, t, target_token_index[char]] = 1. + if t > 0: + # decoder_target_data will be ahead by one timestep + # and will not include the start character. + decoder_target_data[i, t - 1, target_token_index[char]] = 1. + + return (input_texts, max_encoder_seq_length, max_decoder_seq_length, + num_encoder_tokens, num_decoder_tokens, + input_token_index, target_token_index, + encoder_input_data, decoder_input_data, decoder_target_data) + + +def seq2seq_model(num_encoder_tokens, num_decoder_tokens, latent_dim): + """Create a Keras model for the seq2seq translation. + + Args: + num_encoder_tokens: Total number of distinct tokens in the inputs + to the encoder. + num_decoder_tokens: Total number of distinct tokens in the outputs + to/from the decoder + latent_dim: Number of latent dimensions in the LSTMs. + + Returns: + encoder_inputs: Instance of `keras.Input`, symbolic tensor as input to + the encoder LSTM. + encoder_states: Instance of `keras.Input`, symbolic tensor for output + states (h and c) from the encoder LSTM. + decoder_inputs: Instance of `keras.Input`, symbolic tensor as input to + the decoder LSTM. + decoder_lstm: `keras.Layer` instance, the decoder LSTM. + decoder_dense: `keras.Layer` instance, the Dense layer in the decoder. + model: `keras.Model` instance, the entire translation model that can be + used in training. + """ + # Define an input sequence and process it. + encoder_inputs = Input(shape=(None, num_encoder_tokens)) + encoder = LSTM(latent_dim, + return_state=True, + recurrent_initializer=FLAGS.recurrent_initializer) + _, state_h, state_c = encoder(encoder_inputs) + # We discard `encoder_outputs` and only keep the states. + encoder_states = [state_h, state_c] + + # Set up the decoder, using `encoder_states` as initial state. + decoder_inputs = Input(shape=(None, num_decoder_tokens)) + # We set up our decoder to return full output sequences, + # and to return internal states as well. We don't use the + # return states in the training model, but we will use them in inference. + decoder_lstm = LSTM(FLAGS.latent_dim, + return_sequences=True, + return_state=True, + recurrent_initializer=FLAGS.recurrent_initializer) + decoder_outputs, _, _ = decoder_lstm(decoder_inputs, + initial_state=encoder_states) + decoder_dense = Dense(num_decoder_tokens, activation='softmax') + decoder_outputs = decoder_dense(decoder_outputs) + + # Define the model that will turn + # `encoder_input_data` & `decoder_input_data` into `decoder_target_data` + model = Model([encoder_inputs, decoder_inputs], decoder_outputs) + return (encoder_inputs, encoder_states, decoder_inputs, decoder_lstm, + decoder_dense, model) + + +def decode_sequence(input_seq, + encoder_model, + decoder_model, + num_decoder_tokens, + target_begin_index, + reverse_target_char_index, + max_decoder_seq_length): + """Decode (i.e., translate) an encoded sentence. + + Args: + input_seq: A `numpy.ndarray` of shape + `(1, max_encoder_seq_length, num_encoder_tokens)`. + encoder_model: A `keras.Model` instance for the encoder. + decoder_model: A `keras.Model` instance for the decoder. + num_decoder_tokens: Number of unique tokens for the decoder. + target_begin_index: An `int`: the index for the beginning token of the + decoder. + reverse_target_char_index: A lookup table for the target characters, i.e., + a map from `int` index to target character. + max_decoder_seq_length: Maximum allowed sequence length output by the + decoder. + + Returns: + The result of the decoding (i.e., translation) as a string. + """ + + # Encode the input as state vectors. + states_value = encoder_model.predict(input_seq) + + # Generate empty target sequence of length 1. + target_seq = np.zeros((1, 1, num_decoder_tokens)) + # Populate the first character of target sequence with the start character. + target_seq[0, 0, target_begin_index] = 1. + + # Sampling loop for a batch of sequences + # (to simplify, here we assume a batch of size 1). + stop_condition = False + decoded_sentence = '' + while not stop_condition: + output_tokens, h, c = decoder_model.predict( + [target_seq] + states_value) + + # Sample a token + sampled_token_index = np.argmax(output_tokens[0, -1, :]) + sampled_char = reverse_target_char_index[sampled_token_index] + decoded_sentence += sampled_char + + # Exit condition: either hit max length + # or find stop character. + if (sampled_char == '\n' or + len(decoded_sentence) > max_decoder_seq_length): + stop_condition = True + + # Update the target sequence (of length 1). + target_seq = np.zeros((1, 1, num_decoder_tokens)) + target_seq[0, 0, sampled_token_index] = 1. + + # Update states + states_value = [h, c] + + return decoded_sentence + + +def main(): + (input_texts, _, max_decoder_seq_length, + num_encoder_tokens, num_decoder_tokens, + __, target_token_index, + encoder_input_data, decoder_input_data, decoder_target_data) = read_data() + + (encoder_inputs, encoder_states, decoder_inputs, decoder_lstm, + decoder_dense, model) = seq2seq_model( + num_encoder_tokens, num_decoder_tokens, FLAGS.latent_dim) + + # Run training. + model.compile(optimizer='rmsprop', loss='categorical_crossentropy') + model.fit([encoder_input_data, decoder_input_data], decoder_target_data, + batch_size=FLAGS.batch_size, + epochs=FLAGS.epochs, + validation_split=0.2) + + tfjs.converters.save_keras_model(model, FLAGS.artifacts_dir) + + # Next: inference mode (sampling). + # Here's the drill: + # 1) encode input and retrieve initial decoder state + # 2) run one step of decoder with this initial state + # and a "start of sequence" token as target. + # Output will be the next target token + # 3) Repeat with the current target token and current states + + # Define sampling models + encoder_model = Model(encoder_inputs, encoder_states) + + decoder_state_input_h = Input(shape=(FLAGS.latent_dim,)) + decoder_state_input_c = Input(shape=(FLAGS.latent_dim,)) + decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c] + decoder_outputs, state_h, state_c = decoder_lstm( + decoder_inputs, initial_state=decoder_states_inputs) + decoder_states = [state_h, state_c] + decoder_outputs = decoder_dense(decoder_outputs) + decoder_model = Model( + [decoder_inputs] + decoder_states_inputs, + [decoder_outputs] + decoder_states) + + # Reverse-lookup token index to decode sequences back to + # something readable. + reverse_target_char_index = dict( + (i, char) for char, i in target_token_index.items()) + + target_begin_index = target_token_index['\t'] + + for seq_index in range(FLAGS.num_test_sentences): + # Take one sequence (part of the training set) + # for trying out decoding. + input_seq = encoder_input_data[seq_index: seq_index + 1] + decoded_sentence = decode_sequence( + input_seq, encoder_model, decoder_model, num_decoder_tokens, + target_begin_index, reverse_target_char_index, max_decoder_seq_length) + print('-') + print('Input sentence:', input_texts[seq_index]) + print('Decoded sentence:', decoded_sentence) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + 'Keras seq2seq translation model training and serialization') + parser.add_argument( + 'data_path', + type=str, + help='Path to the training data, e.g., ~/ml-data/fra-eng/fra.txt') + parser.add_argument( + '--batch_size', + type=int, + default=64, + help='Training batch size.') + parser.add_argument( + '--epochs', + type=int, + default=100, + help='Number of training epochs.') + parser.add_argument( + '--latent_dim', + type=int, + default=256, + help='Latent dimensionality of the encoding space.') + parser.add_argument( + '--num_samples', + type=int, + default=10000, + help='Number of samples to train on.') + parser.add_argument( + '--num_test_sentences', + type=int, + default=100, + help='Number of example sentences to test at the end of the training.') + # TODO(cais): This is a workaround for the limitation in TF.js Layers that the + # default recurrent initializer "Orthogonal" is currently not supported. + # Remove this once "Orthogonal" becomes available. + parser.add_argument( + '--recurrent_initializer', + type=str, + default='orthogonal', + help='Custom initializer for recurrent kernels of LSTMs (e.g., ' + 'glorot_uniform)') + parser.add_argument( + '--artifacts_dir', + type=str, + default='/tmp/translation.keras', + help='Local path for saving the TensorFlow.js artifacts.') + + FLAGS, _ = parser.parse_known_args() + main() diff --git a/translation/serve.sh b/translation/serve.sh new file mode 100755 index 000000000..e5ede63ed --- /dev/null +++ b/translation/serve.sh @@ -0,0 +1,42 @@ +#!/usr/bin/env bash + +# Copyright 2018 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +# This script starts two HTTP servers on different ports: +# * Port 1234 (using parcel) serves HTML and JavaScript. +# * Port 1235 (using http-server) serves pretrained model resources. +# +# The reason for this arrangement is that Parcel currently has a limitation that +# prevents it from serving the pretrained models; see +# https://github.com/parcel-bundler/parcel/issues/1098. Once that issue is +# resolved, a single Parcel server will be sufficient. + +NODE_ENV=development +RESOURCE_PORT=1235 + +# Ensure that http-server is available +yarn + +echo Starting the pretrained model server... +node_modules/http-server/bin/http-server dist --cors -p "${RESOURCE_PORT}" > /dev/null & HTTP_SERVER_PID=$! + +echo Starting the example html/js server... +# This uses port 1234 by default. +node_modules/parcel-bundler/bin/cli.js serve -d dist --open --no-hmr --public-url / index.html + +# When the Parcel server exits, kill the http-server too. +kill $HTTP_SERVER_PID + diff --git a/translation/ui.js b/translation/ui.js index 47b8199aa..f696defb9 100644 --- a/translation/ui.js +++ b/translation/ui.js @@ -16,6 +16,7 @@ */ export function status(statusText) { + console.log(statusText); document.getElementById('status').textContent = statusText; } @@ -32,3 +33,8 @@ export function setTranslationFunction(translate) { translate(inputSentence); }); } + +export function disableLoadModelButtons() { + document.getElementById('load-pretrained-remote').style.display = 'none'; + document.getElementById('load-pretrained-local').style.display = 'none'; +} diff --git a/translation/yarn.lock b/translation/yarn.lock index 094ed0409..e20f9abce 100644 --- a/translation/yarn.lock +++ b/translation/yarn.lock @@ -705,6 +705,10 @@ binary-extensions@^1.0.0: version "1.11.0" resolved "https://registry.yarnpkg.com/binary-extensions/-/binary-extensions-1.11.0.tgz#46aa1751fb6a2f93ee5e689bb1087d4b14c6c205" +bindings@~1.2.1: + version "1.2.1" + resolved "https://registry.yarnpkg.com/bindings/-/bindings-1.2.1.tgz#14ad6113812d2d37d72e67b4cacb4bb726505f11" + block-stream@*: version "0.0.9" resolved "https://registry.yarnpkg.com/block-stream/-/block-stream-0.0.9.tgz#13ebfe778a03205cfe03751481ebb4b3300c126a" @@ -766,12 +770,6 @@ brorand@^1.0.1: version "1.1.0" resolved "https://registry.yarnpkg.com/brorand/-/brorand-1.1.0.tgz#12c25efe40a45e3c323eb8675a0a0ce57b22371f" -browser-resolve@^1.11.2: - version "1.11.2" - resolved "https://registry.yarnpkg.com/browser-resolve/-/browser-resolve-1.11.2.tgz#8ff09b0a2c421718a1051c260b32e48f442938ce" - dependencies: - resolve "1.1.7" - browserify-aes@^1.0.0, browserify-aes@^1.0.4: version "1.1.1" resolved "https://registry.yarnpkg.com/browserify-aes/-/browserify-aes-1.1.1.tgz#38b7ab55edb806ff2dcda1a7f1620773a477c49f" @@ -1042,6 +1040,10 @@ colormin@^1.0.5: css-color-names "0.0.4" has "^1.0.1" +colors@1.0.3: + version "1.0.3" + resolved "https://registry.yarnpkg.com/colors/-/colors-1.0.3.tgz#0433f44d809680fdeb60ed260f1b0c262e82a40b" + colors@~1.1.2: version "1.1.2" resolved "https://registry.yarnpkg.com/colors/-/colors-1.1.2.tgz#168a4701756b6a7f51a12ce0c97bfa28c084ed63" @@ -1122,6 +1124,10 @@ core-util-is@1.0.2, core-util-is@~1.0.0: version "1.0.2" resolved "https://registry.yarnpkg.com/core-util-is/-/core-util-is-1.0.2.tgz#b5fd54220aa2bc5ab57aab7140c940754503c1a7" +corser@~2.0.0: + version "2.0.1" + resolved "https://registry.yarnpkg.com/corser/-/corser-2.0.1.tgz#8eda252ecaab5840dcd975ceb90d9370c819ff87" + create-ecdh@^4.0.0: version "4.0.0" resolved "https://registry.yarnpkg.com/create-ecdh/-/create-ecdh-4.0.0.tgz#888c723596cdf7612f6498233eebd7a35301737d" @@ -1410,6 +1416,13 @@ date-now@^0.1.4: version "0.1.4" resolved "https://registry.yarnpkg.com/date-now/-/date-now-0.1.4.tgz#eaf439fd4d4848ad74e5cc7dbef200672b9e345b" +deasync@^0.1.12: + version "0.1.12" + resolved "https://registry.yarnpkg.com/deasync/-/deasync-0.1.12.tgz#0159492a4133ab301d6c778cf01e74e63b10e549" + dependencies: + bindings "~1.2.1" + nan "^2.0.7" + debug@2.6.9, debug@^2.2.0, debug@^2.3.3, debug@^2.6.8: version "2.6.9" resolved "https://registry.yarnpkg.com/debug/-/debug-2.6.9.tgz#5d128515df134ff327e90a4c93f4e077a536341f" @@ -1562,6 +1575,15 @@ ecc-jsbn@~0.1.1: dependencies: jsbn "~0.1.0" +ecstatic@^2.0.0: + version "2.2.1" + resolved "https://registry.yarnpkg.com/ecstatic/-/ecstatic-2.2.1.tgz#b5087fad439dd9dd49d31e18131454817fe87769" + dependencies: + he "^1.1.1" + mime "^1.2.11" + minimist "^1.1.0" + url-join "^2.0.2" + editorconfig@^0.13.2: version "0.13.3" resolved "https://registry.yarnpkg.com/editorconfig/-/editorconfig-0.13.3.tgz#e5219e587951d60958fd94ea9a9a008cdeff1b34" @@ -1667,6 +1689,10 @@ etag@~1.8.1: version "1.8.1" resolved "https://registry.yarnpkg.com/etag/-/etag-1.8.1.tgz#41ae2eeb65efa62268aebfea83ac7d79299b0887" +eventemitter3@1.x.x: + version "1.2.0" + resolved "https://registry.yarnpkg.com/eventemitter3/-/eventemitter3-1.2.0.tgz#1c86991d816ad1e504750e73874224ecf3bec508" + events@^1.0.0: version "1.1.1" resolved "https://registry.yarnpkg.com/events/-/events-1.1.1.tgz#9ebdb7635ad099c70dcc4c2a1f5004288e8bd924" @@ -1997,6 +2023,10 @@ hawk@3.1.3, hawk@~3.1.3: hoek "2.x.x" sntp "1.x.x" +he@^1.1.1: + version "1.1.1" + resolved "https://registry.yarnpkg.com/he/-/he-1.1.1.tgz#93410fd21b009735151f8868c2f271f3427e23fd" + hmac-drbg@^1.0.0: version "1.0.1" resolved "https://registry.yarnpkg.com/hmac-drbg/-/hmac-drbg-1.0.1.tgz#d2745701025a6c775a6c545793ed502fc0c649a1" @@ -2020,7 +2050,7 @@ html-comment-regex@^1.1.0: version "1.1.1" resolved "https://registry.yarnpkg.com/html-comment-regex/-/html-comment-regex-1.1.1.tgz#668b93776eaae55ebde8f3ad464b307a4963625e" -htmlnano@^0.1.6: +htmlnano@^0.1.7: version "0.1.7" resolved "https://registry.yarnpkg.com/htmlnano/-/htmlnano-0.1.7.tgz#1751937a05f122a3248dba1c63edf01f6d36cb84" dependencies: @@ -2051,6 +2081,26 @@ http-errors@~1.6.2: setprototypeof "1.0.3" statuses ">= 1.3.1 < 2" +http-proxy@^1.8.1: + version "1.16.2" + resolved "https://registry.yarnpkg.com/http-proxy/-/http-proxy-1.16.2.tgz#06dff292952bf64dbe8471fa9df73066d4f37742" + dependencies: + eventemitter3 "1.x.x" + requires-port "1.x.x" + +http-server@~0.10.0: + version "0.10.0" + resolved "https://registry.yarnpkg.com/http-server/-/http-server-0.10.0.tgz#b2a446b16a9db87ed3c622ba9beb1b085b1234a7" + dependencies: + colors "1.0.3" + corser "~2.0.0" + ecstatic "^2.0.0" + http-proxy "^1.8.1" + opener "~1.4.0" + optimist "0.6.x" + portfinder "^1.0.13" + union "~0.4.3" + http-signature@~1.1.0: version "1.1.1" resolved "https://registry.yarnpkg.com/http-signature/-/http-signature-1.1.1.tgz#df72e267066cd0ac67fb76adf8e134a8fbcf91bf" @@ -2542,6 +2592,10 @@ mime@1.4.1: version "1.4.1" resolved "https://registry.yarnpkg.com/mime/-/mime-1.4.1.tgz#121f9ebc49e3766f311a76e1fa1c8003c4b03aa6" +mime@^1.2.11: + version "1.6.0" + resolved "https://registry.yarnpkg.com/mime/-/mime-1.6.0.tgz#32cd9e5c64553bd58d19a568af452acff04981b1" + mimic-fn@^1.0.0: version "1.2.0" resolved "https://registry.yarnpkg.com/mimic-fn/-/mimic-fn-1.2.0.tgz#820c86a39334640e99516928bd03fca88057d022" @@ -2564,10 +2618,14 @@ minimist@0.0.8: version "0.0.8" resolved "https://registry.yarnpkg.com/minimist/-/minimist-0.0.8.tgz#857fcabfc3397d2625b8228262e86aa7a011b05d" -minimist@^1.1.3, minimist@^1.2.0: +minimist@^1.1.0, minimist@^1.1.3, minimist@^1.2.0: version "1.2.0" resolved "https://registry.yarnpkg.com/minimist/-/minimist-1.2.0.tgz#a35008b20f41383eec1fb914f4cd5df79a264284" +minimist@~0.0.1: + version "0.0.10" + resolved "https://registry.yarnpkg.com/minimist/-/minimist-0.0.10.tgz#de3f98543dbf96082be48ad1a0c7cda836301dcf" + mixin-deep@^1.2.0: version "1.3.1" resolved "https://registry.yarnpkg.com/mixin-deep/-/mixin-deep-1.3.1.tgz#a49e7268dce1a0d9698e45326c5626df3543d0fe" @@ -2575,7 +2633,7 @@ mixin-deep@^1.2.0: for-in "^1.0.2" is-extendable "^1.0.1" -"mkdirp@>=0.5 0", mkdirp@^0.5.1, mkdirp@~0.5.0, mkdirp@~0.5.1: +mkdirp@0.5.x, "mkdirp@>=0.5 0", mkdirp@^0.5.1, mkdirp@~0.5.0, mkdirp@~0.5.1: version "0.5.1" resolved "https://registry.yarnpkg.com/mkdirp/-/mkdirp-0.5.1.tgz#30057438eac6cf7f8c4767f38648d6697d75c903" dependencies: @@ -2585,7 +2643,7 @@ ms@2.0.0: version "2.0.0" resolved "https://registry.yarnpkg.com/ms/-/ms-2.0.0.tgz#5608aeadfc00be6c2901df5f9861788de0d597c8" -nan@^2.3.0: +nan@^2.0.7, nan@^2.3.0: version "2.10.0" resolved "https://registry.yarnpkg.com/nan/-/nan-2.10.0.tgz#96d0cd610ebd58d4b4de9cc0c6828cda99c7548f" @@ -2783,12 +2841,23 @@ once@^1.3.0, once@^1.3.3: dependencies: wrappy "1" +opener@~1.4.0: + version "1.4.3" + resolved "https://registry.yarnpkg.com/opener/-/opener-1.4.3.tgz#5c6da2c5d7e5831e8ffa3964950f8d6674ac90b8" + opn@^5.1.0: version "5.3.0" resolved "https://registry.yarnpkg.com/opn/-/opn-5.3.0.tgz#64871565c863875f052cfdf53d3e3cb5adb53b1c" dependencies: is-wsl "^1.1.0" +optimist@0.6.x: + version "0.6.1" + resolved "https://registry.yarnpkg.com/optimist/-/optimist-0.6.1.tgz#da3ea74686fa21a19a111c326e90eb15a0196686" + dependencies: + minimist "~0.0.1" + wordwrap "~0.0.2" + optionator@^0.8.1: version "0.8.2" resolved "https://registry.yarnpkg.com/optionator/-/optionator-0.8.2.tgz#364c5e409d3f4d6301d6c0b4c05bba50180aeb64" @@ -2855,9 +2924,9 @@ pako@~1.0.5: version "1.0.6" resolved "https://registry.yarnpkg.com/pako/-/pako-1.0.6.tgz#0101211baa70c4bca4a0f63f2206e97b7dfaf258" -parcel-bundler@~1.6.2: - version "1.6.2" - resolved "https://registry.yarnpkg.com/parcel-bundler/-/parcel-bundler-1.6.2.tgz#e415c4993b6f4c48cd410427594ce80bbb7d90f1" +parcel-bundler@~1.7.0: + version "1.7.0" + resolved "https://registry.yarnpkg.com/parcel-bundler/-/parcel-bundler-1.7.0.tgz#8c2512615fd602d2f39bd97bfd128f8fe524b321" dependencies: babel-code-frame "^6.26.0" babel-core "^6.25.0" @@ -2870,7 +2939,6 @@ parcel-bundler@~1.6.2: babel-types "^6.26.0" babylon "^6.17.4" babylon-walk "^1.0.2" - browser-resolve "^1.11.2" browserslist "^2.11.2" chalk "^2.1.0" chokidar "^2.0.1" @@ -2878,12 +2946,13 @@ parcel-bundler@~1.6.2: commander "^2.11.0" cross-spawn "^6.0.4" cssnano "^3.10.0" + deasync "^0.1.12" dotenv "^5.0.0" filesize "^3.6.0" get-port "^3.2.0" glob "^7.1.2" grapheme-breaker "^0.3.2" - htmlnano "^0.1.6" + htmlnano "^0.1.7" is-url "^1.2.2" js-yaml "^3.10.0" json5 "^0.5.1" @@ -2893,13 +2962,12 @@ parcel-bundler@~1.6.2: node-libs-browser "^2.0.0" opn "^5.1.0" physical-cpu-count "^2.0.0" - postcss "^6.0.10" + postcss "^6.0.19" postcss-value-parser "^3.3.0" posthtml "^0.11.2" posthtml-parser "^0.4.0" posthtml-render "^1.1.0" resolve "^1.4.0" - sanitize-filename "^1.6.1" semver "^5.4.1" serialize-to-js "^1.1.1" serve-static "^1.12.4" @@ -2972,6 +3040,14 @@ physical-cpu-count@^2.0.0: version "2.0.0" resolved "https://registry.yarnpkg.com/physical-cpu-count/-/physical-cpu-count-2.0.0.tgz#18de2f97e4bf7a9551ad7511942b5496f7aba660" +portfinder@^1.0.13: + version "1.0.13" + resolved "https://registry.yarnpkg.com/portfinder/-/portfinder-1.0.13.tgz#bb32ecd87c27104ae6ee44b5a3ccbf0ebb1aede9" + dependencies: + async "^1.5.2" + debug "^2.2.0" + mkdirp "0.5.x" + posix-character-classes@^0.1.0: version "0.1.1" resolved "https://registry.yarnpkg.com/posix-character-classes/-/posix-character-classes-0.1.1.tgz#01eac0fe3b5af71a2a6c02feabb8c1fef7e00eab" @@ -3187,9 +3263,9 @@ postcss@^5.0.10, postcss@^5.0.11, postcss@^5.0.12, postcss@^5.0.13, postcss@^5.0 source-map "^0.5.6" supports-color "^3.2.3" -postcss@^6.0.10: - version "6.0.20" - resolved "https://registry.yarnpkg.com/postcss/-/postcss-6.0.20.tgz#686107e743a12d5530cb68438c590d5b2bf72c3c" +postcss@^6.0.19: + version "6.0.21" + resolved "https://registry.yarnpkg.com/postcss/-/postcss-6.0.21.tgz#8265662694eddf9e9a5960db6da33c39e4cd069d" dependencies: chalk "^2.3.2" source-map "^0.6.1" @@ -3280,6 +3356,10 @@ q@^1.1.2: version "1.5.1" resolved "https://registry.yarnpkg.com/q/-/q-1.5.1.tgz#7e32f75b41381291d04611f1bf14109ac00651d7" +qs@~2.3.3: + version "2.3.3" + resolved "https://registry.yarnpkg.com/qs/-/qs-2.3.3.tgz#e9e85adbe75da0bbe4c8e0476a086290f863b404" + qs@~6.4.0: version "6.4.0" resolved "https://registry.yarnpkg.com/qs/-/qs-6.4.0.tgz#13e26d28ad6b0ffaa91312cd3bf708ed351e7233" @@ -3466,14 +3546,14 @@ require-main-filename@^1.0.1: version "1.0.1" resolved "https://registry.yarnpkg.com/require-main-filename/-/require-main-filename-1.0.1.tgz#97f717b69d48784f5f526a6c5aa8ffdda055a4d1" +requires-port@1.x.x: + version "1.0.0" + resolved "https://registry.yarnpkg.com/requires-port/-/requires-port-1.0.0.tgz#925d2601d39ac485e091cf0da5c6e694dc3dcaff" + resolve-url@^0.2.1: version "0.2.1" resolved "https://registry.yarnpkg.com/resolve-url/-/resolve-url-0.2.1.tgz#2c637fe77c893afd2a663fe21aa9080068e2052a" -resolve@1.1.7: - version "1.1.7" - resolved "https://registry.yarnpkg.com/resolve/-/resolve-1.1.7.tgz#203114d82ad2c5ed9e8e0411b3932875e889e97b" - resolve@^1.1.5, resolve@^1.1.6, resolve@^1.4.0: version "1.6.0" resolved "https://registry.yarnpkg.com/resolve/-/resolve-1.6.0.tgz#0fbd21278b27b4004481c395349e7aba60a9ff5c" @@ -3517,12 +3597,6 @@ safer-eval@^1.2.3: dependencies: clones "^1.1.0" -sanitize-filename@^1.6.1: - version "1.6.1" - resolved "https://registry.yarnpkg.com/sanitize-filename/-/sanitize-filename-1.6.1.tgz#612da1c96473fa02dccda92dcd5b4ab164a6772a" - dependencies: - truncate-utf8-bytes "^1.0.0" - sax@~1.2.1, sax@~1.2.4: version "1.2.4" resolved "https://registry.yarnpkg.com/sax/-/sax-1.2.4.tgz#2816234e2378bddc4e5354fab5caa895df7100d9" @@ -3978,12 +4052,6 @@ trim-right@^1.0.1: version "1.0.1" resolved "https://registry.yarnpkg.com/trim-right/-/trim-right-1.0.1.tgz#cb2e1203067e0c8de1f614094b9fe45704ea6003" -truncate-utf8-bytes@^1.0.0: - version "1.0.2" - resolved "https://registry.yarnpkg.com/truncate-utf8-bytes/-/truncate-utf8-bytes-1.0.2.tgz#405923909592d56f78a5818434b0b78489ca5f2b" - dependencies: - utf8-byte-length "^1.0.1" - tslib@^1.9.0: version "1.9.0" resolved "https://registry.yarnpkg.com/tslib/-/tslib-1.9.0.tgz#e37a86fda8cbbaf23a057f473c9f4dc64e5fc2e8" @@ -4050,6 +4118,12 @@ union-value@^1.0.0: is-extendable "^0.1.1" set-value "^0.4.3" +union@~0.4.3: + version "0.4.6" + resolved "https://registry.yarnpkg.com/union/-/union-0.4.6.tgz#198fbdaeba254e788b0efcb630bc11f24a2959e0" + dependencies: + qs "~2.3.3" + uniq@^1.0.1: version "1.0.1" resolved "https://registry.yarnpkg.com/uniq/-/uniq-1.0.1.tgz#b31c5ae8254844a3a8281541ce2b04b865a734ff" @@ -4083,6 +4157,10 @@ urix@^0.1.0: version "0.1.0" resolved "https://registry.yarnpkg.com/urix/-/urix-0.1.0.tgz#da937f7a62e21fec1fd18d49b35c2935067a6c72" +url-join@^2.0.2: + version "2.0.5" + resolved "https://registry.yarnpkg.com/url-join/-/url-join-2.0.5.tgz#5af22f18c052a000a48d7b82c5e9c2e2feeda728" + url@^0.11.0: version "0.11.0" resolved "https://registry.yarnpkg.com/url/-/url-0.11.0.tgz#3838e97cfc60521eb73c525a8e55bfdd9e2e28f1" @@ -4096,10 +4174,6 @@ use@^3.1.0: dependencies: kind-of "^6.0.2" -utf8-byte-length@^1.0.1: - version "1.0.4" - resolved "https://registry.yarnpkg.com/utf8-byte-length/-/utf8-byte-length-1.0.4.tgz#f45f150c4c66eee968186505ab93fcbb8ad6bf61" - util-deprecate@~1.0.1: version "1.0.2" resolved "https://registry.yarnpkg.com/util-deprecate/-/util-deprecate-1.0.2.tgz#450d4dc9fa70de732762fbd2d4a28981419a0ccf" @@ -4411,6 +4485,10 @@ wide-align@^1.1.0: dependencies: string-width "^1.0.2" +wordwrap@~0.0.2: + version "0.0.3" + resolved "https://registry.yarnpkg.com/wordwrap/-/wordwrap-0.0.3.tgz#a3d5da6cd5c0bc0008d37234bbaf1bed63059107" + wordwrap@~1.0.0: version "1.0.0" resolved "https://registry.yarnpkg.com/wordwrap/-/wordwrap-1.0.0.tgz#27584810891456a4171c8d0226441ade90cbcaeb"