diff --git a/src/io/browser_http.ts b/src/io/browser_http.ts index 9e667d486a..7fb32934c4 100644 --- a/src/io/browser_http.ts +++ b/src/io/browser_http.ts @@ -27,6 +27,9 @@ import {IORouter, IORouterRegistry} from './router_registry'; import {IOHandler, ModelArtifacts, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './types'; import {loadWeightsAsArrayBuffer} from './weights_loader'; +const OCTET_STREAM_MIME_TYPE = 'application/octet-stream'; +const JSON_TYPE = 'application/json'; + export class BrowserHTTPRequest implements IOHandler { protected readonly path: string|string[]; protected readonly requestInit: RequestInit; @@ -108,14 +111,13 @@ export class BrowserHTTPRequest implements IOHandler { 'model.json', new Blob( [JSON.stringify(modelTopologyAndWeightManifest)], - {type: 'application/json'}), + {type: JSON_TYPE}), 'model.json'); if (modelArtifacts.weightData != null) { init.body.append( 'model.weights.bin', - new Blob( - [modelArtifacts.weightData], {type: 'application/octet-stream'}), + new Blob([modelArtifacts.weightData], {type: OCTET_STREAM_MIME_TYPE}), 'model.weights.bin'); } @@ -150,10 +152,7 @@ export class BrowserHTTPRequest implements IOHandler { * Loads the model topology file and build the in memory graph of the model. */ private async loadBinaryTopology(): Promise { - const response = await this.getFetchFunc()( - this.path[0], this.addAcceptHeader('application/octet-stream')); - this.verifyContentType( - response, 'model topology', 'application/octet-stream'); + const response = await this.getFetchFunc()(this.path[0], this.requestInit); if (!response.ok) { throw new Error(`Request to ${this.path[0]} failed with error: ${ @@ -162,30 +161,10 @@ export class BrowserHTTPRequest implements IOHandler { return await response.arrayBuffer(); } - private addAcceptHeader(mimeType: string): RequestInit { - const requestOptions = Object.assign({}, this.requestInit || {}); - const headers = Object.assign({}, requestOptions.headers || {}); - // tslint:disable-next-line:no-any - (headers as any)['Accept'] = mimeType; - requestOptions.headers = headers; - return requestOptions; - } - - private verifyContentType(response: Response, target: string, type: string) { - const contentType = response.headers.get('content-type'); - if (!contentType || contentType.indexOf(type) === -1) { - throw new Error(`Request to ${response.url} for ${ - target} failed. Expected content type ${type} but got ${ - contentType}.`); - } - } - protected async loadBinaryModel(): Promise { const graphPromise = this.loadBinaryTopology(); - const manifestPromise = await this.getFetchFunc()( - this.path[1], this.addAcceptHeader('application/json')); - this.verifyContentType( - manifestPromise, 'weights manifest', 'application/json'); + const manifestPromise = + await this.getFetchFunc()(this.path[1], this.requestInit); if (!manifestPromise.ok) { throw new Error(`Request to ${this.path[1]} failed with error: ${ manifestPromise.statusText}`); @@ -208,10 +187,8 @@ export class BrowserHTTPRequest implements IOHandler { } protected async loadJSONModel(): Promise { - const modelConfigRequest = await this.getFetchFunc()( - this.path as string, this.addAcceptHeader('application/json')); - this.verifyContentType( - modelConfigRequest, 'model topology', 'application/json'); + const modelConfigRequest = + await this.getFetchFunc()(this.path as string, this.requestInit); if (!modelConfigRequest.ok) { throw new Error(`Request to ${this.path} failed with error: ${ diff --git a/src/io/browser_http_test.ts b/src/io/browser_http_test.ts index 1f05aa7b68..bb46c5efa6 100644 --- a/src/io/browser_http_test.ts +++ b/src/io/browser_http_test.ts @@ -156,10 +156,6 @@ describeWithFlags('browserHTTPRequest-load fetch', NODE_ENVS, () => { expect(modelArtifacts.modelTopology).toEqual(modelTopology1); expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights); expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData); - expect(requestInits['./model.json'].headers['Accept']) - .toEqual('application/json'); - expect(requestInits['./weightfile0'].headers['Accept']) - .toEqual('application/octet-stream'); }); it('throw exception if no fetch polyfill', () => { @@ -487,10 +483,6 @@ describeWithFlags('browserHTTPRequest-load', BROWSER_ENVS, () => { expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights); expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData); expect(Object.keys(requestInits).length).toEqual(2); - expect(requestInits['./model.json'].headers['Accept']) - .toEqual('application/json'); - expect(requestInits['./weightfile0'].headers['Accept']) - .toEqual('application/octet-stream'); // Assert that fetch is invoked with `window` as the context. expect(windowFetchSpy.calls.mostRecent().object).toEqual(window); }); @@ -533,10 +525,6 @@ describeWithFlags('browserHTTPRequest-load', BROWSER_ENVS, () => { expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights); expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData); expect(Object.keys(requestInits).length).toEqual(2); - expect(requestInits['./model.json'].headers['Accept']) - .toEqual('application/json'); - expect(requestInits['./weightfile0'].headers['Accept']) - .toEqual('application/octet-stream'); expect(Object.keys(requestInits).length).toEqual(2); expect(requestInits['./model.json'].headers['header_key_1']) .toEqual('header_value_1'); @@ -765,29 +753,6 @@ describeWithFlags('browserHTTPRequest-load', BROWSER_ENVS, () => { }); }); - it('with wrong content type leads to error', async (done) => { - setupFakeWeightFiles( - { - 'path1/model.json': - {data: JSON.stringify({}), contentType: 'text/html'} - }, - requestInits); - const handler = tf.io.browserHTTPRequest('path1/model.json'); - handler.load() - .then(modelTopology1 => { - done.fail( - 'Loading with wrong content-type succeeded unexpectedly.'); - }) - .catch(err => { - expect(err.message) - .toEqual( - 'Request to path1/model.json for model topology failed. ' + - 'Expected content type application/json ' + - 'but got text/html.'); - done(); - }); - }); - it('with fetch rejection leads to error', async (done) => { setupFakeWeightFiles( { @@ -856,12 +821,6 @@ describeWithFlags('browserHTTPRequest-load', BROWSER_ENVS, () => { expect(new Float32Array(modelArtifacts.weightData)) .toEqual(floatData); expect(Object.keys(requestInits).length).toEqual(3); - expect(requestInits['./model.pb'].headers['Accept']) - .toEqual('application/octet-stream'); - expect(requestInits['./weights_manifest.json'].headers['Accept']) - .toEqual('application/json'); - expect(requestInits['./weightfile0'].headers['Accept']) - .toEqual('application/octet-stream'); done(); }) .catch(err => done.fail(err.stack)); @@ -906,13 +865,6 @@ describeWithFlags('browserHTTPRequest-load', BROWSER_ENVS, () => { expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights); expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData); expect(Object.keys(requestInits).length).toEqual(3); - expect(requestInits['./model.pb?tfjs-format=file'].headers['Accept']) - .toEqual('application/octet-stream'); - expect(requestInits['./weights_manifest.json?tfjs-format=file'] - .headers['Accept']) - .toEqual('application/json'); - expect(requestInits['./weightfile0?tfjs-format=file'].headers['Accept']) - .toEqual('application/octet-stream'); }); it('1 group, 2 weights, 1 path, with requestInit', async () => { @@ -954,16 +906,10 @@ describeWithFlags('browserHTTPRequest-load', BROWSER_ENVS, () => { expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights); expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData); expect(Object.keys(requestInits).length).toEqual(3); - expect(requestInits['./model.pb'].headers['Accept']) - .toEqual('application/octet-stream'); expect(requestInits['./model.pb'].headers['header_key_1']) .toEqual('header_value_1'); - expect(requestInits['./weights_manifest.json'].headers['Accept']) - .toEqual('application/json'); expect(requestInits['./weights_manifest.json'].headers['header_key_1']) .toEqual('header_value_1'); - expect(requestInits['./weightfile0'].headers['Accept']) - .toEqual('application/octet-stream'); expect(requestInits['./weightfile0'].headers['header_key_1']) .toEqual('header_value_1'); }); @@ -1159,161 +1105,6 @@ describeWithFlags('browserHTTPRequest-load', BROWSER_ENVS, () => { expect(() => tf.io.browserHTTPRequest(['path1/model.pb'])).toThrow(); }); - it('with wrong model content type leads to error', async (done) => { - const weightsManifest: tf.io.WeightsManifestConfig = [ - { - paths: ['weightfile0'], - weights: [{ - name: 'fooWeight', - shape: [3, 1], - dtype: 'int32', - }] - }, - { - paths: ['weightfile1'], - weights: [{ - name: 'barWeight', - shape: [2], - dtype: 'bool', - }], - } - ]; - const floatData1 = new Int32Array([1, 3, 3]); - const floatData2 = new Uint8Array([7, 4]); - setupFakeWeightFiles( - { - 'path1/model.pb': {data: modelData, contentType: 'text/html'}, - 'path2/weights_manifest.json': { - data: JSON.stringify(weightsManifest), - contentType: 'application/json' - }, - 'path3/weightfile0': - {data: floatData1, contentType: 'application/octet-stream'}, - 'path3/weightfile1': - {data: floatData2, contentType: 'application/octet-stream'}, - }, - requestInits); - const handler = tf.io.browserHTTPRequest( - ['path1/model.pb', 'path2/weights_manifest.json'], {}, 'path3/'); - handler.load() - .then(modelTopology1 => { - done.fail( - 'Loading with wrong content-type succeeded unexpectedly.'); - }) - .catch(err => { - expect(err.message) - .toEqual( - 'Request to path1/model.pb for model topology failed. ' + - 'Expected content type application/octet-stream ' + - 'but got text/html.'); - done(); - }); - }); - - it('with wrong manifest content type leads to error', async (done) => { - const weightsManifest: tf.io.WeightsManifestConfig = [ - { - paths: ['weightfile0'], - weights: [{ - name: 'fooWeight', - shape: [3, 1], - dtype: 'int32', - }] - }, - { - paths: ['weightfile1'], - weights: [{ - name: 'barWeight', - shape: [2], - dtype: 'bool', - }], - } - ]; - const floatData1 = new Int32Array([1, 3, 3]); - const floatData2 = new Uint8Array([7, 4]); - setupFakeWeightFiles( - { - 'path1/model.pb': - {data: modelData, contentType: 'application/octet-stream'}, - 'path2/weights_manifest.json': { - data: JSON.stringify(weightsManifest), - contentType: 'application/octet-stream' - }, - 'path3/weightfile0': - {data: floatData1, contentType: 'application/octet-stream'}, - 'path3/weightfile1': - {data: floatData2, contentType: 'application/octet-stream'}, - }, - requestInits); - const handler = tf.io.browserHTTPRequest( - ['path1/model.pb', 'path2/weights_manifest.json'], {}, 'path3/'); - handler.load() - .then(modelTopology1 => { - done.fail( - 'Loading with wrong content-type succeeded unexpectedly.'); - }) - .catch(err => { - expect(err.message) - .toEqual( - 'Request to path2/weights_manifest.json for weights ' + - 'manifest failed. Expected content type application/json' + - ' but got application/octet-stream.'); - done(); - }); - }); - - it('with wrong weight content type leads to error', async (done) => { - const weightsManifest: tf.io.WeightsManifestConfig = [ - { - paths: ['weightfile0'], - weights: [{ - name: 'fooWeight', - shape: [3, 1], - dtype: 'int32', - }] - }, - { - paths: ['weightfile1'], - weights: [{ - name: 'barWeight', - shape: [2], - dtype: 'bool', - }], - } - ]; - const floatData1 = new Int32Array([1, 3, 3]); - const floatData2 = new Uint8Array([7, 4]); - setupFakeWeightFiles( - { - 'path1/model.pb': - {data: modelData, contentType: 'application/octet-stream'}, - 'path2/weights_manifest.json': { - data: JSON.stringify(weightsManifest), - contentType: 'application/json' - }, - 'path3/weightfile0': - {data: floatData1, contentType: 'application/json'}, - 'path3/weightfile1': - {data: floatData2, contentType: 'application/octet-stream'}, - }, - requestInits); - const handler = tf.io.browserHTTPRequest( - ['path1/model.pb', 'path2/weights_manifest.json'], {}, 'path3/'); - handler.load() - .then(modelTopology1 => { - done.fail( - 'Loading with wrong content-type succeeded unexpectedly.'); - }) - .catch(err => { - expect(err.message) - .toEqual( - 'Request to path3/weightfile0 for weight file failed. ' + - 'Expected content type application/octet-stream but ' + - 'got application/json.'); - done(); - }); - }); - it('with fetch rejection leads to error', async (done) => { setupFakeWeightFiles( { diff --git a/src/io/weights_loader.ts b/src/io/weights_loader.ts index c5a2d5b79e..b10fbe0bd7 100644 --- a/src/io/weights_loader.ts +++ b/src/io/weights_loader.ts @@ -21,12 +21,6 @@ import * as util from '../util'; import {decodeWeights} from './io_utils'; import {DTYPE_VALUE_SIZE_MAP, WeightsManifestConfig, WeightsManifestEntry} from './types'; -type RequestHeader = { - [key: string]: string -}; - -const OCTET_STREAM_TYPE = 'application/octet-stream'; -const CONTENT_TYPE = 'Content-type'; /** * Reads binary weights data from a number of URLs. * @@ -45,12 +39,6 @@ export async function loadWeightsAsArrayBuffer( fetchFunc = fetch; } - // Add accept header - requestOptions = requestOptions || {}; - const headers = (requestOptions.headers || {}) as RequestHeader; - headers['Accept'] = OCTET_STREAM_TYPE; - requestOptions.headers = headers; - // Create the requests for all of the weights in parallel. const requests = fetchURLs.map(fetchURL => fetchFunc(fetchURL, requestOptions)); @@ -64,19 +52,6 @@ export async function loadWeightsAsArrayBuffer( } const responses = await Promise.all(requests); - const badContentType = responses.filter(response => { - const contentType = response.headers.get(CONTENT_TYPE); - return !contentType || contentType.indexOf(OCTET_STREAM_TYPE) === -1; - }); - if (badContentType.length > 0) { - throw new Error( - badContentType - .map( - resp => `Request to ${resp.url} for weight file failed.` + - ` Expected content type ${OCTET_STREAM_TYPE} but got ${ - resp.headers.get(CONTENT_TYPE)}.`) - .join('\n')); - } const bufferPromises = responses.map(response => response.arrayBuffer()); const bufferStartFraction = 0.5; diff --git a/src/io/weights_loader_test.ts b/src/io/weights_loader_test.ts index 659378ed21..f0c10dbc5a 100644 --- a/src/io/weights_loader_test.ts +++ b/src/io/weights_loader_test.ts @@ -471,8 +471,7 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => { .then(weights => { expect((window.fetch as jasmine.Spy).calls.count()).toBe(1); expect(window.fetch).toHaveBeenCalledWith('./weightfile0', { - credentials: 'include', - headers: {'Accept': 'application/octet-stream'} + credentials: 'include' }); }) .then(done)