@@ -369,6 +369,7 @@ def _parse_load_config(configs):
369369 'params_filename' ,
370370 'keep_name_table' ,
371371 'return_numpy' ,
372+ 'safetensors' ,
372373 ]
373374
374375 # input check
@@ -388,12 +389,13 @@ def _parse_load_config(configs):
388389 inner_config .params_filename = configs .get ('params_filename' , None )
389390 inner_config .keep_name_table = configs .get ('keep_name_table' , None )
390391 inner_config .return_numpy = configs .get ('return_numpy' , False )
392+ inner_config .safetensors = configs .get ('safetensors' , False )
391393
392394 return inner_config
393395
394396
395397def _parse_save_config (configs ):
396- supported_configs = ['use_binary_format' , 'pickle_protocol' ]
398+ supported_configs = ['use_binary_format' , 'pickle_protocol' , 'safetensors' ]
397399
398400 # input check
399401 for key in configs :
@@ -410,6 +412,7 @@ def _parse_save_config(configs):
410412 inner_config = _SaveLoadConfig ()
411413 inner_config .use_binary_format = configs .get ('use_binary_format' , False )
412414 inner_config .pickle_protocol = configs .get ('pickle_protocol' , None )
415+ inner_config .safetensors = configs .get ('safetensors' , False )
413416
414417 return inner_config
415418
@@ -956,14 +959,45 @@ def save(
956959
957960 elif _is_state_dict (obj ):
958961 if in_dygraph_mode ():
959- _legacy_save (obj , path , protocol )
962+ if config .safetensors :
963+ _safe_save (obj , path )
964+ else :
965+ _legacy_save (obj , path , protocol )
960966 else :
961967 _legacy_static_save (obj , path , protocol )
962968 else :
963969 with _open_file_buffer (path , 'wb' ) as f :
964970 _pickle_save (obj , f , protocol )
965971
966972
973+ def _safe_save (obj , path ):
974+ if not isinstance (obj , dict ):
975+ raise NotImplementedError (
976+ "Now only supports save state_dict of Layer or Optimizer, "
977+ f"expect dict, but received { type (obj )} ."
978+ )
979+
980+ if len (obj ) == 0 :
981+ warnings .warn ("The input state dict is empty, no need to save." )
982+
983+ if _is_file_path (path ):
984+ filename = os .path .basename (path )
985+ if filename == "" :
986+ raise ValueError (
987+ "The input path MUST be format of dirname/filename "
988+ "[dirname\\ filename in Windows system], but received "
989+ "filename is empty string."
990+ )
991+ # 2. save object
992+ dirname = os .path .dirname (path )
993+ if dirname and not os .path .exists (dirname ):
994+ os .makedirs (dirname , exist_ok = True )
995+
996+ from safetensors .paddle import save_file
997+
998+ save_file (obj , path )
999+
1000+
9671001def _legacy_save (obj , path , protocol = 2 ):
9681002 # 1. input check
9691003 if not isinstance (obj , dict ):
@@ -1190,6 +1224,11 @@ def load(path: str | BytesIO, **configs: Unpack[_LoadOptions]) -> Any:
11901224 config = _parse_load_config (configs )
11911225 exception_type = pickle .UnpicklingError
11921226 try :
1227+ if config .safetensors :
1228+ from safetensors .paddle import load_file
1229+
1230+ load_result = load_file (path )
1231+ return load_result
11931232 with _open_file_buffer (path , 'rb' ) as f :
11941233 # When value of dict is lager than 4GB ,there is a Bug on 'MAC python3'
11951234 if (
@@ -1310,8 +1349,13 @@ def _legacy_load(path, **configs):
13101349
13111350 if os .path .isfile (path ) or _is_memory_buffer (path ):
13121351 # we think path is file means this file is created by paddle.save
1313- with _open_file_buffer (path , 'rb' ) as f :
1314- load_result = pickle .load (f , encoding = 'latin1' )
1352+ if config .safetensors :
1353+ from safetensors .paddle import load_file
1354+
1355+ load_result = load_file (path )
1356+ else :
1357+ with _open_file_buffer (path , 'rb' ) as f :
1358+ load_result = pickle .load (f , encoding = 'latin1' )
13151359 load_result = _pack_loaded_dict (load_result )
13161360 if (
13171361 not config .keep_name_table
0 commit comments