-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add checkpoint util class and implement #10532
Changes from 54 commits
568a329
1fabbba
77c6b71
b81671e
e21a72d
2a05b3d
87a0856
dc534fc
802d10c
d1bd3fd
5e74db3
a1419f1
461d2fc
2f4c039
38596cf
ce1bcc9
3c82006
f04b23a
c80125f
2e25e73
30b50dc
0334d49
d081256
886897c
9cf47af
c6f042f
b677d82
744e95d
955c793
3dd2746
4220b31
6d53dce
8430c8d
7b6c0ab
f9d4b9d
a4fd375
f688652
821acdb
eff92d0
cd98f2b
dbd0237
22df4c2
d98480c
ee91e48
b6ee59a
e130bf3
5451c78
01975ec
ed2129c
be05056
06aa23b
2412dee
e901de6
27b7175
9d98534
d96b442
192f9a5
cf3fb24
2c47e06
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,21 +13,17 @@ | |
# limitations under the License. | ||
|
||
import os | ||
import time | ||
import shutil | ||
|
||
from paddle.fluid.evaluator import Evaluator | ||
from paddle.fluid.framework import Program, Parameter, default_main_program, Variable | ||
from . import core | ||
|
||
__all__ = [ | ||
'save_vars', | ||
'save_params', | ||
'save_persistables', | ||
'load_vars', | ||
'load_params', | ||
'load_persistables', | ||
'save_inference_model', | ||
'load_inference_model', | ||
'get_inference_program', | ||
'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params', | ||
'load_persistables', 'save_inference_model', 'load_inference_model', | ||
'get_inference_program', 'save_checkpoint', 'load_checkpoint' | ||
] | ||
|
||
|
||
|
@@ -195,6 +191,8 @@ def load_vars(executor, | |
load_var_map = {} | ||
for each_var in vars: | ||
assert isinstance(each_var, Variable) | ||
if each_var.type == core.VarDesc.VarType.RAW: | ||
continue | ||
new_var = _clone_var_in_block_(load_block, each_var) | ||
if filename is None: | ||
load_block.append_op( | ||
|
@@ -454,3 +452,167 @@ def get_parameter_value_by_name(name, executor, program=None): | |
program = default_main_program() | ||
var = program.global_block().var(name) | ||
return get_parameter_value(var, executor) | ||
|
||
|
||
SUCCESS_MARK_FILENAME = "_SUCCESS" | ||
|
||
|
||
def save_checkpoint(executor, | ||
dirname=None, | ||
max_num_checkpoints=3, | ||
save_interval_secs=600, | ||
main_program=None): | ||
""" | ||
Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory, | ||
directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy | ||
to keep numbers of checkpoint directory, the numbers of checkpoint directory are max_num_checkpoints at most, | ||
The interval time between two save_checkpoint must great than or equal to save_interval_secs. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The interval between two saved checkpoints must greater than save_interval_secs. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
|
||
:param dirname | ||
:param max_num_checkpoints | ||
:param save_secs | ||
:param main_program | ||
""" | ||
if dirname is None: | ||
dirname = os.getcwd() | ||
|
||
if not os.path.isdir(dirname): | ||
os.makedirs(dirname) | ||
|
||
serial = _get_lastest_checkpoint_dir(dirname) | ||
if serial >= 0 and not _interval_secs_exceed( | ||
os.path.join(dirname, str(serial)), save_interval_secs): | ||
return | ||
|
||
serial = serial + 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
cur_dir = os.path.join(dirname, str(serial)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The checkpoint directories will be saved in "1", "2", etc. which may not make sense to users, I think it's better to be like "checkpoint_1", "checkpoint_2", and the "_SUCCESS" file can save a timestamp when the checkpoint is saved. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
|
||
save_vars( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. save_psersistables can not filter out gradient variables. |
||
executor, | ||
dirname=cur_dir, | ||
main_program=main_program, | ||
vars=None, | ||
predicate=_is_checkpoint_var, | ||
filename=None) | ||
_write_success(cur_dir) | ||
_lru_delete(dirname, max_num_checkpoints) | ||
|
||
|
||
def load_checkpoint(executor, dirname=None, main_program=None): | ||
""" | ||
Load checkpoint from directory by executor, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. directory => one directory There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
it will find lastest checkpoint file and load it auto. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. latest => the most recent saved checkpoint |
||
|
||
:param executor | ||
:param dirname | ||
:param main_program | ||
""" | ||
|
||
if dirname is None: | ||
dirname = os.getcwd() | ||
|
||
serial = _get_lastest_checkpoint_dir(dirname) | ||
|
||
if serial < 0: | ||
return | ||
cur_dir = os.path.join(dirname, str(serial)) | ||
|
||
load_vars( | ||
executor, | ||
dirname=cur_dir, | ||
main_program=main_program, | ||
predicate=_is_checkpoint_var, | ||
filename=None) | ||
|
||
|
||
def _is_checkpoint_var(var): | ||
""" | ||
checkpoint will not save or load all the variables. | ||
var type is FEED_MINIBATCH/FETCH_LIST/RAW and var name is end with @GRAD are discarded. | ||
|
||
:param var | ||
""" | ||
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ | ||
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \ | ||
var.desc.type() == core.VarDesc.VarType.RAW: | ||
return False | ||
|
||
if var.name.endswith("@GRAD"): | ||
return False | ||
|
||
return var.persistable | ||
|
||
|
||
def _interval_secs_exceed(dirname, save_interval_secs): | ||
dir_time = os.path.getmtime(dirname) | ||
if save_interval_secs > (time.time() - dir_time): | ||
return False | ||
return True | ||
|
||
|
||
def _lru_delete(dirname, max_num_checkpoints=3): | ||
dirs = os.listdir(dirname) | ||
serials = [] | ||
for serial in dirs: | ||
try: | ||
serials.append(int(serial)) | ||
except ValueError: | ||
continue | ||
|
||
if len(serials) <= max_num_checkpoints: | ||
return | ||
|
||
serials.sort(reverse=True) | ||
serials = serials[max_num_checkpoints:] | ||
for serial in serials: | ||
cur_dir = os.path.join(dirname, str(serial)) | ||
shutil.rmtree(cur_dir) | ||
|
||
|
||
def _write_success(dirname): | ||
""" | ||
write an empty _SUCCESS file to checkpoint dir, indicate this checkpoint is correct. | ||
|
||
:param dirname | ||
""" | ||
success_file = os.path.join(dirname, SUCCESS_MARK_FILENAME) | ||
with open(success_file, 'a'): | ||
pass | ||
|
||
|
||
def _get_lastest_checkpoint_dir(checkpoint_dir): | ||
""" | ||
get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory | ||
|
||
:param checkpoint_dir | ||
""" | ||
if not checkpoint_dir.strip(): | ||
return -1 | ||
|
||
def has_success(checkpoint_dir, cur_dir): | ||
""" | ||
is _SUCCESS in this dir | ||
""" | ||
if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)): | ||
return -1 | ||
|
||
try: | ||
int(cur_dir) | ||
except ValueError: | ||
return -1 | ||
|
||
success_path = os.path.join(checkpoint_dir, cur_dir, | ||
SUCCESS_MARK_FILENAME) | ||
if os.path.isfile(success_path): | ||
return int(cur_dir) | ||
|
||
if not os.path.isdir(checkpoint_dir): | ||
return -1 | ||
|
||
current_dir = -1 | ||
dirs = os.listdir(checkpoint_dir) | ||
for cur_dir in dirs: | ||
success_num = has_success(checkpoint_dir, cur_dir) | ||
if success_num > current_dir: | ||
current_dir = success_num | ||
return current_dir |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change may not be nessesary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am sorry about it, I will revert it later.