Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

paddle.save/paddle.static.save 升级pickle的版本 #31044

Merged
merged 4 commits into from
Feb 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 52 additions & 34 deletions python/paddle/fluid/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -1711,27 +1711,31 @@ def _exist(var):
load_vars(executor=executor, dirname=dirname, vars=var_list)


def _unpack_saved_dict(saved_obj):
def _unpack_saved_dict(saved_obj, protocol):
temp_saved_obj = {}
unpack_infor = {}
for key, value in saved_obj.items():
if isinstance(value, np.ndarray):
MAX_NUMBER_OF_ELEMENT = int((2**30 - 1) / value.dtype.itemsize)
num_element = np.prod(value.shape)
if num_element > MAX_NUMBER_OF_ELEMENT:
unpack_infor[key] = {}
unpack_infor[key]["OriginShape"] = value.shape
unpack_infor[key]["slices"] = []
value = value.flatten()
for i in range(
int(
math.ceil(num_element * 1.0 /
MAX_NUMBER_OF_ELEMENT))):
part_name = key + "@@." + str(i)
unpack_infor[key]["slices"].append(part_name)
temp_saved_obj[part_name] = value[
i * MAX_NUMBER_OF_ELEMENT:MAX_NUMBER_OF_ELEMENT * (i + 1
)]
# When pickle protocol=2 or protocol=3 the serialized object cannot be larger than 4G.
if 1 < protocol < 4:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对特定版本的特殊处理逻辑最好注释解释一下,方便其他人理解以及后续维护,可以在后续pr再补充

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thx.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

if isinstance(saved_obj, dict):
for key, value in saved_obj.items():
if isinstance(value, np.ndarray):
MAX_NUMBER_OF_ELEMENT = int(
(2**30 - 1) / value.dtype.itemsize)
num_element = np.prod(value.shape)
if num_element > MAX_NUMBER_OF_ELEMENT:
unpack_infor[key] = {}
unpack_infor[key]["OriginShape"] = value.shape
unpack_infor[key]["slices"] = []
value = value.flatten()
for i in range(
int(
math.ceil(num_element * 1.0 /
MAX_NUMBER_OF_ELEMENT))):
part_name = key + "@@." + str(i)
unpack_infor[key]["slices"].append(part_name)
temp_saved_obj[part_name] = value[
i * MAX_NUMBER_OF_ELEMENT:MAX_NUMBER_OF_ELEMENT
* (i + 1)]

if unpack_infor:
for key, value in unpack_infor.items():
Expand All @@ -1744,21 +1748,24 @@ def _unpack_saved_dict(saved_obj):


def _pack_loaded_dict(load_obj):
unpack_info = 'UnpackBigParamInfor@@'
if unpack_info in load_obj:
removes = []
for key, value in load_obj[unpack_info].items():
slices = [load_obj[part] for part in value["slices"]]
load_obj[key] = np.concatenate(slices).reshape(value["OriginShape"])
removes += value["slices"]
for key in removes:
load_obj.pop(key)
load_obj.pop(unpack_info)
if isinstance(load_obj, dict):
unpack_info = 'UnpackBigParamInfor@@'
if unpack_info in load_obj:
removes = []
for key, value in load_obj[unpack_info].items():
slices = [load_obj[part] for part in value["slices"]]
load_obj[key] = np.concatenate(slices).reshape(value[
"OriginShape"])
removes += value["slices"]
for key in removes:
load_obj.pop(key)
load_obj.pop(unpack_info)

return load_obj


@static_only
def save(program, model_path):
def save(program, model_path, pickle_protocol=2):
"""
:api_attr: Static Graph
Expand All @@ -1771,6 +1778,8 @@ def save(program, model_path):
Args:
program(Program) : The program to saved.
model_path(str): the file prefix to save the program. The format is "dirname/file_prefix". If file_prefix is empty str. A exception will be raised
pickle_protocol(int, optional): The protocol version of pickle module must be greater than 1 and less than 5.
Default: 2
Returns:
None
Expand Down Expand Up @@ -1799,6 +1808,14 @@ def save(program, model_path):
assert base_name != "", \
"The input model_path MUST be format of dirname/filename [dirname\\filename in Windows system], but received model_path is empty string."

if not isinstance(pickle_protocol, int):
raise ValueError("The 'protocol' MUST be `int`, but received {}".format(
type(pickle_protocol)))

if pickle_protocol < 2 or pickle_protocol > 4:
raise ValueError("Expected 1<'protocol'<5, but received protocol={}".
format(pickle_protocol))

dir_name = os.path.dirname(model_path)
if dir_name and not os.path.exists(dir_name):
os.makedirs(dir_name)
Expand All @@ -1809,26 +1826,27 @@ def get_tensor(var):

parameter_list = list(filter(is_parameter, program.list_vars()))
param_dict = {p.name: get_tensor(p) for p in parameter_list}
param_dict = _unpack_saved_dict(param_dict)

param_dict = _unpack_saved_dict(param_dict, pickle_protocol)

# When value of dict is lager than 4GB ,there is a Bug on 'MAC python3.5/6'
if sys.platform == 'darwin' and sys.version_info.major == 3 and (
sys.version_info.minor == 5 or sys.version_info.minor == 6):
pickle_bytes = pickle.dumps(param_dict, protocol=2)
pickle_bytes = pickle.dumps(param_dict, protocol=pickle_protocol)
with open(model_path + ".pdparams", 'wb') as f:
max_bytes = 2**30
for i in range(0, len(pickle_bytes), max_bytes):
f.write(pickle_bytes[i:i + max_bytes])
else:
with open(model_path + ".pdparams", 'wb') as f:
pickle.dump(param_dict, f, protocol=2)
pickle.dump(param_dict, f, protocol=pickle_protocol)

optimizer_var_list = list(
filter(is_belong_to_optimizer, program.list_vars()))

opt_dict = {p.name: get_tensor(p) for p in optimizer_var_list}
with open(model_path + ".pdopt", 'wb') as f:
pickle.dump(opt_dict, f, protocol=2)
pickle.dump(opt_dict, f, protocol=pickle_protocol)

main_program = program.clone()
program.desc.flush()
Expand Down
31 changes: 31 additions & 0 deletions python/paddle/fluid/tests/unittests/test_paddle_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import unittest
import numpy as np
import os
import sys

import paddle
import paddle.nn as nn
import paddle.optimizer as opt
Expand Down Expand Up @@ -100,6 +102,35 @@ def test_large_parameters_paddle_save(self):
self.assertTrue(np.array_equal(dict_load[key], value.numpy()))


class TestSaveLoadPickle(unittest.TestCase):
def test_pickle_protocol(self):
# create network
layer = LinearNet()
save_dict = layer.state_dict()

path = os.path.join("test_paddle_save_load_pickle_protocol",
"layer.pdparams")

with self.assertRaises(ValueError):
paddle.save(save_dict, path, 2.0)

with self.assertRaises(ValueError):
paddle.save(save_dict, path, 1)

with self.assertRaises(ValueError):
paddle.save(save_dict, path, 5)

protocols = [2, ]
if sys.version_info.major >= 3 and sys.version_info.minor >= 4:
protocols += [3, 4]
for protocol in protocols:
paddle.save(save_dict, path, protocol)
dict_load = paddle.load(path)
# compare results before and after saving
for key, value in save_dict.items():
self.assertTrue(np.array_equal(dict_load[key], value.numpy()))


class TestSaveLoad(unittest.TestCase):
def setUp(self):
# enable dygraph mode
Expand Down
65 changes: 65 additions & 0 deletions python/paddle/fluid/tests/unittests/test_static_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from __future__ import print_function
import sys

import unittest
import paddle
Expand Down Expand Up @@ -1452,6 +1453,70 @@ def test_ptb_rnn_cpu_float32(self):
])


class TestStaticSaveLoadPickle(unittest.TestCase):
def test_pickle_protocol(self):
# enable static mode
paddle.enable_static()

with new_program_scope():
# create network
x = paddle.static.data(
name="static_save_load_large_x",
shape=[None, 10],
dtype='float32')
z = paddle.static.nn.fc(x, 10, bias_attr=False)
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
prog = paddle.static.default_main_program()

base_map = {}
for var in prog.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
# make sure all the paramerter or optimizer var have been update
self.assertTrue(np.sum(np.abs(t)) != 0)
base_map[var.name] = t

path = os.path.join("test_static_save_load_pickle",
"pickle_protocol")

with self.assertRaises(ValueError):
paddle.fluid.save(prog, path, 2.0)

with self.assertRaises(ValueError):
paddle.fluid.save(prog, path, 1)

with self.assertRaises(ValueError):
paddle.fluid.save(prog, path, 5)

protocols = [2, ]
if sys.version_info.major >= 3 and sys.version_info.minor >= 4:
protocols += [3, 4]
for protocol in protocols:
paddle.fluid.save(prog, path, protocol)
# set var to zero
for var in prog.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
ten = fluid.global_scope().find_var(
var.name).get_tensor()
ten.set(np.zeros_like(np.array(ten)), place)

new_t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
self.assertTrue(np.sum(np.abs(new_t)) == 0)

paddle.fluid.load(prog, path)

for var in prog.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
new_t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
base_t = base_map[var.name]
self.assertTrue(np.array_equal(new_t, base_t))


if __name__ == '__main__':
paddle.enable_static()
unittest.main()
23 changes: 18 additions & 5 deletions python/paddle/framework/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import six
import warnings
import sys
import numpy as np

import paddle

Expand Down Expand Up @@ -198,7 +199,7 @@ def _parse_load_config(configs):
return inner_config


def save(obj, path):
def save(obj, path, pickle_protocol=2):
'''
Save an object to the specified path.
Expand All @@ -218,6 +219,8 @@ def save(obj, path):
obj(Object) : The object to be saved.
path(str) : The path of the object to be saved.
If saved in the current directory, the input path string will be used as the file name.
pickle_protocol(int, optional): The protocol version of pickle module must be greater than 1 and less than 5.
Default: 2
Returns:
None
Expand Down Expand Up @@ -254,26 +257,36 @@ def save(obj, path):
"[dirname\\filename in Windows system], but received "
"filename is empty string.")

if not isinstance(pickle_protocol, int):
raise ValueError("The 'protocol' MUST be `int`, but received {}".format(
type(pickle_protocol)))

if pickle_protocol < 2 or pickle_protocol > 4:
raise ValueError("Expected 1<'protocol'<5, but received protocol={}".
format(pickle_protocol))

# 2. save object
dirname = os.path.dirname(path)
if dirname and not os.path.exists(dirname):
os.makedirs(dirname)

# TODO(chenweihang): supports save other object
saved_obj = _build_saved_state_dict(obj)
saved_obj = _unpack_saved_dict(saved_obj)
if isinstance(obj, dict):
saved_obj = _build_saved_state_dict(obj)

saved_obj = _unpack_saved_dict(saved_obj, pickle_protocol)

# When value of dict is lager than 4GB ,there is a Bug on 'MAC python3.5/6'
if sys.platform == 'darwin' and sys.version_info.major == 3 and (
sys.version_info.minor == 5 or sys.version_info.minor == 6):
pickle_bytes = pickle.dumps(saved_obj, protocol=2)
pickle_bytes = pickle.dumps(saved_obj, protocol=pickle_protocol)
with open(path, 'wb') as f:
max_bytes = 2**30
for i in range(0, len(pickle_bytes), max_bytes):
f.write(pickle_bytes[i:i + max_bytes])
else:
with open(path, 'wb') as f:
pickle.dump(saved_obj, f, protocol=2)
pickle.dump(saved_obj, f, protocol=pickle_protocol)


def load(path, **configs):
Expand Down