Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement loadGraphModelSync #6428

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 58 additions & 17 deletions tfjs-converter/src/executor/graph_model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* =============================================================================
*/

import {InferenceModel, io, ModelPredictConfig, NamedTensorMap, Tensor} from '@tensorflow/tfjs-core';
import {InferenceModel, io, ModelPredictConfig, NamedTensorMap, Tensor, util} from '@tensorflow/tfjs-core';

import * as tensorflow from '../data/compiled_api';
import {NamedTensorsMap, TensorInfo} from '../data/types';
Expand All @@ -26,6 +26,9 @@ import {ResourceManager} from './resource_manager';

export const TFHUB_SEARCH_PARAM = '?tfjs-format=file';
export const DEFAULT_MODEL_NAME = 'model.json';
type Url = string | io.IOHandler | io.IOHandlerSync;
type UrlIOHandler<T extends Url> = T extends string ? io.IOHandler : T;

/**
* A `tf.GraphModel` is a directed, acyclic graph built from a
* SavedModel GraphDef and allows inference execution.
Expand All @@ -36,10 +39,11 @@ export const DEFAULT_MODEL_NAME = 'model.json';
*
* @doc {heading: 'Models', subheading: 'Classes'}
*/
export class GraphModel implements InferenceModel {
export class GraphModel<ModelURL extends Url = string|io.IOHandler>
implements InferenceModel {
private executor: GraphExecutor;
private version = 'n/a';
private handler: io.IOHandler;
private handler: UrlIOHandler<ModelURL>;
private artifacts: io.ModelArtifacts;
private initializer: GraphExecutor;
private resourceManager: ResourceManager;
Expand Down Expand Up @@ -88,7 +92,7 @@ export class GraphModel implements InferenceModel {
* before the load is completed.
*/
constructor(
private modelUrl: string|io.IOHandler,
private modelUrl: ModelURL,
private loadOptions: io.LoadOptions = {}) {
if (loadOptions == null) {
this.loadOptions = {};
Expand All @@ -97,12 +101,14 @@ export class GraphModel implements InferenceModel {
}

private findIOHandler() {
type IOHandler = UrlIOHandler<ModelURL>;
const path = this.modelUrl;
if ((path as io.IOHandler).load != null) {
// Path is an IO Handler.
this.handler = path as io.IOHandler;
this.handler = path as IOHandler;
} else if (this.loadOptions.requestInit != null) {
this.handler = io.browserHTTPRequest(path as string, this.loadOptions);
this.handler = io.browserHTTPRequest(path as string, this.loadOptions) as
IOHandler;
} else {
const handlers = io.getLoadHandlers(path as string, this.loadOptions);
if (handlers.length === 0) {
Expand All @@ -114,24 +120,33 @@ export class GraphModel implements InferenceModel {
`Found more than one (${handlers.length}) load handlers for ` +
`URL '${[path]}'`);
}
this.handler = handlers[0];
this.handler = handlers[0] as IOHandler;
}
}

/**
* Loads the model and weight files, construct the in memory weight map and
* compile the inference graph.
*/
async load(): Promise<boolean> {
load(): UrlIOHandler<ModelURL> extends io.IOHandlerSync ? boolean
: Promise<boolean> {
type IOHandler = UrlIOHandler<ModelURL>;
this.findIOHandler();
if (this.handler.load == null) {
throw new Error(
'Cannot proceed with model loading because the IOHandler provided ' +
'does not have the `load` method implemented.');
}
const artifacts = await this.handler.load();

return this.loadSync(artifacts);
type Result = IOHandler extends io.IOHandlerSync ? boolean
: Promise<boolean>;

const loadResult = this.handler.load() as ReturnType<IOHandler['load']>;
if (util.isPromise(loadResult)) {
return loadResult.then(artifacts => this.loadSync(artifacts)) as Result;
}

return this.loadSync(loadResult) as Result;
}

/**
Expand Down Expand Up @@ -448,15 +463,41 @@ export async function loadGraphModel(
options = {};
}

if (options.fromTFHub) {
if ((modelUrl as io.IOHandler).load == null) {
if (!(modelUrl as string).endsWith('/')) {
modelUrl = (modelUrl as string) + '/';
}
modelUrl = `${modelUrl}${DEFAULT_MODEL_NAME}${TFHUB_SEARCH_PARAM}`;
}
if (options.fromTFHub && typeof modelUrl === 'string') {
modelUrl = getTFHubUrl(modelUrl);
}
const model = new GraphModel(modelUrl, options);
await model.load();
return model;
}

/**
* Load a graph model given a synchronous IO handler with a 'load' method.
*
* @param modelSource The `io.IOHandlerSync` that loads the model.
*
* @doc {heading: 'Models', subheading: 'Loading'}
*/

export function loadGraphModelSync(
modelSource: io.IOHandlerSync): GraphModel<io.IOHandlerSync> {
if (modelSource == null) {
throw new Error(
'modelUrl in loadGraphModelSync() cannot be null. Please provide a ' +
'url or an IOHandler that loads the model');
}
if (!modelSource.load) {
throw new Error(`modelUrl IO Handler ${modelSource} has no load function`);
}
const model = new GraphModel(modelSource);

model.load();
return model;
}

function getTFHubUrl(modelUrl: string): string {
if (!modelUrl.endsWith('/')) {
modelUrl = (modelUrl) + '/';
}
return `${modelUrl}${DEFAULT_MODEL_NAME}${TFHUB_SEARCH_PARAM}`;
}
32 changes: 31 additions & 1 deletion tfjs-converter/src/executor/graph_model_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import * as tensorflow from '../data/compiled_api';
import {deregisterOp, registerOp} from '../operations/custom_op/register';
import {GraphNode} from '../operations/types';

import {GraphModel, loadGraphModel} from './graph_model';
import {GraphModel, loadGraphModel, loadGraphModelSync} from './graph_model';

const HOST = 'http://example.org';
const MODEL_URL = `${HOST}/model.json`;
Expand Down Expand Up @@ -405,6 +405,36 @@ describe('loadGraphModel', () => {
});
});

describe('loadGraphModelSync', () => {
it('Pass a custom io handler', () => {
const customLoader: tfc.io.IOHandlerSync = {
load: () => {
return {
modelTopology: SIMPLE_MODEL,
weightSpecs: weightsManifest,
weightData: new Int32Array([5]).buffer,
};
}
};
const model = loadGraphModelSync(customLoader);
expect(model).toBeDefined();
const bias = model.weights['Const'][0];
expect(bias.dtype).toBe('int32');
expect(bias.dataSync()).toEqual(new Int32Array([5]));
});

it('Expect an error when moderUrl is null', () => {
let errorMsg = 'no error';
try {
loadGraphModelSync(null);
} catch (err) {
errorMsg = err.message;
}
expect(errorMsg)
.toMatch(/modelUrl in loadGraphModelSync\(\) cannot be null/);
});
});

describe('Model', () => {
beforeEach(() => {
model = new GraphModel(MODEL_URL);
Expand Down
2 changes: 1 addition & 1 deletion tfjs-converter/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import './flags';

export {IAttrValue, INameAttrList, INodeDef, ITensor, ITensorShape} from './data/compiled_api';
export {GraphModel, loadGraphModel} from './executor/graph_model';
export {GraphModel, loadGraphModel, loadGraphModelSync} from './executor/graph_model';
export {deregisterOp, registerOp} from './operations/custom_op/register';
export {GraphNode, OpExecutor} from './operations/types';
export {version as version_converter} from './version';
36 changes: 33 additions & 3 deletions tfjs-converter/yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,11 @@
estree-walker "^1.0.1"
picomatch "^2.2.2"

"@tensorflow/tfjs-backend-cpu@link:../link-package-core/node_modules/@tensorflow/tfjs-backend-cpu":
"@tensorflow/tfjs-backend-cpu@link:../link-package/node_modules/@tensorflow/tfjs-backend-cpu":
version "0.0.0"
uid ""

"@tensorflow/tfjs-core@link:../link-package-core/node_modules/@tensorflow/tfjs-core":
"@tensorflow/tfjs-core@link:../link-package/node_modules/@tensorflow/tfjs-core":
version "0.0.0"
uid ""

Expand Down Expand Up @@ -214,6 +214,11 @@
resolved "https://registry.yarnpkg.com/@types/long/-/long-4.0.1.tgz#459c65fa1867dafe6a8f322c4c51695663cc55e9"
integrity sha512-5tXH6Bx/kNGd3MgffdmP4dy2Z+G4eaXw0SE81Tq3BNadtnMR5/ySMzX4SLEzHJzSmPNn4HIdpQsBvXMUykr58w==

"@types/long@^4.0.1":
version "4.0.2"
resolved "https://registry.yarnpkg.com/@types/long/-/long-4.0.2.tgz#b74129719fc8d11c01868010082d483b7545591a"
integrity sha512-MqTGEo5bj5t157U6fA/BiDynNkn0YknVdh48CMPkTSpFTVmvao5UQmm7uEF6xBEo7qIMAlY/JSleYaE6VOdpaA==

"@types/long@~3.0.32":
version "3.0.32"
resolved "https://registry.yarnpkg.com/@types/long/-/long-3.0.32.tgz#f4e5af31e9e9b196d8e5fca8a5e2e20aa3d60b69"
Expand Down Expand Up @@ -241,13 +246,33 @@
resolved "https://registry.yarnpkg.com/@types/node/-/node-10.17.55.tgz#a147f282edec679b894d4694edb5abeb595fecbd"
integrity sha512-koZJ89uLZufDvToeWO5BrC4CR4OUfHnUz2qoPs/daQH6qq3IN62QFxCTZ+bKaCE0xaoCAJYE4AXre8AbghCrhg==

"@types/offscreencanvas@~2019.3.0":
version "2019.3.0"
resolved "https://registry.yarnpkg.com/@types/offscreencanvas/-/offscreencanvas-2019.3.0.tgz#3336428ec7e9180cf4566dfea5da04eb586a6553"
integrity sha512-esIJx9bQg+QYF0ra8GnvfianIY8qWB0GBx54PK5Eps6m+xTj86KLavHv6qDhzKcu5UUOgNfJ2pWaIIV7TRUd9Q==

"@types/resolve@0.0.8":
version "0.0.8"
resolved "https://registry.yarnpkg.com/@types/resolve/-/resolve-0.0.8.tgz#f26074d238e02659e323ce1a13d041eee280e194"
integrity sha512-auApPaJf3NPfe18hSoJkp8EbZzer2ISk7o8mCC3M9he/a04+gbMF97NkpD2S8riMGvm4BMRI59/SZQSaLTKpsQ==
dependencies:
"@types/node" "*"

"@types/seedrandom@2.4.27":
version "2.4.27"
resolved "https://registry.yarnpkg.com/@types/seedrandom/-/seedrandom-2.4.27.tgz#9db563937dd86915f69092bc43259d2f48578e41"
integrity sha512-YvMLqFak/7rt//lPBtEHv3M4sRNA+HGxrhFZ+DQs9K2IkYJbNwVIb8avtJfhDiuaUBX/AW0jnjv48FV8h3u9bQ==

"@types/webgl-ext@0.0.30":
version "0.0.30"
resolved "https://registry.yarnpkg.com/@types/webgl-ext/-/webgl-ext-0.0.30.tgz#0ce498c16a41a23d15289e0b844d945b25f0fb9d"
integrity sha512-LKVgNmBxN0BbljJrVUwkxwRYqzsAEPcZOe6S2T6ZaBDIrFp0qu4FNlpc5sM1tGbXUYFgdVQIoeLk1Y1UoblyEg==

"@webgpu/types@^0.1.16":
version "0.1.17"
resolved "https://registry.yarnpkg.com/@webgpu/types/-/types-0.1.17.tgz#91e8ec9fd6a1e63945ef12bff11394949ea1a583"
integrity sha512-M8INbXsMdkWtVsSHRPEiTXHe0S4gxMhYA/Kz4pNoUF9IXd3PHMi6/2n8EAsqkAEdna+aeCm2RmscWV0hsmIf0Q==

ajv@~6.12.3:
version "6.12.3"
resolved "https://registry.yarnpkg.com/ajv/-/ajv-6.12.3.tgz#18c5af38a111ddeb4f2697bd78d68abc1cabd706"
Expand Down Expand Up @@ -1489,7 +1514,7 @@ lodash@^4.17.4:
resolved "https://registry.yarnpkg.com/lodash/-/lodash-4.17.21.tgz#679591c564c3bffaae8454cf0b3df370c3d6911c"
integrity sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==

long@^4.0.0:
long@4.0.0, long@^4.0.0:
version "4.0.0"
resolved "https://registry.yarnpkg.com/long/-/long-4.0.0.tgz#9a7b71cfb7d361a194ea555241c92f7468d5bf28"
integrity sha512-XsP+KhQif4bjX1kbuSiySJFNAehNxgLb6hPRGJ9QsUr8ajHkuXGdrHmFUTUUXhDwVX2R5bY4JNZEwbUiMhV+MA==
Expand Down Expand Up @@ -1902,6 +1927,11 @@ safe-buffer@~5.1.0, safe-buffer@~5.1.1:
resolved "https://registry.yarnpkg.com/safe-buffer/-/safe-buffer-5.1.2.tgz#991ec69d296e0313747d59bdfd2b745c35f8828d"
integrity sha512-Gd2UZBJDkXlY7GbJxfsE8/nvKkUEU1G38c1siN6QP6a9PT9MmHB8GnpscSmMJSoF8LOIrt8ud/wPtojys4G6+g==

seedrandom@2.4.3:
version "2.4.3"
resolved "https://registry.yarnpkg.com/seedrandom/-/seedrandom-2.4.3.tgz#2438504dad33917314bff18ac4d794f16d6aaecc"
integrity sha1-JDhQTa0zkXMUv/GKxNeU8W1qrsw=

semver@^5.3.0:
version "5.7.1"
resolved "https://registry.yarnpkg.com/semver/-/semver-5.7.1.tgz#a954f931aeba508d307bbf069eff0c01c96116f7"
Expand Down
3 changes: 2 additions & 1 deletion tfjs-core/src/io/io.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import {browserHTTPRequest, http, isHTTPScheme} from './http';
import {concatenateArrayBuffers, decodeWeights, encodeWeights, getModelArtifactsForJSON, getModelArtifactsInfoForJSON} from './io_utils';
import {fromMemory, fromMemorySync, withSaveHandler, withSaveHandlerSync} from './passthrough';
import {getLoadHandlers, getSaveHandlers, registerLoadRouter, registerSaveRouter} from './router_registry';
import {IOHandler, LoadHandler, LoadOptions, ModelArtifacts, ModelArtifactsInfo, ModelJSON, ModelStoreManager, OnProgressCallback, RequestDetails, SaveConfig, SaveHandler, SaveResult, TrainingConfig, WeightGroup, WeightsManifestConfig, WeightsManifestEntry} from './types';
import {IOHandler, IOHandlerSync, LoadHandler, LoadOptions, ModelArtifacts, ModelArtifactsInfo, ModelJSON, ModelStoreManager, OnProgressCallback, RequestDetails, SaveConfig, SaveHandler, SaveResult, TrainingConfig, WeightGroup, WeightsManifestConfig, WeightsManifestEntry} from './types';
import {loadWeights, weightsLoaderFactory} from './weights_loader';

export {copyModel, listModels, moveModel, removeModel} from './model_management';
Expand All @@ -43,6 +43,7 @@ export {
getSaveHandlers,
http,
IOHandler,
IOHandlerSync,
isHTTPScheme,
LoadHandler,
LoadOptions,
Expand Down