Skip to content
This repository has been archived by the owner on Aug 15, 2019. It is now read-only.

Commit

Permalink
Update models to deeplearnjs 0.5.0 (#769)
Browse files Browse the repository at this point in the history
  • Loading branch information
HalfdanJ authored and dsmilkov committed Feb 22, 2018
1 parent 4d71d85 commit fa67bbc
Show file tree
Hide file tree
Showing 14 changed files with 160 additions and 103 deletions.
2 changes: 1 addition & 1 deletion demos/imagenet/imagenet.ts
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ export class ImagenetDemo extends ImagenetDemoPolymer {
imagenet_util.getRenderGrayscaleChannelsCollageShader(this.gpgpu);

const cameraSetup = this.setupCameraInput();
this.squeezeNet = new SqueezeNet(this.math);
this.squeezeNet = new SqueezeNet();

await Promise.all([this.squeezeNet.load(), cameraSetup]);

Expand Down
4 changes: 2 additions & 2 deletions demos/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@
"@bower_components/webcomponentsjs": "webcomponents/webcomponentsjs#^0.7.24",
"chart.js": "~2.7.1",
"deeplearn": "file:../dist/",
"deeplearn-knn-image-classifier": "~0.2.3",
"deeplearn-squeezenet": "~0.1.9",
"deeplearn-knn-image-classifier": "~0.3.0",
"deeplearn-squeezenet": "~0.2.0",
"tslint": "~5.8.0",
"vue": "~2.5.9"
},
Expand Down
3 changes: 1 addition & 2 deletions demos/teachable_gaming/teachable_gaming.ts
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,7 @@ export class TeachableGamingDemo extends TeachableGamingDemoPolymer {
});
}
this.classifier = new KNNImageClassifier(
TeachableGamingDemo.maxControls, TeachableGamingDemo.knnKValue,
dl.ENV.math);
TeachableGamingDemo.maxControls, TeachableGamingDemo.knnKValue);
this.classifier.load();
this.predictedIndex = -1;
this.selectedIndex = -1;
Expand Down
42 changes: 42 additions & 0 deletions models/knn_image_classifier/demo.html
Original file line number Diff line number Diff line change
@@ -1,2 +1,44 @@
<script src="https://unpkg.com/deeplearn"></script>
<script src="dist/bundle.js"></script>

<h2>Training images</h2>
<img id="cat" src="images/cat.jpg"></img>
<img id="dog1" src="images/dog1.jpg"></img>
<h2>Input image</h2>
<img id="dog2" src="images/dog2.jpg"></img>
<div id="result"></div>

<script>
const cat = document.getElementById('cat');
const dog1 = document.getElementById('dog1');
const dog2 = document.getElementById('dog2');
const resultElement = document.getElementById('result');

window.onload = async () => {
resultElement.innerText = 'Loading classifier...';

// If dl isn't loaded, wait 1 second.
if (dl == null) {
await new Promise(resolve => setTimeout(resolve, 1000));
}

const knn = new knn_image_classifier.KNNImageClassifier(2, 1);
await knn.load();

resultElement.innerText = 'Training...';

const catPixels = dl.fromPixels(cat);
const dog1Pixels = dl.fromPixels(dog1);
const dog2Pixels = dl.fromPixels(dog2);

knn.addImage(catPixels, 0);
knn.addImage(dog1Pixels, 1);

resultElement.innerText = 'Predicting...';

const prediction = await knn.predictClass(dog2Pixels);

resultElement.innerText = 'Predicted to be class '+(prediction.classIndex+1);
};

</script>
Binary file added models/knn_image_classifier/images/cat.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added models/knn_image_classifier/images/dog1.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added models/knn_image_classifier/images/dog2.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
67 changes: 32 additions & 35 deletions models/knn_image_classifier/knn_image_classifier.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,23 @@
* =============================================================================
*/
// tslint:disable-next-line:max-line-length
import {Array1D, Array2D, Array3D, Model, NDArrayMath, Scalar} from 'deeplearn';
import * as dl from 'deeplearn';
import {Tensor1D, Tensor2D, Tensor3D} from 'deeplearn';
import {SqueezeNet} from 'deeplearn-squeezenet';
import * as model_util from '../util';

export class KNNImageClassifier implements Model {
export class KNNImageClassifier {
private squeezeNet: SqueezeNet;

// A concatenated matrix of all class logits matrices, lazily created and
// used during prediction.
private trainLogitsMatrix: Array2D;
private trainLogitsMatrix: Tensor2D;

private classLogitsMatrices: Array2D[] = [];
private classLogitsMatrices: Tensor2D[] = [];
private classExampleCount: number[] = [];

private varsLoaded = false;
private squashLogitsDenominator = Scalar.new(300);
private squashLogitsDenominator = dl.scalar(300);

/**
* Contructor for the class.
Expand All @@ -39,15 +40,13 @@ export class KNNImageClassifier implements Model {
* @param k The number of nearest neighbors to look at when predicting.
* @param math A math implementation for performing the calculations.
*/
constructor(
private numClasses: number, private k: number,
private math: NDArrayMath) {
constructor(private numClasses: number, private k: number) {
for (let i = 0; i < this.numClasses; i++) {
this.classLogitsMatrices.push(null);
this.classExampleCount.push(0);
}

this.squeezeNet = new SqueezeNet(this.math);
this.squeezeNet = new SqueezeNet();
}

/**
Expand Down Expand Up @@ -75,7 +74,7 @@ export class KNNImageClassifier implements Model {
/**
* Adds the provided image to the specified class.
*/
addImage(image: Array3D, classIndex: number): void {
addImage(image: Tensor3D, classIndex: number): void {
if (!this.varsLoaded) {
console.warn('Cannot add images until vars have been loaded.');
return;
Expand All @@ -85,7 +84,7 @@ export class KNNImageClassifier implements Model {
}
this.clearTrainLogitsMatrix();

this.math.scope((keep, track) => {
dl.tidy(() => {
// Add the squeezenet logits for the image to the appropriate class
// logits matrix.
const logits = this.squeezeNet.predict(image);
Expand All @@ -95,14 +94,16 @@ export class KNNImageClassifier implements Model {
if (this.classLogitsMatrices[classIndex] == null) {
this.classLogitsMatrices[classIndex] = imageLogits.as2D(1, logitsSize);
} else {
const newTrainLogitsMatrix = this.math.concat2D(
this.classLogitsMatrices[classIndex].as2D(
this.classExampleCount[classIndex], logitsSize),
imageLogits.as2D(1, logitsSize), 0);
const newTrainLogitsMatrix =
this.classLogitsMatrices[classIndex]
.as2D(this.classExampleCount[classIndex], logitsSize)
.concat(imageLogits.as2D(1, logitsSize), 0);

this.classLogitsMatrices[classIndex].dispose();
this.classLogitsMatrices[classIndex] = newTrainLogitsMatrix;
}
keep(this.classLogitsMatrices[classIndex]);

dl.keep(this.classLogitsMatrices[classIndex]);

this.classExampleCount[classIndex]++;
});
Expand All @@ -123,12 +124,12 @@ export class KNNImageClassifier implements Model {
* @param image The input image.
* @returns cosine distances for each entry in the database.
*/
predict(image: Array3D): Array1D {
predict(image: Tensor3D): Tensor1D {
if (!this.varsLoaded) {
throw new Error('Cannot predict until vars have been loaded.');
}

return this.math.scope((keep) => {
return dl.tidy(() => {
const logits = this.squeezeNet.predict(image);
const imageLogits = this.normalizeVector(logits);
const logitsSize = imageLogits.shape[0];
Expand All @@ -149,13 +150,11 @@ export class KNNImageClassifier implements Model {
return null;
}

keep(this.trainLogitsMatrix);
dl.keep(this.trainLogitsMatrix);

const numExamples = this.getNumExamples();
return this.math
.matMul(
this.trainLogitsMatrix.as2D(numExamples, logitsSize),
imageLogits.as2D(logitsSize, 1))
return this.trainLogitsMatrix.as2D(numExamples, logitsSize)
.matMul(imageLogits.as2D(logitsSize, 1))
.as1D();
});
}
Expand All @@ -168,7 +167,7 @@ export class KNNImageClassifier implements Model {
* @returns A dict of the top class for the image and an array of confidence
* values for all possible classes.
*/
async predictClass(image: Array3D):
async predictClass(image: Tensor3D):
Promise<{classIndex: number, confidences: number[]}> {
let imageClass = -1;
const confidences = new Array<number>(this.numClasses);
Expand All @@ -179,7 +178,7 @@ export class KNNImageClassifier implements Model {
const knn = this.predict(image).asType('float32');
const numExamples = this.getNumExamples();
const kVal = Math.min(this.k, numExamples);
const topK = model_util.topK(await knn.data(), kVal);
const topK = model_util.topK(await knn.data() as Float32Array, kVal);
knn.dispose();
const topKIndices = topK.indices;

Expand Down Expand Up @@ -236,32 +235,30 @@ export class KNNImageClassifier implements Model {
}
}

private concatWithNulls(ndarray1: Array2D, ndarray2: Array2D): Array2D {
private concatWithNulls(ndarray1: Tensor2D, ndarray2: Tensor2D): Tensor2D {
if (ndarray1 == null && ndarray2 == null) {
return null;
}
if (ndarray1 == null) {
return this.math.clone(ndarray2);
return ndarray2.clone();
} else if (ndarray2 === null) {
return this.math.clone(ndarray1);
return ndarray1.clone();
}
return this.math.concat2D(ndarray1, ndarray2, 0);
return ndarray1.concat(ndarray2, 0);
}

/**
* Normalize the provided vector to unit length.
*/
private normalizeVector(vec: Array1D) {
private normalizeVector(vec: Tensor1D) {
// This hack is here for numerical stability on devices without floating
// point textures. We divide by a constant so that the sum doesn't overflow
// our fixed point precision. Remove this once we use floating point
// intermediates with proper dynamic range quantization.
const squashedVec = this.math.divide(vec, this.squashLogitsDenominator);
const squashedVec = dl.div(vec, this.squashLogitsDenominator);
const sqrtSum = squashedVec.square().sum().sqrt();

const squared = this.math.multiplyStrict(squashedVec, squashedVec);
const sum = this.math.sum(squared);
const sqrtSum = this.math.sqrt(sum);
return this.math.divide(squashedVec, sqrtSum);
return dl.div(squashedVec, sqrtSum);
}

private getNumExamples() {
Expand Down
8 changes: 4 additions & 4 deletions models/knn_image_classifier/package.json
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
{
"name": "deeplearn-knn-image-classifier",
"version": "0.2.3",
"version": "0.3.0",
"description": "A KNN Image Classifier model in deeplearn.js",
"main": "dist/knn_image_classifier/index.js",
"unpkg": "dist/bundle.js",
"types": "dist/knn_image_classifier/index.d.ts",
"peerDependencies": {
"deeplearn": "0.4.0"
"deeplearn": "0.5.0"
},
"dependencies": {
"deeplearn-squeezenet": "~0.1.9"
"deeplearn-squeezenet": "~0.2.0"
},
"repository": {
"type": "git",
Expand All @@ -18,7 +18,7 @@
"devDependencies": {
"browserify": "~14.5.0",
"browserify-shim": "~3.8.14",
"deeplearn": "0.4.0",
"deeplearn": "0.5.0",
"mkdirp": "~0.5.1",
"tsify": "~3.0.3",
"tslint": "~5.8.0",
Expand Down
11 changes: 8 additions & 3 deletions models/knn_image_classifier/yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -564,11 +564,12 @@ deeplearn-squeezenet@~0.1.9:
version "0.1.9"
resolved "https://registry.yarnpkg.com/deeplearn-squeezenet/-/deeplearn-squeezenet-0.1.9.tgz#a4194c31156cd2fa7421aaffcc6e3d4b3e3ef1c1"

deeplearn@0.4.0:
version "0.4.0"
resolved "https://registry.yarnpkg.com/deeplearn/-/deeplearn-0.4.0.tgz#3181789795b28c1c8d0311fb2209cea8dfe5ad67"
deeplearn@0.5.0:
version "0.5.0"
resolved "https://registry.yarnpkg.com/deeplearn/-/deeplearn-0.5.0.tgz#015c25fb35fcbf40cf9f31d9fff7760af7a436ec"
dependencies:
seedrandom "~2.4.3"
utf8 "~2.1.2"

defined@^1.0.0:
version "1.0.0"
Expand Down Expand Up @@ -1993,6 +1994,10 @@ url@~0.11.0:
punycode "1.3.2"
querystring "0.2.0"

utf8@~2.1.2:
version "2.1.2"
resolved "https://registry.yarnpkg.com/utf8/-/utf8-2.1.2.tgz#1fa0d9270e9be850d9b05027f63519bf46457d96"

util-deprecate@~1.0.1:
version "1.0.2"
resolved "https://registry.yarnpkg.com/util-deprecate/-/util-deprecate-1.0.2.tgz#450d4dc9fa70de732762fbd2d4a28981419a0ccf"
Expand Down
2 changes: 1 addition & 1 deletion models/squeezenet/demo.html
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

resultElement.innerText = 'Predicting...';

const pixels = dl.Array3D.fromPixels(cat);
const pixels = dl.fromPixels(cat);

const result = squeezeNet.predict(pixels);

Expand Down
6 changes: 3 additions & 3 deletions models/squeezenet/package.json
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
{
"name": "deeplearn-squeezenet",
"version": "0.1.9",
"version": "0.2.0",
"description": "Pretrained SqueezeNet model in deeplearn.js",
"main": "dist/squeezenet/index.js",
"unpkg": "dist/bundle.js",
"types": "dist/squeezenet/index.d.ts",
"peerDependencies": {
"deeplearn": "0.4.0"
"deeplearn": "0.5.0"
},
"repository": {
"type": "git",
Expand All @@ -15,7 +15,7 @@
"devDependencies": {
"browserify": "~14.5.0",
"browserify-shim": "~3.8.14",
"deeplearn": "0.4.0",
"deeplearn": "0.5.0",
"mkdirp": "~0.5.1",
"tsify": "~3.0.3",
"tslint": "~5.8.0",
Expand Down
Loading

0 comments on commit fa67bbc

Please sign in to comment.