Skip to content

Commit 9577126

Browse files
[Safetensors]Add safetensors to paddle save/load (#74609)
* add safetensors to paddle save/load * add dependency * fix dcu bug * fix requirements * fix
1 parent 7cd2789 commit 9577126

File tree

3 files changed

+76
-4
lines changed

3 files changed

+76
-4
lines changed

python/paddle/framework/io.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,7 @@ def _parse_load_config(configs):
369369
'params_filename',
370370
'keep_name_table',
371371
'return_numpy',
372+
'safetensors',
372373
]
373374

374375
# input check
@@ -388,12 +389,13 @@ def _parse_load_config(configs):
388389
inner_config.params_filename = configs.get('params_filename', None)
389390
inner_config.keep_name_table = configs.get('keep_name_table', None)
390391
inner_config.return_numpy = configs.get('return_numpy', False)
392+
inner_config.safetensors = configs.get('safetensors', False)
391393

392394
return inner_config
393395

394396

395397
def _parse_save_config(configs):
396-
supported_configs = ['use_binary_format', 'pickle_protocol']
398+
supported_configs = ['use_binary_format', 'pickle_protocol', 'safetensors']
397399

398400
# input check
399401
for key in configs:
@@ -410,6 +412,7 @@ def _parse_save_config(configs):
410412
inner_config = _SaveLoadConfig()
411413
inner_config.use_binary_format = configs.get('use_binary_format', False)
412414
inner_config.pickle_protocol = configs.get('pickle_protocol', None)
415+
inner_config.safetensors = configs.get('safetensors', False)
413416

414417
return inner_config
415418

@@ -956,14 +959,45 @@ def save(
956959

957960
elif _is_state_dict(obj):
958961
if in_dygraph_mode():
959-
_legacy_save(obj, path, protocol)
962+
if config.safetensors:
963+
_safe_save(obj, path)
964+
else:
965+
_legacy_save(obj, path, protocol)
960966
else:
961967
_legacy_static_save(obj, path, protocol)
962968
else:
963969
with _open_file_buffer(path, 'wb') as f:
964970
_pickle_save(obj, f, protocol)
965971

966972

973+
def _safe_save(obj, path):
974+
if not isinstance(obj, dict):
975+
raise NotImplementedError(
976+
"Now only supports save state_dict of Layer or Optimizer, "
977+
f"expect dict, but received {type(obj)}."
978+
)
979+
980+
if len(obj) == 0:
981+
warnings.warn("The input state dict is empty, no need to save.")
982+
983+
if _is_file_path(path):
984+
filename = os.path.basename(path)
985+
if filename == "":
986+
raise ValueError(
987+
"The input path MUST be format of dirname/filename "
988+
"[dirname\\filename in Windows system], but received "
989+
"filename is empty string."
990+
)
991+
# 2. save object
992+
dirname = os.path.dirname(path)
993+
if dirname and not os.path.exists(dirname):
994+
os.makedirs(dirname, exist_ok=True)
995+
996+
from safetensors.paddle import save_file
997+
998+
save_file(obj, path)
999+
1000+
9671001
def _legacy_save(obj, path, protocol=2):
9681002
# 1. input check
9691003
if not isinstance(obj, dict):
@@ -1190,6 +1224,11 @@ def load(path: str | BytesIO, **configs: Unpack[_LoadOptions]) -> Any:
11901224
config = _parse_load_config(configs)
11911225
exception_type = pickle.UnpicklingError
11921226
try:
1227+
if config.safetensors:
1228+
from safetensors.paddle import load_file
1229+
1230+
load_result = load_file(path)
1231+
return load_result
11931232
with _open_file_buffer(path, 'rb') as f:
11941233
# When value of dict is lager than 4GB ,there is a Bug on 'MAC python3'
11951234
if (
@@ -1310,8 +1349,13 @@ def _legacy_load(path, **configs):
13101349

13111350
if os.path.isfile(path) or _is_memory_buffer(path):
13121351
# we think path is file means this file is created by paddle.save
1313-
with _open_file_buffer(path, 'rb') as f:
1314-
load_result = pickle.load(f, encoding='latin1')
1352+
if config.safetensors:
1353+
from safetensors.paddle import load_file
1354+
1355+
load_result = load_file(path)
1356+
else:
1357+
with _open_file_buffer(path, 'rb') as f:
1358+
load_result = pickle.load(f, encoding='latin1')
13151359
load_result = _pack_loaded_dict(load_result)
13161360
if (
13171361
not config.keep_name_table

python/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ Pillow
55
opt_einsum==3.3.0
66
networkx
77
typing_extensions
8+
safetensors>=0.6.0

test/legacy_test/test_paddle_save_load.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,33 @@ def test_pickle_protocol(self):
161161
)
162162

163163

164+
# class TestSaveLoadSafetensors(unittest.TestCase):
165+
# def setUp(self):
166+
# self.temp_dir = tempfile.TemporaryDirectory()
167+
168+
# def tearDown(self):
169+
# self.temp_dir.cleanup()
170+
171+
# def test_safetensors(self):
172+
# # enable dygraph mode
173+
# paddle.disable_static()
174+
# # create network
175+
# layer = LinearNet()
176+
# save_dict = layer.state_dict()
177+
178+
# path = os.path.join(
179+
# self.temp_dir.name,
180+
# "test_paddle_save_load_safetensors",
181+
# "layer.safetensors",
182+
# )
183+
184+
# paddle.save(save_dict, path, safetensors=True)
185+
# dict_load = paddle.load(path, safetensors=True)
186+
# # compare results before and after saving
187+
# for key, value in save_dict.items():
188+
# np.testing.assert_array_equal(dict_load[key].numpy(), value.numpy())
189+
190+
164191
class TestSaveLoadAny(unittest.TestCase):
165192
def setUp(self):
166193
self.temp_dir = tempfile.TemporaryDirectory()

0 commit comments

Comments
 (0)