Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Falied to load gru layer weights to gru cell, Layer 'gru_cell' expected 3 variables, but received 0 variables during loading #20407

Closed
victorVoice opened this issue Oct 25, 2024 · 7 comments

Comments

@victorVoice
Copy link

Using tensorflow 2.16.1 with keras 3.5.0 falied to load pretrained gru layers weights to a gru cell.
the tow layer are defined as below

For gru layers:
t_rnn_1 = keras.layers.GRU(units=64, return_sequences=True)(t_in_1)
t_rnn_2 = keras.layers.GRU(units=64, return_sequences=True)(t_rnn_1)
t_dense_c = keras.layers.Dense(80)(t_rnn_2)
t_dense_c = tf.keras.layers.ReLU(max_value=6.)(t_dense_c)

For gru cells:
t_rnn_1, cell_out1 = keras.layers.GRUCell(units=64)(t_in_1, states=cell_in1)
t_rnn_2, cell_out2 = keras.layers.GRUCell(units=64)(t_rnn_1, states=cell_in2)
t_dense_2= keras.layers.Dense(80)(t_rnn_2)
t_dense_2 = tf.keras.layers.ReLU(max_value=6.)(t_dense_2)

when loading got flowing error message

Traceback (most recent call last):
File "/home/victoryu/project/se_tf/subband_model_streaming.py", line 319, in
tf_model.create_tf_lite_model(weights_file=args.ckpt, target_name='./crn_cplx')
File "/home/victoryu/project/se_tf/subband_model_streaming.py", line 128, in create_tf_lite_model
self.model.load_weights(weights_file)
File "/home/victoryu/miniconda3/envs/tf2.16/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/home/victoryu/miniconda3/envs/tf2.16/lib/python3.11/site-packages/keras/src/saving/saving_lib.py", line 593, in _raise_loading_failure
raise ValueError(msg)
ValueError: A total of 2 objects could not be loaded. Example error message for object :

Layer 'gru_cell' expected 3 variables, but received 0 variables during loading. Expected: ['kernel', 'recurrent_kernel', 'bias']

It works fine when using tensorflow 2.13.0 + keras 2.13.1.

When visulize the weight.h5 file the differnce between to layers are as below

image

image

wonder is the cause the problem, and how to fix it.

@mehtamansi29
Copy link
Collaborator

Hi @victorVoice -

Thanks for reporting the issue. Can you help me what you defined here t_in_1 or any full sample code for both gru or gru_cell layer ?

@victorVoice
Copy link
Author

victorVoice commented Oct 25, 2024

@mehtamansi29 Sure,
t_in_1 for gru layer is a tensor with shape [batchsize, time_steps, feature_dims] in the acture model is like a tensor with shape [32, 63, 80]
for gru_cell layer, since it process 1 time step at each time the t_in_1 here is [batch_siz,feature_dims], like [32, 80]

here is some sample code for gru cell layer

inp = keras.Input(batch_shape=(1, 5, 16))
cell_in1 = keras.Input(batch_shape=(1, 64))
cell_in2 = keras.Input(batch_shape=(1, 64))

t_in_1 = keras.layers.Reshape([5 * 16])(inp)
t_rnn_1, cell_out1 = keras.layers.GRUCell(units=64)(t_in_1, states=cell_in1)
t_rnn_2, cell_out2 = keras.layers.GRUCell(units=64)(t_rnn_1, states=cell_in2)
t_dense_2= keras.layers.Dense(80)(t_rnn_2)
t_dense_2 = tf.keras.layers.ReLU(max_value=6.)(t_dense_2)
s2_out = keras.layers.Reshape([1, 5, 16])(t_dense_2)

here is the code for gru layers

inp = keras.Input(batch_shape=(32, 63, 5, 16))

t_in_1 = keras.layers.Reshape([64, 5 * 16])(inp)
t_rnn_1 =  keras.layers.GRU(units=64, return_sequences=True)(t_in_1)
t_rnn_2 =  keras.layers.GRU(units=64, return_sequences=True)(t_rnn_1)
t_dense_2= keras.layers.Dense(80)(t_rnn_2)
t_dense_2 = tf.keras.layers.ReLU(max_value=6.)(t_dense_2)
s2_out = keras.layers.Reshape([1, 63, 5, 16])(t_dense_2)

@mehtamansi29
Copy link
Collaborator

Hi @victorVoice -

Thanks for the sample code. I replicate the sample code with GRU layer or GRU_cell in latest keras(3.6.0) and it is working fine for me.
Attached gist for the reference.

@victorVoice
Copy link
Author

@mehtamansi29 Thanks i will try keras(3.6.0) first, thx for the help.

Copy link

github-actions bot commented Dec 4, 2024

This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.

@github-actions github-actions bot added the stale label Dec 4, 2024
Copy link

This issue was closed because it has been inactive for 28 days. Please reopen if you'd like to work on this further.

Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants