Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Keras models Pickle-able #10483

Closed
wants to merge 15 commits into from
7 changes: 7 additions & 0 deletions keras/engine/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -1180,6 +1180,13 @@ def load_weights(self, filepath, by_name=False,
saving.load_weights_from_hdf5_group(
f, self.layers, reshape=reshape)

def __getstate__(self):
return saving.get_model_state(self)

def __setstate__(self, state):
model = saving.load_model_from_state(state)
self.__dict__ = model.__dict__

def _updated_config(self):
"""Util hared between different serialization methods.

Expand Down
273 changes: 235 additions & 38 deletions keras/engine/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np
import os
import sys
import json
import yaml
import warnings
Expand All @@ -16,13 +17,245 @@
from ..utils.io_utils import ask_to_proceed_with_overwrite
from ..utils import conv_utils


if sys.version_info[0] == 3:
import pickle
else:
import cPickle as pickle

try:
import h5py
HDF5_OBJECT_HEADER_LIMIT = 64512
except ImportError:
h5py = None


def _get_json_type(obj):
"""Serialize any object to a JSON-serializable structure.

# Arguments
obj: the object to serialize

# Returns
JSON-serializable structure representing `obj`.

# Raises
TypeError: if `obj` cannot be serialized.
"""
# if obj is a serializable Keras class instance
# e.g. optimizer, layer
if hasattr(obj, 'get_config'):
return {'class_name': obj.__class__.__name__,
'config': obj.get_config()}

# if obj is any numpy type
if type(obj).__module__ == np.__name__:
if isinstance(obj, np.ndarray):
return {'type': type(obj),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've changed this behavior recently, this will need to be updated

'value': obj.tolist()}
else:
return obj.item()

# misc functions (e.g. loss function)
if callable(obj):
return obj.__name__

# if obj is a python 'type'
if type(obj).__name__ == type.__name__:
return obj.__name__

raise TypeError('Not JSON Serializable:', obj)


def get_model_state(model):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function would exist exclusively for pickling, so I think it should have "pickle" or some such in the name (e.g. get_model_state_for_pickling

Additionally, the code seems highly redundant with save_model / load_model. How can we refactor to minimize code duplication?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fchollet What about - we have a common get_model_state()/load_model_from_state() which will be called by pickle_model()/unpickle_model() and load_model()/save_model()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds good!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After further thought, that would make h5py write sub-optimal. We would be creating a fully copy of the model weights in memory and writing it all at once to disk. So probably leave save_model as is?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fchollet ping ping


from .. import __version__ as keras_version

state = {}
state['keras_version'] = str(keras_version).encode('utf8')
state['backend'] = K.backend().encode('utf8')

# save model config
state['model_config'] = json.dumps({
'class_name': model.__class__.__name__,
'config': model.get_config()
}, default=_get_json_type).encode('utf8')

# save model weights
layers = model.layers
model_weights = {}
state['model_weights'] = model_weights
model_weights['layer_names'] = [layer.name.encode('utf8') for layer in layers]
for layer in layers:
layer_weights = {}
model_weights[layer.name.encode('utf8')] = layer_weights
symbolic_weights = layer.weights
weight_values = K.batch_get_value(symbolic_weights)
weight_names = []
for i, (w, val) in enumerate(zip(symbolic_weights, weight_values)):
if hasattr(w, 'name') and w.name:
name = str(w.name)
else:
name = 'param_' + str(i)
weight_names.append(name.encode('utf8'))
layer_weights['weight_names'] = weight_names[:]
for name, val in zip(weight_names, weight_values):
layer_weights[name] = pickle.dumps(val)

if model.optimizer:
if isinstance(model.optimizer, optimizers.TFOptimizer):
warnings.warn(
'TensorFlow optimizers do not '
'make it possible to access '
'optimizer attributes or optimizer state '
'after instantiation. '
'As a result, we cannot include the optimizer '
'as part of the model state.'
'You will have to compile your model again '
'after loading it. '
'Prefer using a Keras optimizer instead '
'(see keras.io/optimizers).')
else:
# save optimizer config
state['training_config'] = json.dumps({
'optimizer_config': {
'class_name': model.optimizer.__class__.__name__,
'config': model.optimizer.get_config()
},
'loss': model.loss,
'metrics': model.metrics,
'sample_weight_mode': model.sample_weight_mode,
'loss_weights': model.loss_weights,
}, default=_get_json_type).encode('utf-8')

# save optimizer weights
symbolic_weights = getattr(model.optimizer, 'weights')
if symbolic_weights:
optimizer_weights = {}
state['optimizer_weights'] = optimizer_weights
weight_values = K.batch_get_value(symbolic_weights)
weight_names = []
for i, (w, val) in enumerate(zip(symbolic_weights,
weight_values)):
# Default values of symbolic_weights is /variable
# for Theano and CNTK
if K.backend() == 'theano' or K.backend() == 'cntk':
if hasattr(w, 'name'):
if w.name.split('/')[-1] == 'variable':
name = str(w.name) + '_' + str(i)
else:
name = str(w.name)
else:
name = 'param_' + str(i)
else:
if hasattr(w, 'name') and w.name:
name = str(w.name)
else:
name = 'param_' + str(i)
weight_names.append(name.encode('utf8'))
optimizer_weights['weight_names'] = weight_names[:]
for name, val in zip(weight_names, weight_values):
optimizer_weights[name] = pickle.dumps(val)
return state


def load_model_from_state(state):
model_config = json.loads(state['model_config'].decode('utf-8'))
model = model_from_config(model_config)

# set weights
if 'keras_version' in state:
original_keras_version = state['keras_version'].decode('utf8')
else:
original_keras_version = '1'
if 'backend' in state:
original_keras_backend = state['backend']
else:
original_keras_backend = None

layers = model.layers
filtered_layers = []
for layer in layers:
weights = layer.weights
if weights:
filtered_layers.append(layer)

model_weights = state['model_weights']
layer_names = model_weights['layer_names']
filtered_layer_names = []
for name in layer_names:
weight_names = model_weights[name]['weight_names']
if weight_names:
filtered_layer_names.append(name)
layer_names = filtered_layer_names

if len(layer_names) != len(filtered_layers):
raise ValueError('You are trying to load a weight file '
'containing ' + str(len(layer_names)) +
' layers into a model with ' +
str(len(filtered_layers)) + ' layers.')

weight_value_tuples = []
for k, name in enumerate(layer_names):
layer_weights = model_weights[name]
weight_names = layer_weights['weight_names']
weight_values = [pickle.loads(layer_weights[name]) for name in weight_names]
layer = filtered_layers[k]
symbolic_weights = layer.weights
weight_values = preprocess_weights_for_loading(layer,
weight_values,
original_keras_version,
original_keras_backend,
reshape=False)
if len(weight_values) != len(symbolic_weights):
raise ValueError('Layer #' + str(k) +
' (named "' + layer.name +
'" in the current model) was found to '
'correspond to layer ' + name +
' in the provided state. '
'However the new layer ' + layer.name +
' expects ' + str(len(symbolic_weights)) +
' weights, but the saved weights have ' +
str(len(weight_values)) +
' elements.')
weight_value_tuples += zip(symbolic_weights, weight_values)
K.batch_set_value(weight_value_tuples)

# load optimizer
training_config = state.get('training_config', None)
if training_config:
training_config = json.loads(training_config.decode('utf-8'))
optimizer_config = training_config['optimizer_config']
optimizer = optimizers.deserialize(optimizer_config)
loss = training_config['loss']
metrics = training_config['metrics']
sample_weight_mode = training_config['sample_weight_mode']
loss_weights = training_config['loss_weights']
# compile model
if loss:
model.compile(optimizer=optimizer,
loss=loss,
metrics=metrics,
loss_weights=loss_weights,
sample_weight_mode=sample_weight_mode)

# load optimizer weights
optimizer_weights = state.get('optimizer_weights')
if optimizer_weights:
# Build train function (to get weight updates).
model._make_train_function()
optimizer_weight_values = [pickle.loads(optimizer_weights[n]) for n in
optimizer_weights['weight_names']]
try:
model.optimizer.set_weights(optimizer_weight_values)
except ValueError:
warnings.warn('Error in loading the saved optimizer '
'state. As a result, your model is '
'starting with a freshly initialized '
'optimizer.')
return model


def save_model(model, filepath, overwrite=True, include_optimizer=True):
"""Save a model to a HDF5 file.

Expand Down Expand Up @@ -58,42 +291,6 @@ def save_model(model, filepath, overwrite=True, include_optimizer=True):
if h5py is None:
raise ImportError('`save_model` requires h5py.')

def get_json_type(obj):
"""Serialize any object to a JSON-serializable structure.

# Arguments
obj: the object to serialize

# Returns
JSON-serializable structure representing `obj`.

# Raises
TypeError: if `obj` cannot be serialized.
"""
# if obj is a serializable Keras class instance
# e.g. optimizer, layer
if hasattr(obj, 'get_config'):
return {'class_name': obj.__class__.__name__,
'config': obj.get_config()}

# if obj is any numpy type
if type(obj).__module__ == np.__name__:
if isinstance(obj, np.ndarray):
return {'type': type(obj),
'value': obj.tolist()}
else:
return obj.item()

# misc functions (e.g. loss function)
if callable(obj):
return obj.__name__

# if obj is a python 'type'
if type(obj).__name__ == type.__name__:
return obj.__name__

raise TypeError('Not JSON Serializable:', obj)

from .. import __version__ as keras_version

if not isinstance(filepath, h5py.File):
Expand All @@ -115,7 +312,7 @@ def get_json_type(obj):
f.attrs['model_config'] = json.dumps({
'class_name': model.__class__.__name__,
'config': model.get_config()
}, default=get_json_type).encode('utf8')
}, default=_get_json_type).encode('utf8')

model_weights_group = f.create_group('model_weights')
model_layers = model.layers
Expand Down Expand Up @@ -144,7 +341,7 @@ def get_json_type(obj):
'metrics': model.metrics,
'sample_weight_mode': model.sample_weight_mode,
'loss_weights': model.loss_weights,
}, default=get_json_type).encode('utf8')
}, default=_get_json_type).encode('utf8')

# Save optimizer weights.
symbolic_weights = getattr(model.optimizer, 'weights')
Expand Down
Loading