Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transform url fix #858

Merged
merged 2 commits into from
May 31, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 18 additions & 21 deletions dist/webdnn.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -495,17 +495,21 @@ declare module 'webdnn/descriptor_runner/descriptor_runner' {
import PlaceholderContext from 'webdnn/placeholder';
import SymbolicFloat32Array from 'webdnn/symbolic_typed_array/symbolic_float32array';
import { BackendName } from 'webdnn/webdnn';
export interface DescriptorRunnerOptions {
transformUrlDelegate?: (base: string) => string;
}
/**
* @protected
*/
export interface DescriptorRunnerConstructor<D extends GraphDescriptor, P> {
new (option?: any): DescriptorRunner<D, P>;
new (option: DescriptorRunnerOptions): DescriptorRunner<D, P>;
checkAvailability(): boolean;
}
/**
* `DescriptorRunner` provides interface to execute DNN model and access input and output buffers.
*/
export abstract class DescriptorRunner<D extends GraphDescriptor, P> {
constructor(option?: DescriptorRunnerOptions);
/**
* For Developer:
*
Expand Down Expand Up @@ -539,6 +543,7 @@ declare module 'webdnn/descriptor_runner/descriptor_runner' {
* The backend name
*/
readonly backendName: BackendName;
readonly transformUrlDelegate: (base: string) => string;
/**
* The descriptor
*/
Expand Down Expand Up @@ -735,28 +740,17 @@ declare module 'webdnn/fetch' {
ignoreCache: boolean;
progressCallback?: (loaded: number, total: number) => any;
}
/**
* Transform url generated based on current active backend
* @param url transformed url
* @protected
*/
export function transformUrl(url: string): string;
/**
* Register delegate function for transform url.
* @param delegate Delegate function which will be called with original url, and must return converted url strings.
* @protected
*/
export function registerTransformUrlDelegate(delegate: (base: string) => string): void;
/**
* Fetch function. WebDNN API use this function instead of original `fetch` function.
* FIXME
* @param input Requested url
* @param init Additional information about webdnnFetch
* @param init.ignoreCache If true, cache is ignored by appending '?t=(timestamp)' to the end of request url.
* @param transformUrlDelegate url transform function
* @param init? Additional information about webdnnFetch
* @param init?.ignoreCache If true, cache is ignored by appending '?t=(timestamp)' to the end of request url.
* @returns Response
* @protected
*/
export default function webdnnFetch(input: RequestInfo, init?: WebDNNRequestInit): Promise<any>;
export default function webdnnFetch(input: RequestInfo, transformUrlDelegate: (base: string) => string, init?: WebDNNRequestInit): Promise<any>;
/**
* Read `Response.body` stream as ArrayBuffer. This function provide progress information by callback.
* @param res Response object
Expand Down Expand Up @@ -796,7 +790,7 @@ declare module 'webdnn/descriptor_runner/descriptor_runner_fallback' {
import { GraphDescriptorFallback } from 'webdnn/graph_descriptor/graph_descriptor_fallback';
import SymbolicFloat32Array from 'webdnn/symbolic_typed_array/symbolic_float32array';
import { BackendName } from 'webdnn/webdnn';
import { DescriptorRunner } from 'webdnn/descriptor_runner/descriptor_runner';
import { DescriptorRunner, DescriptorRunnerOptions } from 'webdnn/descriptor_runner/descriptor_runner';
/**
* @protected
*/
Expand All @@ -808,6 +802,7 @@ declare module 'webdnn/descriptor_runner/descriptor_runner_fallback' {
private dynamicBuffer;
private directory;
static checkAvailability(): boolean;
constructor(options?: DescriptorRunnerOptions);
init(): Promise<void>;
setDescriptorAndParameters(descriptor: GraphDescriptorFallback, parameters: ArrayBuffer): Promise<void>;
fetchDescriptor(directory: string): Promise<any>;
Expand Down Expand Up @@ -861,7 +856,7 @@ declare module 'webdnn/descriptor_runner/descriptor_runner_webassembly' {
import { GraphDescriptorWebassembly } from 'webdnn/graph_descriptor/graph_descriptor_webassembly';
import SymbolicFloat32Array from 'webdnn/symbolic_typed_array/symbolic_float32array';
import { BackendName } from 'webdnn/webdnn';
import { DescriptorRunner } from 'webdnn/descriptor_runner/descriptor_runner';
import { DescriptorRunner, DescriptorRunnerOptions } from 'webdnn/descriptor_runner/descriptor_runner';
/**
* @protected
*/
Expand All @@ -873,7 +868,7 @@ declare module 'webdnn/descriptor_runner/descriptor_runner_webassembly' {
private worker_initial_error;
private directory;
static checkAvailability(): boolean;
constructor();
constructor(options?: DescriptorRunnerOptions);
init(): Promise<void>;
private absolutePath(path);
setDescriptorAndParameters(descriptor: GraphDescriptorWebassembly, parameters: ArrayBuffer): Promise<void>;
Expand Down Expand Up @@ -1260,7 +1255,7 @@ declare module 'webdnn/descriptor_runner/descriptor_runner_webgl' {
import { GraphDescriptorWebGL } from 'webdnn/graph_descriptor/graph_descriptor_webgl';
import SymbolicFloat32Array from 'webdnn/symbolic_typed_array/symbolic_float32array';
import { BackendName } from 'webdnn/webdnn';
import { DescriptorRunner } from 'webdnn/descriptor_runner/descriptor_runner';
import { DescriptorRunner, DescriptorRunnerOptions } from 'webdnn/descriptor_runner/descriptor_runner';
/**
* @protected
*/
Expand All @@ -1272,6 +1267,7 @@ declare module 'webdnn/descriptor_runner/descriptor_runner_webgl' {
private programs;
private buffers;
static checkAvailability(): boolean;
constructor(options?: DescriptorRunnerOptions);
init(): Promise<void>;
fetchDescriptor(directory: string): Promise<any>;
fetchParameters(directory: string, progressCallback?: (loaded: number, total: number) => any): Promise<ArrayBuffer>;
Expand Down Expand Up @@ -1395,12 +1391,13 @@ declare module 'webdnn/descriptor_runner/descriptor_runner_webgpu' {
import { GraphDescriptorWebGPU } from 'webdnn/graph_descriptor/graph_descriptor_webgpu';
import SymbolicFloat32Array from 'webdnn/symbolic_typed_array/symbolic_float32array';
import { BackendName } from 'webdnn/webdnn';
import { DescriptorRunner } from 'webdnn/descriptor_runner/descriptor_runner';
import { DescriptorRunner, DescriptorRunnerOptions } from 'webdnn/descriptor_runner/descriptor_runner';
/**
* DescriptorRunner for WebGPU
* @protected
*/
export default class DescriptorRunnerWebGPU extends DescriptorRunner<GraphDescriptorWebGPU, ArrayBuffer> {
constructor(options?: DescriptorRunnerOptions);
/**
* backend name
*/
Expand Down
4 changes: 2 additions & 2 deletions dist/webdnn.es5.js

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion dist/webdnn.es5.js.map

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions dist/webdnn.js

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion dist/webdnn.js.map

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"name": "webdnn",
"version": "1.2.5",
"description": "Deep Neural Network Execution Framework for Web Browsers",
"main": "src/descriptor_runner/webdnn.js",
"main": "dist/webdnn.js",
"types": "dist/webdnn.d.ts",
"directories": {
"doc": "docs",
Expand Down
16 changes: 15 additions & 1 deletion src/descriptor_runner/descriptor_runner/descriptor_runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,16 @@ import PlaceholderContext from "../placeholder";
import SymbolicFloat32Array from "../symbolic_typed_array/symbolic_float32array";
import { BackendName } from "../webdnn";

export interface DescriptorRunnerOptions {
transformUrlDelegate? : (base: string) => string
}


/**
* @protected
*/
export interface DescriptorRunnerConstructor<D extends GraphDescriptor, P> {
new(option?: any): DescriptorRunner<D, P>
new(option: DescriptorRunnerOptions): DescriptorRunner<D, P>

checkAvailability(): boolean;
}
Expand All @@ -21,6 +26,13 @@ export interface DescriptorRunnerConstructor<D extends GraphDescriptor, P> {
* `DescriptorRunner` provides interface to execute DNN model and access input and output buffers.
*/
export abstract class DescriptorRunner<D extends GraphDescriptor, P> {

constructor(option : DescriptorRunnerOptions = {}) {
let {
transformUrlDelegate = function(url){return url;}
} = option;
this.transformUrlDelegate = transformUrlDelegate;
}
/**
* For Developer:
*
Expand Down Expand Up @@ -56,6 +68,8 @@ export abstract class DescriptorRunner<D extends GraphDescriptor, P> {
*/
readonly backendName: BackendName;

readonly transformUrlDelegate: (base : string) => string;

/**
* The descriptor
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@

import * as localforage from "localforage";
import get_weight_decoder from "../decoder/get_weight_decoder";
import webdnnFetch, { readArrayBufferProgressively, transformUrl } from "../fetch"
import webdnnFetch, { readArrayBufferProgressively } from "../fetch"
import { GraphDescriptorFallback } from "../graph_descriptor/graph_descriptor_fallback";
import { Allocation, ResolvedAllocation } from "../graph_descriptor/memory_layout";
import PlaceholderContext from "../placeholder";
import SymbolicFloat32Array from "../symbolic_typed_array/symbolic_float32array";
import { BackendName } from "../webdnn";
import { DescriptorRunner } from "./descriptor_runner";
import { DescriptorRunner, DescriptorRunnerOptions } from "./descriptor_runner";

/**
* @private
Expand All @@ -37,6 +37,10 @@ export default class DescriptorRunnerFallback extends DescriptorRunner<GraphDesc
return true;
}

constructor(options: DescriptorRunnerOptions = {}) {
super(options);
}

async init(): Promise<void> {
//nothing to do
}
Expand All @@ -51,12 +55,12 @@ export default class DescriptorRunnerFallback extends DescriptorRunner<GraphDesc

async fetchDescriptor(directory: string) {
this.directory = directory;
let res = await webdnnFetch(`${directory}/graph_${this.backendName}.json`);
let res = await webdnnFetch(`${directory}/graph_${this.backendName}.json`, this.transformUrlDelegate);
return res.json();
}

async fetchParameters(directory: string, progressCallback?: (loaded: number, total: number) => any) {
let res = await webdnnFetch(`${directory}/weight_${this.backendName}.bin`);
let res = await webdnnFetch(`${directory}/weight_${this.backendName}.bin`, this.transformUrlDelegate);
return readArrayBufferProgressively(res, progressCallback);
}

Expand Down Expand Up @@ -118,7 +122,7 @@ export default class DescriptorRunnerFallback extends DescriptorRunner<GraphDesc
script.onload = resolve;
}

script.src = transformUrl(`${this.directory}/kernels_fallback.js`);
script.src = this.transformUrlDelegate(`${this.directory}/kernels_fallback.js`);
document.getElementsByTagName("head")[0].appendChild(script);
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@

import * as localforage from "localforage";
import get_weight_decoder from "../decoder/get_weight_decoder";
import webDNNFetch, { readArrayBufferProgressively, transformUrl } from "../fetch";
import webDNNFetch, { readArrayBufferProgressively } from "../fetch";
import { GraphDescriptorWebassembly } from "../graph_descriptor/graph_descriptor_webassembly";
import PlaceholderContext from "../placeholder";
import SymbolicFloat32Array from "../symbolic_typed_array/symbolic_float32array";
import { BackendName } from "../webdnn";
import { DescriptorRunner } from "./descriptor_runner";
import { DescriptorRunner, DescriptorRunnerOptions } from "./descriptor_runner";

/**
* @private
Expand All @@ -33,8 +33,8 @@ export default class DescriptorRunnerWebassembly extends DescriptorRunner<GraphD
return 'Worker' in window;
}

constructor() {
super();
constructor(options: DescriptorRunnerOptions = {}) {
super(options);
if (typeof Worker === 'undefined') throw new Error('WebWorker is needed for WebAssembly backend');
if (typeof WebAssembly !== 'object') {
console.warn('WebAssembly is not supported on this browser, trying to use asm.js code');
Expand Down Expand Up @@ -69,7 +69,7 @@ export default class DescriptorRunnerWebassembly extends DescriptorRunner<GraphD
kernel_backend = 'asmjs';
}
let worker_entry_js_path = `${this.directory}/kernels_${kernel_backend}.js`;
worker_entry_js_path = transformUrl(worker_entry_js_path);
worker_entry_js_path = this.transformUrlDelegate(worker_entry_js_path);
this.worker_entry_js_path = worker_entry_js_path;

let worker_src_fetch = await fetch(this.worker_entry_js_path);
Expand All @@ -82,7 +82,7 @@ export default class DescriptorRunnerWebassembly extends DescriptorRunner<GraphD
*/
let map_aux_file_src = (basename, key) => {
let file_abs = this.absolutePath(`${this.directory}/${basename}`);
let file_abs_transformed = transformUrl(file_abs);// absolute path is given
let file_abs_transformed = this.transformUrlDelegate(file_abs);// absolute path is given
worker_src = worker_src.replace(key, file_abs_transformed);
}

Expand Down Expand Up @@ -126,7 +126,7 @@ export default class DescriptorRunnerWebassembly extends DescriptorRunner<GraphD
*/
async fetchDescriptor(directory: string): Promise<GraphDescriptorWebassembly> {
this.directory = directory;
let res = await webDNNFetch(`${directory}/graph_${this.backendName}.json`);
let res = await webDNNFetch(`${directory}/graph_${this.backendName}.json`, this.transformUrlDelegate);
return res.json();
}

Expand All @@ -147,7 +147,7 @@ export default class DescriptorRunnerWebassembly extends DescriptorRunner<GraphD
*/
async fetchParameters(directory: string, progressCallback?: (loaded: number, total: number) => any): Promise<ArrayBuffer> {
let weight_url = `${directory}/weight_${this.backendName}.bin`;
let weight_fetch = await webDNNFetch(weight_url);
let weight_fetch = await webDNNFetch(weight_url, this.transformUrlDelegate);
return readArrayBufferProgressively(weight_fetch, progressCallback);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import PlaceholderContext from "../placeholder";
import SymbolicFloat32Array from "../symbolic_typed_array/symbolic_float32array";
import { BackendName, getConfiguration } from "../webdnn";
import WebGLHandler from "../webgl_handler";
import { DescriptorRunner } from "./descriptor_runner";
import {DescriptorRunner, DescriptorRunnerOptions} from "./descriptor_runner";

/**
* @protected
Expand Down Expand Up @@ -72,6 +72,10 @@ export default class DescriptorRunnerWebGL extends DescriptorRunner<GraphDescrip
return WebGLHandler.checkAvailability();
}

constructor(options: DescriptorRunnerOptions = {}) {
super(options);
}

async init() {
if (!DescriptorRunnerWebGL.checkAvailability()) throw Error('WebGL backend is not supported in this browser.');

Expand All @@ -84,12 +88,12 @@ export default class DescriptorRunnerWebGL extends DescriptorRunner<GraphDescrip
}

async fetchDescriptor(directory: string) {
let res = await webdnnFetch(`${directory}/graph_${this.backendName}_${this.handler.MAX_TEXTURE_SIZE}.json`);
let res = await webdnnFetch(`${directory}/graph_${this.backendName}_${this.handler.MAX_TEXTURE_SIZE}.json`, this.transformUrlDelegate);
return res.json();
}

async fetchParameters(directory: string, progressCallback?: (loaded: number, total: number) => any) {
let res = await webdnnFetch(`${directory}/weight_${this.backendName}_${this.handler.MAX_TEXTURE_SIZE}.bin`);
let res = await webdnnFetch(`${directory}/weight_${this.backendName}_${this.handler.MAX_TEXTURE_SIZE}.bin`, this.transformUrlDelegate);
return readArrayBufferProgressively(res, progressCallback);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import PlaceholderContext from "../placeholder";
import SymbolicFloat32Array from "../symbolic_typed_array/symbolic_float32array";
import { BackendName, getConfiguration } from "../webdnn";
import WebGPUHandler, { IS_WEBGPU_SUPPORTED } from "../webgpu_handler";
import { DescriptorRunner } from "./descriptor_runner";
import {DescriptorRunner, DescriptorRunnerOptions} from "./descriptor_runner";

/**
* Check this device is iOS devices or not.
Expand All @@ -25,6 +25,12 @@ const IS_IOS = navigator.userAgent.includes('iPhone') || navigator.userAgent.inc
* @protected
*/
export default class DescriptorRunnerWebGPU extends DescriptorRunner<GraphDescriptorWebGPU, ArrayBuffer> {


constructor(options: DescriptorRunnerOptions = {}) {
super(options);
}

/**
* backend name
*/
Expand Down Expand Up @@ -129,7 +135,7 @@ using namespace metal;
* @protected
*/
async fetchDescriptor(directory: string): Promise<GraphDescriptorWebGPU> {
let res = await webdnnFetch(`${directory}/graph_${this.backendName}.json`);
let res = await webdnnFetch(`${directory}/graph_${this.backendName}.json`, this.transformUrlDelegate);
return res.json();
}

Expand All @@ -149,7 +155,7 @@ using namespace metal;
* @protected
*/
async fetchParameters(directory: string, progressCallback?: (loaded: number, total: number) => any): Promise<ArrayBuffer> {
let res = await webdnnFetch(`${directory}/weight_${this.backendName}.bin`);
let res = await webdnnFetch(`${directory}/weight_${this.backendName}.bin`, this.transformUrlDelegate);
return readArrayBufferProgressively(res, progressCallback);
}

Expand Down
Loading