Skip to content

Commit

Permalink
add default argument for paddle.save/static.save
Browse files Browse the repository at this point in the history
  • Loading branch information
hbwx24 committed Feb 19, 2021
1 parent c137578 commit 5bdf1db
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 39 deletions.
86 changes: 52 additions & 34 deletions python/paddle/fluid/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -1711,27 +1711,31 @@ def _exist(var):
load_vars(executor=executor, dirname=dirname, vars=var_list)


def _unpack_saved_dict(saved_obj):
def _unpack_saved_dict(saved_obj, protocol):
temp_saved_obj = {}
unpack_infor = {}
for key, value in saved_obj.items():
if isinstance(value, np.ndarray):
MAX_NUMBER_OF_ELEMENT = int((2**30 - 1) / value.dtype.itemsize)
num_element = np.prod(value.shape)
if num_element > MAX_NUMBER_OF_ELEMENT:
unpack_infor[key] = {}
unpack_infor[key]["OriginShape"] = value.shape
unpack_infor[key]["slices"] = []
value = value.flatten()
for i in range(
int(
math.ceil(num_element * 1.0 /
MAX_NUMBER_OF_ELEMENT))):
part_name = key + "@@." + str(i)
unpack_infor[key]["slices"].append(part_name)
temp_saved_obj[part_name] = value[
i * MAX_NUMBER_OF_ELEMENT:MAX_NUMBER_OF_ELEMENT * (i + 1
)]

if 1 < protocol < 4:
if isinstance(saved_obj, dict):
for key, value in saved_obj.items():
if isinstance(value, np.ndarray):
MAX_NUMBER_OF_ELEMENT = int(
(2**30 - 1) / value.dtype.itemsize)
num_element = np.prod(value.shape)
if num_element > MAX_NUMBER_OF_ELEMENT:
unpack_infor[key] = {}
unpack_infor[key]["OriginShape"] = value.shape
unpack_infor[key]["slices"] = []
value = value.flatten()
for i in range(
int(
math.ceil(num_element * 1.0 /
MAX_NUMBER_OF_ELEMENT))):
part_name = key + "@@." + str(i)
unpack_infor[key]["slices"].append(part_name)
temp_saved_obj[part_name] = value[
i * MAX_NUMBER_OF_ELEMENT:MAX_NUMBER_OF_ELEMENT
* (i + 1)]

if unpack_infor:
for key, value in unpack_infor.items():
Expand All @@ -1744,21 +1748,24 @@ def _unpack_saved_dict(saved_obj):


def _pack_loaded_dict(load_obj):
unpack_info = 'UnpackBigParamInfor@@'
if unpack_info in load_obj:
removes = []
for key, value in load_obj[unpack_info].items():
slices = [load_obj[part] for part in value["slices"]]
load_obj[key] = np.concatenate(slices).reshape(value["OriginShape"])
removes += value["slices"]
for key in removes:
load_obj.pop(key)
load_obj.pop(unpack_info)
if isinstance(load_obj, dict):
unpack_info = 'UnpackBigParamInfor@@'
if unpack_info in load_obj:
removes = []
for key, value in load_obj[unpack_info].items():
slices = [load_obj[part] for part in value["slices"]]
load_obj[key] = np.concatenate(slices).reshape(value[
"OriginShape"])
removes += value["slices"]
for key in removes:
load_obj.pop(key)
load_obj.pop(unpack_info)

return load_obj


@static_only
def save(program, model_path):
def save(program, model_path, pickle_protocol=2):
"""
:api_attr: Static Graph
Expand All @@ -1771,6 +1778,8 @@ def save(program, model_path):
Args:
program(Program) : The program to saved.
model_path(str): the file prefix to save the program. The format is "dirname/file_prefix". If file_prefix is empty str. A exception will be raised
pickle_protocol(int, optional): The protocol version of pickle module must be greater than 1 and less than 5.
Default: None
Returns:
None
Expand Down Expand Up @@ -1799,6 +1808,14 @@ def save(program, model_path):
assert base_name != "", \
"The input model_path MUST be format of dirname/filename [dirname\\filename in Windows system], but received model_path is empty string."

if not isinstance(pickle_protocol, int):
raise ValueError("The 'protocol' MUST be `int`, but received {}".format(
type(pickle_protocol)))

if pickle_protocol < 2 or pickle_protocol > 4:
raise ValueError("Expected 1<'protocol'<5, but received protocol={}".
format(pickle_protocol))

dir_name = os.path.dirname(model_path)
if dir_name and not os.path.exists(dir_name):
os.makedirs(dir_name)
Expand All @@ -1809,26 +1826,27 @@ def get_tensor(var):

parameter_list = list(filter(is_parameter, program.list_vars()))
param_dict = {p.name: get_tensor(p) for p in parameter_list}
param_dict = _unpack_saved_dict(param_dict)

param_dict = _unpack_saved_dict(param_dict, pickle_protocol)

# When value of dict is lager than 4GB ,there is a Bug on 'MAC python3.5/6'
if sys.platform == 'darwin' and sys.version_info.major == 3 and (
sys.version_info.minor == 5 or sys.version_info.minor == 6):
pickle_bytes = pickle.dumps(param_dict, protocol=2)
pickle_bytes = pickle.dumps(param_dict, protocol=pickle_protocol)
with open(model_path + ".pdparams", 'wb') as f:
max_bytes = 2**30
for i in range(0, len(pickle_bytes), max_bytes):
f.write(pickle_bytes[i:i + max_bytes])
else:
with open(model_path + ".pdparams", 'wb') as f:
pickle.dump(param_dict, f, protocol=2)
pickle.dump(param_dict, f, protocol=pickle_protocol)

optimizer_var_list = list(
filter(is_belong_to_optimizer, program.list_vars()))

opt_dict = {p.name: get_tensor(p) for p in optimizer_var_list}
with open(model_path + ".pdopt", 'wb') as f:
pickle.dump(opt_dict, f, protocol=2)
pickle.dump(opt_dict, f, protocol=pickle_protocol)

main_program = program.clone()
program.desc.flush()
Expand Down
31 changes: 31 additions & 0 deletions python/paddle/fluid/tests/unittests/test_paddle_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import unittest
import numpy as np
import os
import sys

import paddle
import paddle.nn as nn
import paddle.optimizer as opt
Expand Down Expand Up @@ -100,6 +102,35 @@ def test_large_parameters_paddle_save(self):
self.assertTrue(np.array_equal(dict_load[key], value.numpy()))


class TestSaveLoadPickle(unittest.TestCase):
def test_pickle_protocol(self):
# create network
layer = LinearNet()
save_dict = layer.state_dict()

path = os.path.join("test_paddle_save_load_pickle_protocol",
"layer.pdparams")

with self.assertRaises(ValueError):
paddle.save(save_dict, path, 2.0)

with self.assertRaises(ValueError):
paddle.save(save_dict, path, 1)

with self.assertRaises(ValueError):
paddle.save(save_dict, path, 5)

protocols = [2, ]
if sys.version_info.major >= 3 and sys.version_info.minor >= 4:
protocols += [3, 4]
for protocol in protocols:
paddle.save(save_dict, path, protocol)
dict_load = paddle.load(path)
# compare results before and after saving
for key, value in save_dict.items():
self.assertTrue(np.array_equal(dict_load[key], value.numpy()))


class TestSaveLoad(unittest.TestCase):
def setUp(self):
# enable dygraph mode
Expand Down
65 changes: 65 additions & 0 deletions python/paddle/fluid/tests/unittests/test_static_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from __future__ import print_function
import sys

import unittest
import paddle
Expand Down Expand Up @@ -1452,6 +1453,70 @@ def test_ptb_rnn_cpu_float32(self):
])


class TestStaticSaveLoadPickle(unittest.TestCase):
def test_pickle_protocol(self):
# enable static mode
paddle.enable_static()

with new_program_scope():
# create network
x = paddle.static.data(
name="static_save_load_large_x",
shape=[None, 10],
dtype='float32')
z = paddle.static.nn.fc(x, 10, bias_attr=False)
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
prog = paddle.static.default_main_program()

base_map = {}
for var in prog.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
# make sure all the paramerter or optimizer var have been update
self.assertTrue(np.sum(np.abs(t)) != 0)
base_map[var.name] = t

path = os.path.join("test_static_save_load_pickle",
"pickle_protocol")

with self.assertRaises(ValueError):
paddle.fluid.save(prog, path, 2.0)

with self.assertRaises(ValueError):
paddle.fluid.save(prog, path, 1)

with self.assertRaises(ValueError):
paddle.fluid.save(prog, path, 5)

protocols = [2, ]
if sys.version_info.major >= 3 and sys.version_info.minor >= 4:
protocols += [3, 4]
for protocol in protocols:
paddle.fluid.save(prog, path, protocol)
# set var to zero
for var in prog.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
ten = fluid.global_scope().find_var(
var.name).get_tensor()
ten.set(np.zeros_like(np.array(ten)), place)

new_t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
self.assertTrue(np.sum(np.abs(new_t)) == 0)

paddle.fluid.load(prog, path)

for var in prog.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
new_t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
base_t = base_map[var.name]
self.assertTrue(np.array_equal(new_t, base_t))


if __name__ == '__main__':
paddle.enable_static()
unittest.main()
23 changes: 18 additions & 5 deletions python/paddle/framework/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import six
import warnings
import sys
import numpy as np

import paddle

Expand Down Expand Up @@ -198,7 +199,7 @@ def _parse_load_config(configs):
return inner_config


def save(obj, path):
def save(obj, path, pickle_protocol=2):
'''
Save an object to the specified path.
Expand All @@ -218,6 +219,8 @@ def save(obj, path):
obj(Object) : The object to be saved.
path(str) : The path of the object to be saved.
If saved in the current directory, the input path string will be used as the file name.
pickle_protocol(int, optional): The protocol version of pickle module must be greater than 1 and less than 5.
Default: None
Returns:
None
Expand Down Expand Up @@ -254,26 +257,36 @@ def save(obj, path):
"[dirname\\filename in Windows system], but received "
"filename is empty string.")

if not isinstance(pickle_protocol, int):
raise ValueError("The 'protocol' MUST be `int`, but received {}".format(
type(pickle_protocol)))

if pickle_protocol < 2 or pickle_protocol > 4:
raise ValueError("Expected 1<'protocol'<5, but received protocol={}".
format(pickle_protocol))

# 2. save object
dirname = os.path.dirname(path)
if dirname and not os.path.exists(dirname):
os.makedirs(dirname)

# TODO(chenweihang): supports save other object
saved_obj = _build_saved_state_dict(obj)
saved_obj = _unpack_saved_dict(saved_obj)
if isinstance(obj, dict):
saved_obj = _build_saved_state_dict(obj)

saved_obj = _unpack_saved_dict(saved_obj, pickle_protocol)

# When value of dict is lager than 4GB ,there is a Bug on 'MAC python3.5/6'
if sys.platform == 'darwin' and sys.version_info.major == 3 and (
sys.version_info.minor == 5 or sys.version_info.minor == 6):
pickle_bytes = pickle.dumps(saved_obj, protocol=2)
pickle_bytes = pickle.dumps(saved_obj, protocol=pickle_protocol)
with open(path, 'wb') as f:
max_bytes = 2**30
for i in range(0, len(pickle_bytes), max_bytes):
f.write(pickle_bytes[i:i + max_bytes])
else:
with open(path, 'wb') as f:
pickle.dump(saved_obj, f, protocol=2)
pickle.dump(saved_obj, f, protocol=pickle_protocol)


def load(path, **configs):
Expand Down

0 comments on commit 5bdf1db

Please sign in to comment.