Skip to content

Commit

Permalink
Save load/save pickle protocol (#31044)
Browse files Browse the repository at this point in the history
* add default argument  for paddle.save/static.save

* edit documentation of

* Add comments for special processing for protocol=2 and protocol=3.

* Update python/paddle/fluid/io.py

Co-authored-by: lanxianghit <47554610+lanxianghit@users.noreply.github.com>

Co-authored-by: lanxianghit <47554610+lanxianghit@users.noreply.github.com>
  • Loading branch information
hbwx24 and lanxianghit committed Feb 23, 2021
1 parent 29543da commit 6d7ca4c
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 39 deletions.
86 changes: 52 additions & 34 deletions python/paddle/fluid/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -1711,27 +1711,31 @@ def _exist(var):
load_vars(executor=executor, dirname=dirname, vars=var_list)


def _unpack_saved_dict(saved_obj):
def _unpack_saved_dict(saved_obj, protocol):
temp_saved_obj = {}
unpack_infor = {}
for key, value in saved_obj.items():
if isinstance(value, np.ndarray):
MAX_NUMBER_OF_ELEMENT = int((2**30 - 1) / value.dtype.itemsize)
num_element = np.prod(value.shape)
if num_element > MAX_NUMBER_OF_ELEMENT:
unpack_infor[key] = {}
unpack_infor[key]["OriginShape"] = value.shape
unpack_infor[key]["slices"] = []
value = value.flatten()
for i in range(
int(
math.ceil(num_element * 1.0 /
MAX_NUMBER_OF_ELEMENT))):
part_name = key + "@@." + str(i)
unpack_infor[key]["slices"].append(part_name)
temp_saved_obj[part_name] = value[
i * MAX_NUMBER_OF_ELEMENT:MAX_NUMBER_OF_ELEMENT * (i + 1
)]
# When pickle protocol=2 or protocol=3 the serialized object cannot be larger than 4G.
if 1 < protocol < 4:
if isinstance(saved_obj, dict):
for key, value in saved_obj.items():
if isinstance(value, np.ndarray):
MAX_NUMBER_OF_ELEMENT = int(
(2**30 - 1) / value.dtype.itemsize)
num_element = np.prod(value.shape)
if num_element > MAX_NUMBER_OF_ELEMENT:
unpack_infor[key] = {}
unpack_infor[key]["OriginShape"] = value.shape
unpack_infor[key]["slices"] = []
value = value.flatten()
for i in range(
int(
math.ceil(num_element * 1.0 /
MAX_NUMBER_OF_ELEMENT))):
part_name = key + "@@." + str(i)
unpack_infor[key]["slices"].append(part_name)
temp_saved_obj[part_name] = value[
i * MAX_NUMBER_OF_ELEMENT:MAX_NUMBER_OF_ELEMENT
* (i + 1)]

if unpack_infor:
for key, value in unpack_infor.items():
Expand All @@ -1744,21 +1748,24 @@ def _unpack_saved_dict(saved_obj):


def _pack_loaded_dict(load_obj):
unpack_info = 'UnpackBigParamInfor@@'
if unpack_info in load_obj:
removes = []
for key, value in load_obj[unpack_info].items():
slices = [load_obj[part] for part in value["slices"]]
load_obj[key] = np.concatenate(slices).reshape(value["OriginShape"])
removes += value["slices"]
for key in removes:
load_obj.pop(key)
load_obj.pop(unpack_info)
if isinstance(load_obj, dict):
unpack_info = 'UnpackBigParamInfor@@'
if unpack_info in load_obj:
removes = []
for key, value in load_obj[unpack_info].items():
slices = [load_obj[part] for part in value["slices"]]
load_obj[key] = np.concatenate(slices).reshape(value[
"OriginShape"])
removes += value["slices"]
for key in removes:
load_obj.pop(key)
load_obj.pop(unpack_info)

return load_obj


@static_only
def save(program, model_path):
def save(program, model_path, pickle_protocol=2):
"""
:api_attr: Static Graph
Expand All @@ -1771,6 +1778,8 @@ def save(program, model_path):
Args:
program(Program) : The program to saved.
model_path(str): the file prefix to save the program. The format is "dirname/file_prefix". If file_prefix is empty str. A exception will be raised
pickle_protocol(int, optional): The protocol version of pickle module must be greater than 1 and less than 5.
Default: 2
Returns:
None
Expand Down Expand Up @@ -1799,6 +1808,14 @@ def save(program, model_path):
assert base_name != "", \
"The input model_path MUST be format of dirname/filename [dirname\\filename in Windows system], but received model_path is empty string."

if not isinstance(pickle_protocol, int):
raise ValueError("The 'protocol' MUST be `int`, but received {}".format(
type(pickle_protocol)))

if pickle_protocol < 2 or pickle_protocol > 4:
raise ValueError("Expected 1<'protocol'<5, but received protocol={}".
format(pickle_protocol))

dir_name = os.path.dirname(model_path)
if dir_name and not os.path.exists(dir_name):
os.makedirs(dir_name)
Expand All @@ -1809,26 +1826,27 @@ def get_tensor(var):

parameter_list = list(filter(is_parameter, program.list_vars()))
param_dict = {p.name: get_tensor(p) for p in parameter_list}
param_dict = _unpack_saved_dict(param_dict)

param_dict = _unpack_saved_dict(param_dict, pickle_protocol)

# When value of dict is lager than 4GB ,there is a Bug on 'MAC python3.5/6'
if sys.platform == 'darwin' and sys.version_info.major == 3 and (
sys.version_info.minor == 5 or sys.version_info.minor == 6):
pickle_bytes = pickle.dumps(param_dict, protocol=2)
pickle_bytes = pickle.dumps(param_dict, protocol=pickle_protocol)
with open(model_path + ".pdparams", 'wb') as f:
max_bytes = 2**30
for i in range(0, len(pickle_bytes), max_bytes):
f.write(pickle_bytes[i:i + max_bytes])
else:
with open(model_path + ".pdparams", 'wb') as f:
pickle.dump(param_dict, f, protocol=2)
pickle.dump(param_dict, f, protocol=pickle_protocol)

optimizer_var_list = list(
filter(is_belong_to_optimizer, program.list_vars()))

opt_dict = {p.name: get_tensor(p) for p in optimizer_var_list}
with open(model_path + ".pdopt", 'wb') as f:
pickle.dump(opt_dict, f, protocol=2)
pickle.dump(opt_dict, f, protocol=pickle_protocol)

main_program = program.clone()
program.desc.flush()
Expand Down
31 changes: 31 additions & 0 deletions python/paddle/fluid/tests/unittests/test_paddle_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import unittest
import numpy as np
import os
import sys

import paddle
import paddle.nn as nn
import paddle.optimizer as opt
Expand Down Expand Up @@ -100,6 +102,35 @@ def test_large_parameters_paddle_save(self):
self.assertTrue(np.array_equal(dict_load[key], value.numpy()))


class TestSaveLoadPickle(unittest.TestCase):
def test_pickle_protocol(self):
# create network
layer = LinearNet()
save_dict = layer.state_dict()

path = os.path.join("test_paddle_save_load_pickle_protocol",
"layer.pdparams")

with self.assertRaises(ValueError):
paddle.save(save_dict, path, 2.0)

with self.assertRaises(ValueError):
paddle.save(save_dict, path, 1)

with self.assertRaises(ValueError):
paddle.save(save_dict, path, 5)

protocols = [2, ]
if sys.version_info.major >= 3 and sys.version_info.minor >= 4:
protocols += [3, 4]
for protocol in protocols:
paddle.save(save_dict, path, protocol)
dict_load = paddle.load(path)
# compare results before and after saving
for key, value in save_dict.items():
self.assertTrue(np.array_equal(dict_load[key], value.numpy()))


class TestSaveLoad(unittest.TestCase):
def setUp(self):
# enable dygraph mode
Expand Down
65 changes: 65 additions & 0 deletions python/paddle/fluid/tests/unittests/test_static_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from __future__ import print_function
import sys

import unittest
import paddle
Expand Down Expand Up @@ -1444,6 +1445,70 @@ def test_ptb_rnn_cpu_float32(self):
])


class TestStaticSaveLoadPickle(unittest.TestCase):
def test_pickle_protocol(self):
# enable static mode
paddle.enable_static()

with new_program_scope():
# create network
x = paddle.static.data(
name="static_save_load_large_x",
shape=[None, 10],
dtype='float32')
z = paddle.static.nn.fc(x, 10, bias_attr=False)
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
prog = paddle.static.default_main_program()

base_map = {}
for var in prog.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
# make sure all the paramerter or optimizer var have been update
self.assertTrue(np.sum(np.abs(t)) != 0)
base_map[var.name] = t

path = os.path.join("test_static_save_load_pickle",
"pickle_protocol")

with self.assertRaises(ValueError):
paddle.fluid.save(prog, path, 2.0)

with self.assertRaises(ValueError):
paddle.fluid.save(prog, path, 1)

with self.assertRaises(ValueError):
paddle.fluid.save(prog, path, 5)

protocols = [2, ]
if sys.version_info.major >= 3 and sys.version_info.minor >= 4:
protocols += [3, 4]
for protocol in protocols:
paddle.fluid.save(prog, path, protocol)
# set var to zero
for var in prog.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
ten = fluid.global_scope().find_var(
var.name).get_tensor()
ten.set(np.zeros_like(np.array(ten)), place)

new_t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
self.assertTrue(np.sum(np.abs(new_t)) == 0)

paddle.fluid.load(prog, path)

for var in prog.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
new_t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
base_t = base_map[var.name]
self.assertTrue(np.array_equal(new_t, base_t))


if __name__ == '__main__':
paddle.enable_static()
unittest.main()
23 changes: 18 additions & 5 deletions python/paddle/framework/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import six
import warnings
import sys
import numpy as np

import paddle

Expand Down Expand Up @@ -198,7 +199,7 @@ def _parse_load_config(configs):
return inner_config


def save(obj, path):
def save(obj, path, pickle_protocol=2):
'''
Save an object to the specified path.
Expand All @@ -218,6 +219,8 @@ def save(obj, path):
obj(Object) : The object to be saved.
path(str) : The path of the object to be saved.
If saved in the current directory, the input path string will be used as the file name.
pickle_protocol(int, optional): The protocol version of pickle module must be greater than 1 and less than 5.
Default: 2
Returns:
None
Expand Down Expand Up @@ -254,26 +257,36 @@ def save(obj, path):
"[dirname\\filename in Windows system], but received "
"filename is empty string.")

if not isinstance(pickle_protocol, int):
raise ValueError("The 'protocol' MUST be `int`, but received {}".format(
type(pickle_protocol)))

if pickle_protocol < 2 or pickle_protocol > 4:
raise ValueError("Expected 1<'protocol'<5, but received protocol={}".
format(pickle_protocol))

# 2. save object
dirname = os.path.dirname(path)
if dirname and not os.path.exists(dirname):
os.makedirs(dirname)

# TODO(chenweihang): supports save other object
saved_obj = _build_saved_state_dict(obj)
saved_obj = _unpack_saved_dict(saved_obj)
if isinstance(obj, dict):
saved_obj = _build_saved_state_dict(obj)

saved_obj = _unpack_saved_dict(saved_obj, pickle_protocol)

# When value of dict is lager than 4GB ,there is a Bug on 'MAC python3.5/6'
if sys.platform == 'darwin' and sys.version_info.major == 3 and (
sys.version_info.minor == 5 or sys.version_info.minor == 6):
pickle_bytes = pickle.dumps(saved_obj, protocol=2)
pickle_bytes = pickle.dumps(saved_obj, protocol=pickle_protocol)
with open(path, 'wb') as f:
max_bytes = 2**30
for i in range(0, len(pickle_bytes), max_bytes):
f.write(pickle_bytes[i:i + max_bytes])
else:
with open(path, 'wb') as f:
pickle.dump(saved_obj, f, protocol=2)
pickle.dump(saved_obj, f, protocol=pickle_protocol)


def load(path, **configs):
Expand Down

1 comment on commit 6d7ca4c

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.