Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Cherry pick to 0.15.x: remove content-type checks #1541

Merged
merged 1 commit into from
Feb 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 10 additions & 33 deletions src/io/browser_http.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ import {IORouter, IORouterRegistry} from './router_registry';
import {IOHandler, ModelArtifacts, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './types';
import {loadWeightsAsArrayBuffer} from './weights_loader';

const OCTET_STREAM_MIME_TYPE = 'application/octet-stream';
const JSON_TYPE = 'application/json';

export class BrowserHTTPRequest implements IOHandler {
protected readonly path: string|string[];
protected readonly requestInit: RequestInit;
Expand Down Expand Up @@ -108,14 +111,13 @@ export class BrowserHTTPRequest implements IOHandler {
'model.json',
new Blob(
[JSON.stringify(modelTopologyAndWeightManifest)],
{type: 'application/json'}),
{type: JSON_TYPE}),
'model.json');

if (modelArtifacts.weightData != null) {
init.body.append(
'model.weights.bin',
new Blob(
[modelArtifacts.weightData], {type: 'application/octet-stream'}),
new Blob([modelArtifacts.weightData], {type: OCTET_STREAM_MIME_TYPE}),
'model.weights.bin');
}

Expand Down Expand Up @@ -150,10 +152,7 @@ export class BrowserHTTPRequest implements IOHandler {
* Loads the model topology file and build the in memory graph of the model.
*/
private async loadBinaryTopology(): Promise<ArrayBuffer> {
const response = await this.getFetchFunc()(
this.path[0], this.addAcceptHeader('application/octet-stream'));
this.verifyContentType(
response, 'model topology', 'application/octet-stream');
const response = await this.getFetchFunc()(this.path[0], this.requestInit);

if (!response.ok) {
throw new Error(`Request to ${this.path[0]} failed with error: ${
Expand All @@ -162,30 +161,10 @@ export class BrowserHTTPRequest implements IOHandler {
return await response.arrayBuffer();
}

private addAcceptHeader(mimeType: string): RequestInit {
const requestOptions = Object.assign({}, this.requestInit || {});
const headers = Object.assign({}, requestOptions.headers || {});
// tslint:disable-next-line:no-any
(headers as any)['Accept'] = mimeType;
requestOptions.headers = headers;
return requestOptions;
}

private verifyContentType(response: Response, target: string, type: string) {
const contentType = response.headers.get('content-type');
if (!contentType || contentType.indexOf(type) === -1) {
throw new Error(`Request to ${response.url} for ${
target} failed. Expected content type ${type} but got ${
contentType}.`);
}
}

protected async loadBinaryModel(): Promise<ModelArtifacts> {
const graphPromise = this.loadBinaryTopology();
const manifestPromise = await this.getFetchFunc()(
this.path[1], this.addAcceptHeader('application/json'));
this.verifyContentType(
manifestPromise, 'weights manifest', 'application/json');
const manifestPromise =
await this.getFetchFunc()(this.path[1], this.requestInit);
if (!manifestPromise.ok) {
throw new Error(`Request to ${this.path[1]} failed with error: ${
manifestPromise.statusText}`);
Expand All @@ -208,10 +187,8 @@ export class BrowserHTTPRequest implements IOHandler {
}

protected async loadJSONModel(): Promise<ModelArtifacts> {
const modelConfigRequest = await this.getFetchFunc()(
this.path as string, this.addAcceptHeader('application/json'));
this.verifyContentType(
modelConfigRequest, 'model topology', 'application/json');
const modelConfigRequest =
await this.getFetchFunc()(this.path as string, this.requestInit);

if (!modelConfigRequest.ok) {
throw new Error(`Request to ${this.path} failed with error: ${
Expand Down
209 changes: 0 additions & 209 deletions src/io/browser_http_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,6 @@ describeWithFlags('browserHTTPRequest-load fetch', NODE_ENVS, () => {
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights);
expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData);
expect(requestInits['./model.json'].headers['Accept'])
.toEqual('application/json');
expect(requestInits['./weightfile0'].headers['Accept'])
.toEqual('application/octet-stream');
});

it('throw exception if no fetch polyfill', () => {
Expand Down Expand Up @@ -487,10 +483,6 @@ describeWithFlags('browserHTTPRequest-load', BROWSER_ENVS, () => {
expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights);
expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData);
expect(Object.keys(requestInits).length).toEqual(2);
expect(requestInits['./model.json'].headers['Accept'])
.toEqual('application/json');
expect(requestInits['./weightfile0'].headers['Accept'])
.toEqual('application/octet-stream');
// Assert that fetch is invoked with `window` as the context.
expect(windowFetchSpy.calls.mostRecent().object).toEqual(window);
});
Expand Down Expand Up @@ -533,10 +525,6 @@ describeWithFlags('browserHTTPRequest-load', BROWSER_ENVS, () => {
expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights);
expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData);
expect(Object.keys(requestInits).length).toEqual(2);
expect(requestInits['./model.json'].headers['Accept'])
.toEqual('application/json');
expect(requestInits['./weightfile0'].headers['Accept'])
.toEqual('application/octet-stream');
expect(Object.keys(requestInits).length).toEqual(2);
expect(requestInits['./model.json'].headers['header_key_1'])
.toEqual('header_value_1');
Expand Down Expand Up @@ -765,29 +753,6 @@ describeWithFlags('browserHTTPRequest-load', BROWSER_ENVS, () => {
});
});

it('with wrong content type leads to error', async (done) => {
setupFakeWeightFiles(
{
'path1/model.json':
{data: JSON.stringify({}), contentType: 'text/html'}
},
requestInits);
const handler = tf.io.browserHTTPRequest('path1/model.json');
handler.load()
.then(modelTopology1 => {
done.fail(
'Loading with wrong content-type succeeded unexpectedly.');
})
.catch(err => {
expect(err.message)
.toEqual(
'Request to path1/model.json for model topology failed. ' +
'Expected content type application/json ' +
'but got text/html.');
done();
});
});

it('with fetch rejection leads to error', async (done) => {
setupFakeWeightFiles(
{
Expand Down Expand Up @@ -856,12 +821,6 @@ describeWithFlags('browserHTTPRequest-load', BROWSER_ENVS, () => {
expect(new Float32Array(modelArtifacts.weightData))
.toEqual(floatData);
expect(Object.keys(requestInits).length).toEqual(3);
expect(requestInits['./model.pb'].headers['Accept'])
.toEqual('application/octet-stream');
expect(requestInits['./weights_manifest.json'].headers['Accept'])
.toEqual('application/json');
expect(requestInits['./weightfile0'].headers['Accept'])
.toEqual('application/octet-stream');
done();
})
.catch(err => done.fail(err.stack));
Expand Down Expand Up @@ -906,13 +865,6 @@ describeWithFlags('browserHTTPRequest-load', BROWSER_ENVS, () => {
expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights);
expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData);
expect(Object.keys(requestInits).length).toEqual(3);
expect(requestInits['./model.pb?tfjs-format=file'].headers['Accept'])
.toEqual('application/octet-stream');
expect(requestInits['./weights_manifest.json?tfjs-format=file']
.headers['Accept'])
.toEqual('application/json');
expect(requestInits['./weightfile0?tfjs-format=file'].headers['Accept'])
.toEqual('application/octet-stream');
});

it('1 group, 2 weights, 1 path, with requestInit', async () => {
Expand Down Expand Up @@ -954,16 +906,10 @@ describeWithFlags('browserHTTPRequest-load', BROWSER_ENVS, () => {
expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights);
expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData);
expect(Object.keys(requestInits).length).toEqual(3);
expect(requestInits['./model.pb'].headers['Accept'])
.toEqual('application/octet-stream');
expect(requestInits['./model.pb'].headers['header_key_1'])
.toEqual('header_value_1');
expect(requestInits['./weights_manifest.json'].headers['Accept'])
.toEqual('application/json');
expect(requestInits['./weights_manifest.json'].headers['header_key_1'])
.toEqual('header_value_1');
expect(requestInits['./weightfile0'].headers['Accept'])
.toEqual('application/octet-stream');
expect(requestInits['./weightfile0'].headers['header_key_1'])
.toEqual('header_value_1');
});
Expand Down Expand Up @@ -1159,161 +1105,6 @@ describeWithFlags('browserHTTPRequest-load', BROWSER_ENVS, () => {
expect(() => tf.io.browserHTTPRequest(['path1/model.pb'])).toThrow();
});

it('with wrong model content type leads to error', async (done) => {
const weightsManifest: tf.io.WeightsManifestConfig = [
{
paths: ['weightfile0'],
weights: [{
name: 'fooWeight',
shape: [3, 1],
dtype: 'int32',
}]
},
{
paths: ['weightfile1'],
weights: [{
name: 'barWeight',
shape: [2],
dtype: 'bool',
}],
}
];
const floatData1 = new Int32Array([1, 3, 3]);
const floatData2 = new Uint8Array([7, 4]);
setupFakeWeightFiles(
{
'path1/model.pb': {data: modelData, contentType: 'text/html'},
'path2/weights_manifest.json': {
data: JSON.stringify(weightsManifest),
contentType: 'application/json'
},
'path3/weightfile0':
{data: floatData1, contentType: 'application/octet-stream'},
'path3/weightfile1':
{data: floatData2, contentType: 'application/octet-stream'},
},
requestInits);
const handler = tf.io.browserHTTPRequest(
['path1/model.pb', 'path2/weights_manifest.json'], {}, 'path3/');
handler.load()
.then(modelTopology1 => {
done.fail(
'Loading with wrong content-type succeeded unexpectedly.');
})
.catch(err => {
expect(err.message)
.toEqual(
'Request to path1/model.pb for model topology failed. ' +
'Expected content type application/octet-stream ' +
'but got text/html.');
done();
});
});

it('with wrong manifest content type leads to error', async (done) => {
const weightsManifest: tf.io.WeightsManifestConfig = [
{
paths: ['weightfile0'],
weights: [{
name: 'fooWeight',
shape: [3, 1],
dtype: 'int32',
}]
},
{
paths: ['weightfile1'],
weights: [{
name: 'barWeight',
shape: [2],
dtype: 'bool',
}],
}
];
const floatData1 = new Int32Array([1, 3, 3]);
const floatData2 = new Uint8Array([7, 4]);
setupFakeWeightFiles(
{
'path1/model.pb':
{data: modelData, contentType: 'application/octet-stream'},
'path2/weights_manifest.json': {
data: JSON.stringify(weightsManifest),
contentType: 'application/octet-stream'
},
'path3/weightfile0':
{data: floatData1, contentType: 'application/octet-stream'},
'path3/weightfile1':
{data: floatData2, contentType: 'application/octet-stream'},
},
requestInits);
const handler = tf.io.browserHTTPRequest(
['path1/model.pb', 'path2/weights_manifest.json'], {}, 'path3/');
handler.load()
.then(modelTopology1 => {
done.fail(
'Loading with wrong content-type succeeded unexpectedly.');
})
.catch(err => {
expect(err.message)
.toEqual(
'Request to path2/weights_manifest.json for weights ' +
'manifest failed. Expected content type application/json' +
' but got application/octet-stream.');
done();
});
});

it('with wrong weight content type leads to error', async (done) => {
const weightsManifest: tf.io.WeightsManifestConfig = [
{
paths: ['weightfile0'],
weights: [{
name: 'fooWeight',
shape: [3, 1],
dtype: 'int32',
}]
},
{
paths: ['weightfile1'],
weights: [{
name: 'barWeight',
shape: [2],
dtype: 'bool',
}],
}
];
const floatData1 = new Int32Array([1, 3, 3]);
const floatData2 = new Uint8Array([7, 4]);
setupFakeWeightFiles(
{
'path1/model.pb':
{data: modelData, contentType: 'application/octet-stream'},
'path2/weights_manifest.json': {
data: JSON.stringify(weightsManifest),
contentType: 'application/json'
},
'path3/weightfile0':
{data: floatData1, contentType: 'application/json'},
'path3/weightfile1':
{data: floatData2, contentType: 'application/octet-stream'},
},
requestInits);
const handler = tf.io.browserHTTPRequest(
['path1/model.pb', 'path2/weights_manifest.json'], {}, 'path3/');
handler.load()
.then(modelTopology1 => {
done.fail(
'Loading with wrong content-type succeeded unexpectedly.');
})
.catch(err => {
expect(err.message)
.toEqual(
'Request to path3/weightfile0 for weight file failed. ' +
'Expected content type application/octet-stream but ' +
'got application/json.');
done();
});
});

it('with fetch rejection leads to error', async (done) => {
setupFakeWeightFiles(
{
Expand Down
25 changes: 0 additions & 25 deletions src/io/weights_loader.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,6 @@ import * as util from '../util';
import {decodeWeights} from './io_utils';
import {DTYPE_VALUE_SIZE_MAP, WeightsManifestConfig, WeightsManifestEntry} from './types';

type RequestHeader = {
[key: string]: string
};

const OCTET_STREAM_TYPE = 'application/octet-stream';
const CONTENT_TYPE = 'Content-type';
/**
* Reads binary weights data from a number of URLs.
*
Expand All @@ -45,12 +39,6 @@ export async function loadWeightsAsArrayBuffer(
fetchFunc = fetch;
}

// Add accept header
requestOptions = requestOptions || {};
const headers = (requestOptions.headers || {}) as RequestHeader;
headers['Accept'] = OCTET_STREAM_TYPE;
requestOptions.headers = headers;

// Create the requests for all of the weights in parallel.
const requests =
fetchURLs.map(fetchURL => fetchFunc(fetchURL, requestOptions));
Expand All @@ -64,19 +52,6 @@ export async function loadWeightsAsArrayBuffer(
}

const responses = await Promise.all(requests);
const badContentType = responses.filter(response => {
const contentType = response.headers.get(CONTENT_TYPE);
return !contentType || contentType.indexOf(OCTET_STREAM_TYPE) === -1;
});
if (badContentType.length > 0) {
throw new Error(
badContentType
.map(
resp => `Request to ${resp.url} for weight file failed.` +
` Expected content type ${OCTET_STREAM_TYPE} but got ${
resp.headers.get(CONTENT_TYPE)}.`)
.join('\n'));
}
const bufferPromises = responses.map(response => response.arrayBuffer());

const bufferStartFraction = 0.5;
Expand Down
Loading