-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make Keras models pickle-able (refactored) #11030
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it looks better with this common logic. The H5Dict class seems quite complex, however.
keras/utils/hdf5_utils.py
Outdated
|
||
|
||
# FLAGS | ||
_is_boxed = '_is_dict_boxed' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this for? Non-obvious from the name...
keras/utils/hdf5_utils.py
Outdated
@@ -0,0 +1,197 @@ | |||
"""HDF5 related utilities.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couldn't we just put this in io_utils.py
?
keras/utils/hdf5_utils.py
Outdated
self._is_file = False | ||
self.data[_is_boxed] = True | ||
else: | ||
raise Exception('Required Group, str or dict.' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Never raise a generic Exception, always use a specific exception (TypeError in this case).
keras/utils/hdf5_utils.py
Outdated
|
||
def __setitem__(self, attr, val): | ||
if self.read_only: | ||
raise Exception('Can not set item in read only mode.') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Never raise a generic Exception
keras/utils/hdf5_utils.py
Outdated
if self.read_only: | ||
raise Exception('Can not set item in read only mode.') | ||
is_np = type(val).__module__ == np.__name__ | ||
if type(self.data) is dict: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do not use type(...) is
, use isinstance
(this applies to multiple lines in this PR)
keras/utils/hdf5_utils.py
Outdated
attr = attr.decode('utf-8') | ||
if is_np: | ||
self.data[attr] = pickle.dumps(val) | ||
# We have to remeber to unpickle in __getitem__ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo
keras/utils/hdf5_utils.py
Outdated
self.data[attr] = val | ||
return | ||
if attr in self: | ||
raise Exception('Can not set attribute.' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Never raise a generic Exception
keras/engine/saving.py
Outdated
used for model definition or training. | ||
This method is used for both writing to HDF5 file/group, | ||
as well as pickling. This is achieved via a | ||
keras.utils.hdf5_utls.H5Dict object, which can wrap HDF5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use backquotes around code keywords
keras/utils/hdf5_utils.py
Outdated
|
||
|
||
class H5Dict(object): | ||
""" A dict-like wrapper around h5py groups (or dicts). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove leading space
keras/utils/hdf5_utils.py
Outdated
if isinstance(path, h5py.Group): | ||
self.data = path | ||
self._is_file = False | ||
elif type(path) is str: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do not use type(...) is
, use isinstance
. To test for string types, use six
Please also add a quick standalone unit test for the H5Dict class. |
keras/engine/saving.py
Outdated
def load_model(filepath, custom_objects=None, compile=True): | ||
"""Loads a model saved via `save_model`. | ||
def _deserialize_model(f, custom_objects=None, compile=True): | ||
"""De-serializes a model a serialized via _serialize_model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Reminder to self)
Typo
keras/engine/saving.py
Outdated
weight_names.append(name.encode('utf8')) | ||
layer_group['weight_names'] = weight_names | ||
for name, val in zip(weight_names, weight_values): | ||
layer_group[name] = val |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
forloop 101-102 can combine with forloop 94-99
name = str(w.name) + '_' + str(i) | ||
else: | ||
name = str(w.name) | ||
name = 'param_' + str(i) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we don't make constraint on naming, is it possible to have a duplicate name here? (Not sure if weight name need to be unique)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do share the same concern.
keras/engine/saving.py
Outdated
raise ValueError('You are trying to load a weight file ' | ||
'containing ' + str(len(layer_names)) + | ||
' layers into a model with ' + | ||
str(len(filtered_layers)) + ' layers.') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggest to use %
style
raise ValueError('You are trying to load a weight file '
'containing %d layers into a model with '
'%d layers.' % (len(layer_names), len(filtered_layers)))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't have a unified way, but I think str.format is more common?
if weight_names: | ||
filtered_layer_names.append(name) | ||
|
||
layer_names = filtered_layer_names |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think keep using filtered_layer_names
makes code more readable.
keras/utils/hdf5_utils.py
Outdated
|
||
# Expecting this to never be true. | ||
if len(bad_attributes) > 0: | ||
raise RuntimeError('The following attributes cannot be saved to HDF5 ' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that some of your codes exceed 85 characters
keras/utils/hdf5_utils.py
Outdated
raise RuntimeError('The following attributes cannot be saved to HDF5 ' | ||
'file because they are larger than %d bytes: %s' | ||
% (HDF5_OBJECT_HEADER_LIMIT, | ||
', '.join([x for x in bad_attributes]))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Redundant list comprehension
keras/utils/hdf5_utils.py
Outdated
data_npy = np.asarray(val) | ||
|
||
num_chunks = 1 | ||
chunked_data = np.array_split(data_npy, num_chunks) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we have a better initial num_chunks
?
For example, math.ceil(data_npy.nbytes / HDF5_OBJECT_HEADER_LIMIT)
@fchollet Done. |
ValueError: In case of an invalid savefile. | ||
""" | ||
if h5py is None: | ||
raise ImportError('`load_model` requires h5py.') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be considered dead code since h5py
is in the keras dependencies in the setup.py
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's good to have anyway. Keras could have been installed with--no-deps
, etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, thank you. If someone else can take a quick look too, that would be great, as this is a large PR.
@Dref360 Quick review please? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minors comments concerning documentation.
A minor concern that may bite us later on, is that now that a Model is pickable, users will try to share it between processes which is not handled by TF.
keras/engine/saving.py
Outdated
raise ValueError('You are trying to load a weight file ' | ||
'containing ' + str(len(layer_names)) + | ||
' layers into a model with ' + | ||
str(len(filtered_layers)) + ' layers.') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't have a unified way, but I think str.format is more common?
keras/utils/io_utils.py
Outdated
elif isinstance(path, dict): | ||
self.data = path | ||
self._is_file = False | ||
self.data['_is_group'] = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure that I get what this flag is, maybe add some quick doc?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whether a dict is user data (f[attr] = {'x' : 'y'}
) or a group (new_group = f['new_group']
).
This allows us to have a single serialization logic | ||
for both pickling and saving to disk. | ||
|
||
Note: This is not intended to be a generic wrapper. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having those documented would be nice, maybe a Wiki page?
keras/utils/io_utils.py
Outdated
return iter(list(self.data)) | ||
|
||
def iter(self): | ||
return list(self.data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.data could get quite big, shouldn't we just do 'self.data.items()' to be more efficient?
name = str(w.name) + '_' + str(i) | ||
else: | ||
name = str(w.name) | ||
name = 'param_' + str(i) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do share the same concern.
Can you expand on this? |
If an object is pickable, it's possible to send it to another Process. But, the user can try doing something like:
|
@Dref360 I believe such a use case would result in an exception, right? As long as there is a clear exception (and not undetermined behavior), we are fine from a UX standpoint. |
It seems to work nowadays? Maybe it's new from tensorflow 1.10. from keras import Model
from keras.layers import Input, Dense
import numpy as np
import multiprocessing as mp
inp = Input([10])
d = Dense(1)(inp)
mod = Model(inp, d)
mod.compile('sgd', 'mse')
def send():
return mod.predict(np.zeros([10,10]))
pool = mp.Pool(5)
promises = [pool.apply_async(send, ) for _ in range(5)]
print([k.get().shape for k in promises]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the changes!
I believe this commit introduces a regression on I have a small-sized CNN with the following stats:
If I install keras from the commit before this PR (5a6af4b), the save method works fine. Nevertheless by installing this PR I get the following error:
Same problem appears on latest master. It would be good to fix it before we release a new version. @farizrahman4u let me know if you want me to investigate further. |
@farizrahman4u awesome, thanks for the quick fix. I'll try it out on Friday. |
Hey, quick question, If I am using a custom loss function (or custom layer), is there a way to pass those custom objects when unpickling the model, or do I have to fallback to |
As long as the custom classes are available in the scope in which you are unpicking, you are good. |
Are there plans to have this merged into tf.keras at some point? If so, when would it happen? |
I'm also curious if this will make it into the tf.keras tree |
No description provided.