From 20ceb6f1f710271f428078b56ebddfcdfc1e1428 Mon Sep 17 00:00:00 2001 From: Kevin Scott <151596+thekevinscott@users.noreply.github.com> Date: Mon, 21 Aug 2023 18:36:27 -0400 Subject: [PATCH] Fix bug where a model with empty weights fails to load (#7868) * Fix bug where a model with empty weights fails to load * Address linting complaints --------- Co-authored-by: fengwuyao <131706622+fengwuyao@users.noreply.github.com> Co-authored-by: Ping Yu <4018+pyu10055@users.noreply.github.com> --- tfjs-layers/src/engine/container.ts | 20 ++++++++++++------ tfjs-layers/src/model_save_test.ts | 32 +++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 6 deletions(-) diff --git a/tfjs-layers/src/engine/container.ts b/tfjs-layers/src/engine/container.ts index c78be732e4b..a910258349f 100644 --- a/tfjs-layers/src/engine/container.ts +++ b/tfjs-layers/src/engine/container.ts @@ -36,6 +36,17 @@ export interface ContainerArgs { name?: string; } +// get weights key from tensor map in order to check if it is from keras v3. +// e.g. dense/0 +const isKerasSavedModelFormat = (weights: NamedTensorMap): boolean => { + const keys = Object.keys(weights); + if (keys.length === 0) { + return false; + } + const key = keys[0].split('/'); + return !isNaN(parseInt(key[key.length - 1], 10)); +}; + /** * A Container is a directed acyclic graph of layers. * @@ -594,11 +605,8 @@ export abstract class Container extends Layer { loadWeights(weights: NamedTensorMap, strict = true) { const nameToWeight: {[name: string]: LayerVariable} = {}; let totalWeightsCount = 0; - // get weights key from tensor map in order to check if it is from keras v3. - // e.g. dense/0 - const key = Object.keys(weights)[0].split('/'); - const isKerasSavedModelFormat = !isNaN(parseInt(key[key.length - 1], 10)); - if (isKerasSavedModelFormat) { + const modelIsKerasSavedModelFormat = isKerasSavedModelFormat(weights); + if (modelIsKerasSavedModelFormat) { this.parseWeights(weights); } // Check if weights from keras v3. @@ -606,7 +614,7 @@ export abstract class Container extends Layer { for (const [index, weight] of layer.weights.entries()) { // Parse the name to layerName/index. // e.g. dense/0, dense/1, dense_1/0, dense_1/1 - const parsedName = isKerasSavedModelFormat ? + const parsedName = modelIsKerasSavedModelFormat ? `${weight.name.split('/').slice(0, -1).join('/') + '/'}${index}` : weight.originalName; if (nameToWeight[parsedName] != null) { diff --git a/tfjs-layers/src/model_save_test.ts b/tfjs-layers/src/model_save_test.ts index dba6af3fa6e..026c0e88cc1 100644 --- a/tfjs-layers/src/model_save_test.ts +++ b/tfjs-layers/src/model_save_test.ts @@ -140,6 +140,38 @@ describeMathCPUAndWebGL2('Save-load round trips', () => { } }); + it('loadLayersModel: save and load a model with empty weights', async () => { + // https://github.com/tensorflow/tfjs/issues/7865 + // Models without weights should still be valid models + const model1 = tfl.sequential(); + model1.add( + tfl.layers.upSampling2d({ + size: [2, 2], + dataFormat: 'channelsLast', + inputShape: [null, null, 3], + }) + ); + + // Use a randomly generated model path to prevent collision. + const path = `testModel${new Date().getTime()}_${Math.random()}`; + + // First save the model to local storage. + const modelURL = `localstorage://${path}`; + await model1.save(modelURL); + // Once the saving succeeds, load the model back. + const model2 = await tfl.loadLayersModel(modelURL); + // Verify that the topology of the model is correct. + expect(model2.toJSON(null, false)).toEqual(model1.toJSON(null, false)); + + // Check the equality of the two models' weights. + const weights1 = model1.getWeights(); + const weights2 = model2.getWeights(); + expect(weights2.length).toEqual(weights1.length); + for (let i = 0; i < weights1.length; ++i) { + expectTensorsClose(weights1[i], weights2[i]); + } + }); + it('Functional model, IndexedDB', async () => { const input = tfl.input({shape: [2, 2]}); const layer1 = tfl.layers.flatten().apply(input);