Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Keras models pickle-able (refactored) #11030

Merged
merged 14 commits into from
Sep 11, 2018

Conversation

farizrahman4u
Copy link
Contributor

No description provided.

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it looks better with this common logic. The H5Dict class seems quite complex, however.



# FLAGS
_is_boxed = '_is_dict_boxed'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this for? Non-obvious from the name...

@@ -0,0 +1,197 @@
"""HDF5 related utilities."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't we just put this in io_utils.py?

self._is_file = False
self.data[_is_boxed] = True
else:
raise Exception('Required Group, str or dict.'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Never raise a generic Exception, always use a specific exception (TypeError in this case).


def __setitem__(self, attr, val):
if self.read_only:
raise Exception('Can not set item in read only mode.')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Never raise a generic Exception

if self.read_only:
raise Exception('Can not set item in read only mode.')
is_np = type(val).__module__ == np.__name__
if type(self.data) is dict:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not use type(...) is, use isinstance (this applies to multiple lines in this PR)

attr = attr.decode('utf-8')
if is_np:
self.data[attr] = pickle.dumps(val)
# We have to remeber to unpickle in __getitem__
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo

self.data[attr] = val
return
if attr in self:
raise Exception('Can not set attribute.'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Never raise a generic Exception

used for model definition or training.
This method is used for both writing to HDF5 file/group,
as well as pickling. This is achieved via a
keras.utils.hdf5_utls.H5Dict object, which can wrap HDF5
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use backquotes around code keywords



class H5Dict(object):
""" A dict-like wrapper around h5py groups (or dicts).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove leading space

if isinstance(path, h5py.Group):
self.data = path
self._is_file = False
elif type(path) is str:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not use type(...) is, use isinstance. To test for string types, use six

@fchollet
Copy link
Collaborator

Please also add a quick standalone unit test for the H5Dict class.

def load_model(filepath, custom_objects=None, compile=True):
"""Loads a model saved via `save_model`.
def _deserialize_model(f, custom_objects=None, compile=True):
"""De-serializes a model a serialized via _serialize_model
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Reminder to self)
Typo

weight_names.append(name.encode('utf8'))
layer_group['weight_names'] = weight_names
for name, val in zip(weight_names, weight_values):
layer_group[name] = val
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forloop 101-102 can combine with forloop 94-99

name = str(w.name) + '_' + str(i)
else:
name = str(w.name)
name = 'param_' + str(i)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we don't make constraint on naming, is it possible to have a duplicate name here? (Not sure if weight name need to be unique)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do share the same concern.

raise ValueError('You are trying to load a weight file '
'containing ' + str(len(layer_names)) +
' layers into a model with ' +
str(len(filtered_layers)) + ' layers.')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest to use % style

        raise ValueError('You are trying to load a weight file '
                         'containing %d layers into a model with '
                         '%d layers.' % (len(layer_names), len(filtered_layers)))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have a unified way, but I think str.format is more common?

if weight_names:
filtered_layer_names.append(name)

layer_names = filtered_layer_names
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think keep using filtered_layer_names makes code more readable.


# Expecting this to never be true.
if len(bad_attributes) > 0:
raise RuntimeError('The following attributes cannot be saved to HDF5 '
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that some of your codes exceed 85 characters

raise RuntimeError('The following attributes cannot be saved to HDF5 '
'file because they are larger than %d bytes: %s'
% (HDF5_OBJECT_HEADER_LIMIT,
', '.join([x for x in bad_attributes])))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant list comprehension

data_npy = np.asarray(val)

num_chunks = 1
chunked_data = np.array_split(data_npy, num_chunks)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have a better initial num_chunks?
For example, math.ceil(data_npy.nbytes / HDF5_OBJECT_HEADER_LIMIT)

@farizrahman4u
Copy link
Contributor Author

@fchollet Done.

ValueError: In case of an invalid savefile.
"""
if h5py is None:
raise ImportError('`load_model` requires h5py.')
Copy link
Contributor

@gabrieldemarmiesse gabrieldemarmiesse Sep 3, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be considered dead code since h5py is in the keras dependencies in the setup.py.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's good to have anyway. Keras could have been installed with--no-deps, etc.

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, thank you. If someone else can take a quick look too, that would be great, as this is a large PR.

@farizrahman4u
Copy link
Contributor Author

@Dref360 Quick review please?

Copy link
Contributor

@Dref360 Dref360 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minors comments concerning documentation.
A minor concern that may bite us later on, is that now that a Model is pickable, users will try to share it between processes which is not handled by TF.

raise ValueError('You are trying to load a weight file '
'containing ' + str(len(layer_names)) +
' layers into a model with ' +
str(len(filtered_layers)) + ' layers.')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have a unified way, but I think str.format is more common?

elif isinstance(path, dict):
self.data = path
self._is_file = False
self.data['_is_group'] = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure that I get what this flag is, maybe add some quick doc?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whether a dict is user data (f[attr] = {'x' : 'y'}) or a group (new_group = f['new_group']).

This allows us to have a single serialization logic
for both pickling and saving to disk.

Note: This is not intended to be a generic wrapper.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having those documented would be nice, maybe a Wiki page?

return iter(list(self.data))

def iter(self):
return list(self.data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.data could get quite big, shouldn't we just do 'self.data.items()' to be more efficient?

name = str(w.name) + '_' + str(i)
else:
name = str(w.name)
name = 'param_' + str(i)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do share the same concern.

@farizrahman4u
Copy link
Contributor Author

A minor concern that may bite us later on, is that now that a Model is pickable, users will try to share it between processes which is not handled by TF.

Can you expand on this?

@Dref360
Copy link
Contributor

Dref360 commented Sep 4, 2018

If an object is pickable, it's possible to send it to another Process.
The issue here is that a TF Session cannot get shared between processes. This is not handled currently and it shouldn't. It's just something that we should be aware of.

But, the user can try doing something like:

  • Creates a Model in Process 1
  • Send the model to Process 2
  • model.predict won't work on Process 2 because the tf.Session is bounded to Process 1.

@fchollet
Copy link
Collaborator

fchollet commented Sep 5, 2018

@Dref360 I believe such a use case would result in an exception, right? As long as there is a clear exception (and not undetermined behavior), we are fine from a UX standpoint.

@Dref360
Copy link
Contributor

Dref360 commented Sep 6, 2018

It seems to work nowadays? Maybe it's new from tensorflow 1.10.

from keras import Model
from keras.layers import Input, Dense
import numpy as np
import multiprocessing as mp

inp = Input([10])
d = Dense(1)(inp)
mod = Model(inp, d)
mod.compile('sgd', 'mse')


def send():
    return mod.predict(np.zeros([10,10]))

pool = mp.Pool(5)
promises = [pool.apply_async(send, ) for _ in range(5)]
print([k.get().shape for k in promises])

@farizrahman4u
Copy link
Contributor Author

@fchollet @Dref360

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the changes!

@fchollet fchollet merged commit 9a4c5d8 into keras-team:master Sep 11, 2018
@farizrahman4u farizrahman4u deleted the pickle2 branch September 12, 2018 01:37
@datumbox
Copy link
Contributor

I believe this commit introduces a regression on model.save().

I have a small-sized CNN with the following stats:

Float OPS: 15,635,220
Parameters: 3,131,224
Layers: 76
Conv Depth: 21

If I install keras from the commit before this PR (5a6af4b), the save method works fine. Nevertheless by installing this PR I get the following error:

>>> model.save('model.h5')
Traceback (most recent call last):
  File "test.py", line 149, in <module>
    model.save('deleteme.h5')
  File "/usr/local/lib/python2.7/dist-packages/keras/engine/network.py", line 1087, in save
    save_model(self, filepath, overwrite, include_optimizer)
  File "/usr/local/lib/python2.7/dist-packages/keras/engine/saving.py", line 381, in save_model
    _serialize_model(model, f, include_optimizer)
  File "/usr/local/lib/python2.7/dist-packages/keras/engine/saving.py", line 78, in _serialize_model
    f['keras_version'] = str(keras_version).encode('utf8')
  File "/usr/local/lib/python2.7/dist-packages/keras/utils/io_utils.py", line 214, in __setitem__
    ' Group with name {} exists.'.format(attr))
KeyError: 'Can not set attribute. Group with name keras_version exists.'

Same problem appears on latest master. It would be good to fix it before we release a new version.

@farizrahman4u let me know if you want me to investigate further.

cc @fchollet @Dref360

@farizrahman4u
Copy link
Contributor Author

@datumbox Thanks for the report. Fixed in #11289.

@datumbox
Copy link
Contributor

datumbox commented Oct 3, 2018

@farizrahman4u awesome, thanks for the quick fix. I'll try it out on Friday.

@marawanokasha
Copy link

Hey, quick question, If I am using a custom loss function (or custom layer), is there a way to pass those custom objects when unpickling the model, or do I have to fallback to keras.models.load_model?

@farizrahman4u
Copy link
Contributor Author

As long as the custom classes are available in the scope in which you are unpicking, you are good.

@hartikainen
Copy link

Are there plans to have this merged into tf.keras at some point? If so, when would it happen?

@ryanjulian
Copy link

I'm also curious if this will make it into the tf.keras tree

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants