Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Keras models Pickle-able #10483

Closed
wants to merge 15 commits into from

Conversation

farizrahman4u
Copy link
Contributor

@farizrahman4u farizrahman4u commented Jun 20, 2018

@maxpumperla
Copy link
Collaborator

@farizrahman4u cool, I was going to do the same thing. any chance we can do this without introducing about 95% code duplication for get_json_type etc.?

@farizrahman4u
Copy link
Contributor Author

@fchollet unrelated test seems to be failing?

@Dref360
Copy link
Contributor

Dref360 commented Jun 20, 2018

The CNTK test is a fluke, but the pep8 one is real?

/home/travis/build/keras-team/keras/keras/engine/saving.py:80:13: E126 continuation line over-indented for hanging indent
            'class_name': model.__class__.__name__,
            ^
/home/travis/build/keras-team/keras/keras/engine/saving.py:82:9: E121 continuation line under-indented for hanging indent
        }, default=_get_json_type).encode('utf8')

raise TypeError('Not JSON Serializable:', obj)


def get_model_state(model):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function would exist exclusively for pickling, so I think it should have "pickle" or some such in the name (e.g. get_model_state_for_pickling

Additionally, the code seems highly redundant with save_model / load_model. How can we refactor to minimize code duplication?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fchollet What about - we have a common get_model_state()/load_model_from_state() which will be called by pickle_model()/unpickle_model() and load_model()/save_model()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds good!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After further thought, that would make h5py write sub-optimal. We would be creating a fully copy of the model weights in memory and writing it all at once to disk. So probably leave save_model as is?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fchollet ping ping

@csbrown
Copy link

csbrown commented Jun 26, 2018

I think that it may be possible to use the hdf5 serialization and pickle at the same time. The trick is to hdf5 serialize the model in memory (which can be done, see h5py.File options driver and backing_store). Note that the save_model method can accept either a string filename or an h5py.File object. Having initialized the appropriate h5py.File object, pickle is able to serialize this file object. This can be combined with a __get_state__ and __set_state__ method on a base class to perform the appropriate conversions in preparation for pickling.

As a side-benefit of this process, there need not be another materially different serialization method.

Anybody see any pitfalls to this design?

@Dref360 Dref360 mentioned this pull request Jul 3, 2018
@fchollet
Copy link
Collaborator

Hi, I think this is a useful feature and we should merge it. What's the status on this PR? Thanks.

@farizrahman4u
Copy link
Contributor Author

farizrahman4u commented Aug 28, 2018

What's the status on this PR?

It is functional.

You had asked me to refactor it such that we have a single serialization method which would be called by model.save() and model.__getstate__(). But the H5py file is written on the go, and having a common serialization method would make it sub optimal (first a copy of the model will be created in memory and then written to disk). So I have left it as it is.

future

@fchollet
Copy link
Collaborator

# if obj is any numpy type
if type(obj).__module__ == np.__name__:
if isinstance(obj, np.ndarray):
return {'type': type(obj),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've changed this behavior recently, this will need to be updated

@fchollet
Copy link
Collaborator

any chance we can do this without introducing about 95% code duplication for get_json_type etc.?

Is there any possibility that using a shared function with a callback structure would work? In one case the callback would write to a HDF5 file, in the other case it would just build a dict.

@fchollet
Copy link
Collaborator

Or we create a dict-like HDF5 file class.

@farizrahman4u
Copy link
Contributor Author

@fchollet See #11030 with the dict-like HDF5 file class.

@fchollet
Copy link
Collaborator

Let's close this PR and move this to #11030.

@namish800
Copy link

namish800 commented Oct 9, 2018

I am having the same problem when I try to run model.to_json() my model contains custom layers and I get this error

TypeError: can't pickle _thread.RLock objects

Is there any workaround this problem?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants