Skip to content

Commit

Permalink
Fix bug where a model with empty weights fails to load (tensorflow#7868)
Browse files Browse the repository at this point in the history
* Fix bug where a model with empty weights fails to load

* Address linting complaints

---------

Co-authored-by: fengwuyao <131706622+fengwuyao@users.noreply.github.com>
Co-authored-by: Ping Yu <4018+pyu10055@users.noreply.github.com>
  • Loading branch information
3 people authored Aug 21, 2023
1 parent 0cd53ba commit 20ceb6f
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 6 deletions.
20 changes: 14 additions & 6 deletions tfjs-layers/src/engine/container.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,17 @@ export interface ContainerArgs {
name?: string;
}

// get weights key from tensor map in order to check if it is from keras v3.
// e.g. dense/0
const isKerasSavedModelFormat = (weights: NamedTensorMap): boolean => {
const keys = Object.keys(weights);
if (keys.length === 0) {
return false;
}
const key = keys[0].split('/');
return !isNaN(parseInt(key[key.length - 1], 10));
};

/**
* A Container is a directed acyclic graph of layers.
*
Expand Down Expand Up @@ -594,19 +605,16 @@ export abstract class Container extends Layer {
loadWeights(weights: NamedTensorMap, strict = true) {
const nameToWeight: {[name: string]: LayerVariable} = {};
let totalWeightsCount = 0;
// get weights key from tensor map in order to check if it is from keras v3.
// e.g. dense/0
const key = Object.keys(weights)[0].split('/');
const isKerasSavedModelFormat = !isNaN(parseInt(key[key.length - 1], 10));
if (isKerasSavedModelFormat) {
const modelIsKerasSavedModelFormat = isKerasSavedModelFormat(weights);
if (modelIsKerasSavedModelFormat) {
this.parseWeights(weights);
}
// Check if weights from keras v3.
for (const layer of this.layers) {
for (const [index, weight] of layer.weights.entries()) {
// Parse the name to layerName/index.
// e.g. dense/0, dense/1, dense_1/0, dense_1/1
const parsedName = isKerasSavedModelFormat ?
const parsedName = modelIsKerasSavedModelFormat ?
`${weight.name.split('/').slice(0, -1).join('/') + '/'}${index}` :
weight.originalName;
if (nameToWeight[parsedName] != null) {
Expand Down
32 changes: 32 additions & 0 deletions tfjs-layers/src/model_save_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,38 @@ describeMathCPUAndWebGL2('Save-load round trips', () => {
}
});

it('loadLayersModel: save and load a model with empty weights', async () => {
// https://github.com/tensorflow/tfjs/issues/7865
// Models without weights should still be valid models
const model1 = tfl.sequential();
model1.add(
tfl.layers.upSampling2d({
size: [2, 2],
dataFormat: 'channelsLast',
inputShape: [null, null, 3],
})
);

// Use a randomly generated model path to prevent collision.
const path = `testModel${new Date().getTime()}_${Math.random()}`;

// First save the model to local storage.
const modelURL = `localstorage://${path}`;
await model1.save(modelURL);
// Once the saving succeeds, load the model back.
const model2 = await tfl.loadLayersModel(modelURL);
// Verify that the topology of the model is correct.
expect(model2.toJSON(null, false)).toEqual(model1.toJSON(null, false));

// Check the equality of the two models' weights.
const weights1 = model1.getWeights();
const weights2 = model2.getWeights();
expect(weights2.length).toEqual(weights1.length);
for (let i = 0; i < weights1.length; ++i) {
expectTensorsClose(weights1[i], weights2[i]);
}
});

it('Functional model, IndexedDB', async () => {
const input = tfl.input({shape: [2, 2]});
const layer1 = tfl.layers.flatten().apply(input);
Expand Down

0 comments on commit 20ceb6f

Please sign in to comment.