-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make Keras models Pickle-able #10483
Changes from 12 commits
c1b2a6d
b7763ff
3afda31
159d627
d531af1
7ae2d9a
49c45fe
53dbad1
f73ef85
10862ad
e68ecf3
5e0466b
886bc0d
999a898
628ffe9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
|
||
import numpy as np | ||
import os | ||
import sys | ||
import json | ||
import yaml | ||
import warnings | ||
|
@@ -16,13 +17,245 @@ | |
from ..utils.io_utils import ask_to_proceed_with_overwrite | ||
from ..utils import conv_utils | ||
|
||
|
||
if sys.version_info[0] == 3: | ||
import pickle | ||
else: | ||
import cPickle as pickle | ||
|
||
try: | ||
import h5py | ||
HDF5_OBJECT_HEADER_LIMIT = 64512 | ||
except ImportError: | ||
h5py = None | ||
|
||
|
||
def _get_json_type(obj): | ||
"""Serialize any object to a JSON-serializable structure. | ||
|
||
# Arguments | ||
obj: the object to serialize | ||
|
||
# Returns | ||
JSON-serializable structure representing `obj`. | ||
|
||
# Raises | ||
TypeError: if `obj` cannot be serialized. | ||
""" | ||
# if obj is a serializable Keras class instance | ||
# e.g. optimizer, layer | ||
if hasattr(obj, 'get_config'): | ||
return {'class_name': obj.__class__.__name__, | ||
'config': obj.get_config()} | ||
|
||
# if obj is any numpy type | ||
if type(obj).__module__ == np.__name__: | ||
if isinstance(obj, np.ndarray): | ||
return {'type': type(obj), | ||
'value': obj.tolist()} | ||
else: | ||
return obj.item() | ||
|
||
# misc functions (e.g. loss function) | ||
if callable(obj): | ||
return obj.__name__ | ||
|
||
# if obj is a python 'type' | ||
if type(obj).__name__ == type.__name__: | ||
return obj.__name__ | ||
|
||
raise TypeError('Not JSON Serializable:', obj) | ||
|
||
|
||
def get_model_state(model): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function would exist exclusively for pickling, so I think it should have "pickle" or some such in the name (e.g. Additionally, the code seems highly redundant with save_model / load_model. How can we refactor to minimize code duplication? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @fchollet What about - we have a common There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That sounds good! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After further thought, that would make h5py write sub-optimal. We would be creating a fully copy of the model weights in memory and writing it all at once to disk. So probably leave There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @fchollet ping ping |
||
|
||
from .. import __version__ as keras_version | ||
|
||
state = {} | ||
state['keras_version'] = str(keras_version).encode('utf8') | ||
state['backend'] = K.backend().encode('utf8') | ||
|
||
# save model config | ||
state['model_config'] = json.dumps({ | ||
'class_name': model.__class__.__name__, | ||
'config': model.get_config() | ||
}, default=_get_json_type).encode('utf8') | ||
|
||
# save model weights | ||
layers = model.layers | ||
model_weights = {} | ||
state['model_weights'] = model_weights | ||
model_weights['layer_names'] = [layer.name.encode('utf8') for layer in layers] | ||
for layer in layers: | ||
layer_weights = {} | ||
model_weights[layer.name.encode('utf8')] = layer_weights | ||
symbolic_weights = layer.weights | ||
weight_values = K.batch_get_value(symbolic_weights) | ||
weight_names = [] | ||
for i, (w, val) in enumerate(zip(symbolic_weights, weight_values)): | ||
if hasattr(w, 'name') and w.name: | ||
name = str(w.name) | ||
else: | ||
name = 'param_' + str(i) | ||
weight_names.append(name.encode('utf8')) | ||
layer_weights['weight_names'] = weight_names[:] | ||
for name, val in zip(weight_names, weight_values): | ||
layer_weights[name] = pickle.dumps(val) | ||
|
||
if model.optimizer: | ||
if isinstance(model.optimizer, optimizers.TFOptimizer): | ||
warnings.warn( | ||
'TensorFlow optimizers do not ' | ||
'make it possible to access ' | ||
'optimizer attributes or optimizer state ' | ||
'after instantiation. ' | ||
'As a result, we cannot include the optimizer ' | ||
'as part of the model state.' | ||
'You will have to compile your model again ' | ||
'after loading it. ' | ||
'Prefer using a Keras optimizer instead ' | ||
'(see keras.io/optimizers).') | ||
else: | ||
# save optimizer config | ||
state['training_config'] = json.dumps({ | ||
'optimizer_config': { | ||
'class_name': model.optimizer.__class__.__name__, | ||
'config': model.optimizer.get_config() | ||
}, | ||
'loss': model.loss, | ||
'metrics': model.metrics, | ||
'sample_weight_mode': model.sample_weight_mode, | ||
'loss_weights': model.loss_weights, | ||
}, default=_get_json_type).encode('utf-8') | ||
|
||
# save optimizer weights | ||
symbolic_weights = getattr(model.optimizer, 'weights') | ||
if symbolic_weights: | ||
optimizer_weights = {} | ||
state['optimizer_weights'] = optimizer_weights | ||
weight_values = K.batch_get_value(symbolic_weights) | ||
weight_names = [] | ||
for i, (w, val) in enumerate(zip(symbolic_weights, | ||
weight_values)): | ||
# Default values of symbolic_weights is /variable | ||
# for Theano and CNTK | ||
if K.backend() == 'theano' or K.backend() == 'cntk': | ||
if hasattr(w, 'name'): | ||
if w.name.split('/')[-1] == 'variable': | ||
name = str(w.name) + '_' + str(i) | ||
else: | ||
name = str(w.name) | ||
else: | ||
name = 'param_' + str(i) | ||
else: | ||
if hasattr(w, 'name') and w.name: | ||
name = str(w.name) | ||
else: | ||
name = 'param_' + str(i) | ||
weight_names.append(name.encode('utf8')) | ||
optimizer_weights['weight_names'] = weight_names[:] | ||
for name, val in zip(weight_names, weight_values): | ||
optimizer_weights[name] = pickle.dumps(val) | ||
return state | ||
|
||
|
||
def load_model_from_state(state): | ||
model_config = json.loads(state['model_config'].decode('utf-8')) | ||
model = model_from_config(model_config) | ||
|
||
# set weights | ||
if 'keras_version' in state: | ||
original_keras_version = state['keras_version'].decode('utf8') | ||
else: | ||
original_keras_version = '1' | ||
if 'backend' in state: | ||
original_keras_backend = state['backend'] | ||
else: | ||
original_keras_backend = None | ||
|
||
layers = model.layers | ||
filtered_layers = [] | ||
for layer in layers: | ||
weights = layer.weights | ||
if weights: | ||
filtered_layers.append(layer) | ||
|
||
model_weights = state['model_weights'] | ||
layer_names = model_weights['layer_names'] | ||
filtered_layer_names = [] | ||
for name in layer_names: | ||
weight_names = model_weights[name]['weight_names'] | ||
if weight_names: | ||
filtered_layer_names.append(name) | ||
layer_names = filtered_layer_names | ||
|
||
if len(layer_names) != len(filtered_layers): | ||
raise ValueError('You are trying to load a weight file ' | ||
'containing ' + str(len(layer_names)) + | ||
' layers into a model with ' + | ||
str(len(filtered_layers)) + ' layers.') | ||
|
||
weight_value_tuples = [] | ||
for k, name in enumerate(layer_names): | ||
layer_weights = model_weights[name] | ||
weight_names = layer_weights['weight_names'] | ||
weight_values = [pickle.loads(layer_weights[name]) for name in weight_names] | ||
layer = filtered_layers[k] | ||
symbolic_weights = layer.weights | ||
weight_values = preprocess_weights_for_loading(layer, | ||
weight_values, | ||
original_keras_version, | ||
original_keras_backend, | ||
reshape=False) | ||
if len(weight_values) != len(symbolic_weights): | ||
raise ValueError('Layer #' + str(k) + | ||
' (named "' + layer.name + | ||
'" in the current model) was found to ' | ||
'correspond to layer ' + name + | ||
' in the provided state. ' | ||
'However the new layer ' + layer.name + | ||
' expects ' + str(len(symbolic_weights)) + | ||
' weights, but the saved weights have ' + | ||
str(len(weight_values)) + | ||
' elements.') | ||
weight_value_tuples += zip(symbolic_weights, weight_values) | ||
K.batch_set_value(weight_value_tuples) | ||
|
||
# load optimizer | ||
training_config = state.get('training_config', None) | ||
if training_config: | ||
training_config = json.loads(training_config.decode('utf-8')) | ||
optimizer_config = training_config['optimizer_config'] | ||
optimizer = optimizers.deserialize(optimizer_config) | ||
loss = training_config['loss'] | ||
metrics = training_config['metrics'] | ||
sample_weight_mode = training_config['sample_weight_mode'] | ||
loss_weights = training_config['loss_weights'] | ||
# compile model | ||
if loss: | ||
model.compile(optimizer=optimizer, | ||
loss=loss, | ||
metrics=metrics, | ||
loss_weights=loss_weights, | ||
sample_weight_mode=sample_weight_mode) | ||
|
||
# load optimizer weights | ||
optimizer_weights = state.get('optimizer_weights') | ||
if optimizer_weights: | ||
# Build train function (to get weight updates). | ||
model._make_train_function() | ||
optimizer_weight_values = [pickle.loads(optimizer_weights[n]) for n in | ||
optimizer_weights['weight_names']] | ||
try: | ||
model.optimizer.set_weights(optimizer_weight_values) | ||
except ValueError: | ||
warnings.warn('Error in loading the saved optimizer ' | ||
'state. As a result, your model is ' | ||
'starting with a freshly initialized ' | ||
'optimizer.') | ||
return model | ||
|
||
|
||
def save_model(model, filepath, overwrite=True, include_optimizer=True): | ||
"""Save a model to a HDF5 file. | ||
|
||
|
@@ -58,42 +291,6 @@ def save_model(model, filepath, overwrite=True, include_optimizer=True): | |
if h5py is None: | ||
raise ImportError('`save_model` requires h5py.') | ||
|
||
def get_json_type(obj): | ||
"""Serialize any object to a JSON-serializable structure. | ||
|
||
# Arguments | ||
obj: the object to serialize | ||
|
||
# Returns | ||
JSON-serializable structure representing `obj`. | ||
|
||
# Raises | ||
TypeError: if `obj` cannot be serialized. | ||
""" | ||
# if obj is a serializable Keras class instance | ||
# e.g. optimizer, layer | ||
if hasattr(obj, 'get_config'): | ||
return {'class_name': obj.__class__.__name__, | ||
'config': obj.get_config()} | ||
|
||
# if obj is any numpy type | ||
if type(obj).__module__ == np.__name__: | ||
if isinstance(obj, np.ndarray): | ||
return {'type': type(obj), | ||
'value': obj.tolist()} | ||
else: | ||
return obj.item() | ||
|
||
# misc functions (e.g. loss function) | ||
if callable(obj): | ||
return obj.__name__ | ||
|
||
# if obj is a python 'type' | ||
if type(obj).__name__ == type.__name__: | ||
return obj.__name__ | ||
|
||
raise TypeError('Not JSON Serializable:', obj) | ||
|
||
from .. import __version__ as keras_version | ||
|
||
if not isinstance(filepath, h5py.File): | ||
|
@@ -115,7 +312,7 @@ def get_json_type(obj): | |
f.attrs['model_config'] = json.dumps({ | ||
'class_name': model.__class__.__name__, | ||
'config': model.get_config() | ||
}, default=get_json_type).encode('utf8') | ||
}, default=_get_json_type).encode('utf8') | ||
|
||
model_weights_group = f.create_group('model_weights') | ||
model_layers = model.layers | ||
|
@@ -144,7 +341,7 @@ def get_json_type(obj): | |
'metrics': model.metrics, | ||
'sample_weight_mode': model.sample_weight_mode, | ||
'loss_weights': model.loss_weights, | ||
}, default=get_json_type).encode('utf8') | ||
}, default=_get_json_type).encode('utf8') | ||
|
||
# Save optimizer weights. | ||
symbolic_weights = getattr(model.optimizer, 'weights') | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We've changed this behavior recently, this will need to be updated