Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move Python scripts for building examples into tfjs-examples #38

Merged
merged 11 commits into from
Mar 30, 2018
2 changes: 1 addition & 1 deletion deploy.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env bash
# Copyright 2018 Google Inc. All Rights Reserved.
# Copyright 2018 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
61 changes: 61 additions & 0 deletions iris/build-resources.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#!/usr/bin/env bash

# Copyright 2018 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

# Builds resources for the Iris demo.
# Note this is not necessary to run the demo, because we already provide hosted
# pre-built resources.
# Usage example: do this from the 'iris' directory:
# ./build-resources.sh

set -e

DEMO_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"

TRAIN_EPOCHS=100
while true; do
if [[ "$1" == "--epochs" ]]; then
TRAIN_EPOCHS=$2
shift 2
elif [[ -z "$1" ]]; then
break
else
echo "ERROR: Unrecognized argument: $1"
exit 1
fi
done

RESOURCES_ROOT="${DEMO_DIR}/dist/resources"
rm -rf "${RESOURCES_ROOT}"
mkdir -p "${RESOURCES_ROOT}"

# Run Python script to generate the pretrained model and weights files.
# Make sure you install the tensorflowjs pip package first.

python "${DEMO_DIR}/python/iris.py" \
--epochs "${TRAIN_EPOCHS}" \
--artifacts_dir "${RESOURCES_ROOT}"

cd ${DEMO_DIR}
yarn
yarn build

echo
echo "-----------------------------------------------------------"
echo "Resources written to ${RESOURCES_ROOT}."
echo "You can now run the demo with 'yarn watch'."
echo "-----------------------------------------------------------"
echo
3 changes: 2 additions & 1 deletion iris/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ <h1>TensorFlow.js Layers: Iris Demo</h1>
</div>

<div class="create-model">
<button id="load-pretrained">Load pretrained model</button>
<button id="load-pretrained-remote" style="display:none">Load hosted pretrained model</button>
<button id="load-pretrained-local" style="display:none">Load local pretrained model</button>
</div>

<div>
Expand Down
85 changes: 46 additions & 39 deletions iris/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,12 @@

import * as tf from '@tensorflow/tfjs';

import {getIrisData, IRIS_CLASSES, IRIS_NUM_CLASSES} from './data';
import {clearEvaluateTable, getManualInputData, loadTrainParametersFromUI, plotAccuracies, plotLosses, renderEvaluateTable, renderLogitsForManualInput, setManualInputWinnerMessage, status, wireUpEvaluateTableCallbacks} from './ui';
import * as data from './data';
import * as loader from './loader';
import * as ui from './ui';

let model;

/**
* Load pretrained model stored at a remote URL.
*
* @return An instance of `tf.Model` with model topology and weights loaded.
*/
async function loadHostedPretrainedModel() {
const HOSTED_MODEL_JSON_URL =
'https://storage.googleapis.com/tfjs-models/tfjs/iris_v1/model.json';
status('Loading pretrained model from ' + HOSTED_MODEL_JSON_URL);
try {
model = await tf.loadModel(HOSTED_MODEL_JSON_URL);
status('Done loading pretrained model.');
} catch (err) {
status('Loading pretrained model failed.');
}
}

/**
* Train a `tf.Model` to recognize Iris flower type.
*
Expand All @@ -53,9 +37,9 @@ async function loadHostedPretrainedModel() {
* @returns The trained `tf.Model` instance.
*/
async function trainModel(xTrain, yTrain, xTest, yTest) {
status('Training model... Please wait.');
ui.status('Training model... Please wait.');

const params = loadTrainParametersFromUI();
const params = ui.loadTrainParametersFromUI();

// Define the topology of the model: two dense layers.
const model = tf.sequential();
Expand All @@ -79,16 +63,16 @@ async function trainModel(xTrain, yTrain, xTest, yTest) {
callbacks: {
onEpochEnd: async (epoch, logs) => {
// Plot the loss and accuracy values at the end of every training epoch.
plotLosses(lossValues, epoch, logs.loss, logs.val_loss);
plotAccuracies(accuracyValues, epoch, logs.acc, logs.val_acc);
ui.plotLosses(lossValues, epoch, logs.loss, logs.val_loss);
ui.plotAccuracies(accuracyValues, epoch, logs.acc, logs.val_acc);

// Await web page DOM to refresh for the most recently plotted values.
await tf.nextFrame();
},
}
});

status('Model training complete.');
ui.status('Model training complete.');
return model;
}

Expand All @@ -99,25 +83,25 @@ async function trainModel(xTrain, yTrain, xTest, yTest) {
*/
async function predictOnManualInput(model) {
if (model == null) {
setManualInputWinnerMessage('ERROR: Please load or train model first.');
ui.setManualInputWinnerMessage('ERROR: Please load or train model first.');
return;
}

// Use a `tf.tidy` scope to make sure that WebGL memory allocated for the
// `predict` call is released at the end.
tf.tidy(() => {
// Prepare input data as a 2D `tf.Tensor`.
const inputData = getManualInputData();
const inputData = ui.getManualInputData();
const input = tf.tensor2d([inputData], [1, 4]);

// Call `model.predict` to get the prediction output as probabilities for
// the Iris flower categories.

const predictOut = model.predict(input);
const logits = Array.from(predictOut.dataSync());
const winner = IRIS_CLASSES[predictOut.argMax(-1).dataSync()[0]];
setManualInputWinnerMessage(winner);
renderLogitsForManualInput(logits);
const winner = data.IRIS_CLASSES[predictOut.argMax(-1).dataSync()[0]];
ui.setManualInputWinnerMessage(winner);
ui.renderLogitsForManualInput(logits);
});
}

Expand All @@ -130,39 +114,62 @@ async function predictOnManualInput(model) {
* [numTestExamples, 3].
*/
async function evaluateModelOnTestData(model, xTest, yTest) {
clearEvaluateTable();
ui.clearEvaluateTable();

tf.tidy(() => {
const xData = xTest.dataSync();
const yTrue = yTest.argMax(-1).dataSync();
const predictOut = model.predict(xTest);
const yPred = predictOut.argMax(-1);
renderEvaluateTable(xData, yTrue, yPred.dataSync(), predictOut.dataSync());
ui.renderEvaluateTable(
xData, yTrue, yPred.dataSync(), predictOut.dataSync());
});

predictOnManualInput(model);
}

const LOCAL_MODEL_JSON_URL = 'http://localhost:1235/resources/model.json';
const HOSTED_MODEL_JSON_URL =
'https://storage.googleapis.com/tfjs-models/tfjs/iris_v1/model.json';

/**
* The main function of the Iris demo.
*/
async function iris() {
const [xTrain, yTrain, xTest, yTest] = getIrisData(0.15);
const [xTrain, yTrain, xTest, yTest] = data.getIrisData(0.15);

document.getElementById('train-from-scratch')
.addEventListener('click', async () => {
model = await trainModel(xTrain, yTrain, xTest, yTest);
evaluateModelOnTestData(model, xTest, yTest);
});

document.getElementById('load-pretrained')
.addEventListener('click', async () => {
clearEvaluateTable();
await loadHostedPretrainedModel();
predictOnManualInput(model);
});
if (await loader.urlExists(HOSTED_MODEL_JSON_URL)) {
ui.status('Model available: ' + HOSTED_MODEL_JSON_URL);
const button = document.getElementById('load-pretrained-remote');
button.addEventListener('click', async () => {
ui.clearEvaluateTable();
model = await loader.loadHostedPretrainedModel(HOSTED_MODEL_JSON_URL);
predictOnManualInput(model);
});
// button.style.visibility = 'visible';
button.style.display = 'inline-block';
}

if (await loader.urlExists(LOCAL_MODEL_JSON_URL)) {
ui.status('Model available: ' + LOCAL_MODEL_JSON_URL);
const button = document.getElementById('load-pretrained-local');
button.addEventListener('click', async () => {
ui.clearEvaluateTable();
model = await loader.loadHostedPretrainedModel(LOCAL_MODEL_JSON_URL);
predictOnManualInput(model);
});
// button.style.visibility = 'visible';
button.style.display = 'inline-block';
}

wireUpEvaluateTableCallbacks(() => predictOnManualInput(model));
ui.status('Standing by.');
ui.wireUpEvaluateTableCallbacks(() => predictOnManualInput(model));
}

iris();
53 changes: 53 additions & 0 deletions iris/loader.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import * as tf from '@tensorflow/tfjs';
import * as ui from './ui';

/**
* Test whether a given URL is retrievable.
*/
export async function urlExists(url) {
ui.status('Testing url ' + url);
try {
const response = await fetch(url, {method: 'HEAD'});
return response.ok;
} catch (err) {
return false;
}
}

/**
* Load pretrained model stored at a remote URL.
*
* @return An instance of `tf.Model` with model topology and weights loaded.
*/
export async function loadHostedPretrainedModel(url) {
ui.status('Loading pretrained model from ' + url);
try {
const model = await tf.loadModel(url);
ui.status('Done loading pretrained model.');
// We can't load a model twice due to
// https://github.com/tensorflow/tfjs/issues/34
// Therefore we remove the load buttons to avoid user confusion.
ui.disableLoadModelButtons();
return model;
} catch (err) {
console.error(err);
ui.status('Loading pretrained model failed.');
}
}
7 changes: 4 additions & 3 deletions iris/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@
"vega-embed": "^3.0.0"
},
"scripts": {
"watch": "NODE_ENV=development parcel --no-hmr --open index.html ",
"build": "NODE_ENV=production parcel build index.html --no-minify --public-url ./"
"watch": "./serve.sh",
"build": "NODE_ENV=production parcel build index.html --no-minify --public-url /"
},
"devDependencies": {
"babel-plugin-transform-runtime": "~6.23.0",
"babel-polyfill": "~6.26.0",
"babel-preset-env": "~1.6.1",
"clang-format": "~1.2.2",
"parcel-bundler": "~1.6.2"
"http-server": "~0.10.0",
"parcel-bundler": "~1.7.0"
},
"babel": {
"presets": [
Expand Down
18 changes: 18 additions & 0 deletions iris/python/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright 2018 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
Loading