-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add checkpoint util class and implement #10532
Changes from 51 commits
568a329
1fabbba
77c6b71
b81671e
e21a72d
2a05b3d
87a0856
dc534fc
802d10c
d1bd3fd
5e74db3
a1419f1
461d2fc
2f4c039
38596cf
ce1bcc9
3c82006
f04b23a
c80125f
2e25e73
30b50dc
0334d49
d081256
886897c
9cf47af
c6f042f
b677d82
744e95d
955c793
3dd2746
4220b31
6d53dce
8430c8d
7b6c0ab
f9d4b9d
a4fd375
f688652
821acdb
eff92d0
cd98f2b
dbd0237
22df4c2
d98480c
ee91e48
b6ee59a
e130bf3
5451c78
01975ec
ed2129c
be05056
06aa23b
2412dee
e901de6
27b7175
9d98534
d96b442
192f9a5
cf3fb24
2c47e06
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,21 +13,17 @@ | |
# limitations under the License. | ||
|
||
import os | ||
import time | ||
import shutil | ||
|
||
from paddle.fluid.evaluator import Evaluator | ||
from paddle.fluid.framework import Program, Parameter, default_main_program, Variable | ||
from . import core | ||
|
||
__all__ = [ | ||
'save_vars', | ||
'save_params', | ||
'save_persistables', | ||
'load_vars', | ||
'load_params', | ||
'load_persistables', | ||
'save_inference_model', | ||
'load_inference_model', | ||
'get_inference_program', | ||
'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params', | ||
'load_persistables', 'save_inference_model', 'load_inference_model', | ||
'get_inference_program', 'save_checkpoint', 'restore_checkpoint' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe it's better to be named as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
] | ||
|
||
|
||
|
@@ -195,6 +191,8 @@ def load_vars(executor, | |
load_var_map = {} | ||
for each_var in vars: | ||
assert isinstance(each_var, Variable) | ||
if each_var.type == core.VarDesc.VarType.RAW: | ||
continue | ||
new_var = _clone_var_in_block_(load_block, each_var) | ||
if filename is None: | ||
load_block.append_op( | ||
|
@@ -454,3 +452,149 @@ def get_parameter_value_by_name(name, executor, program=None): | |
program = default_main_program() | ||
var = program.global_block().var(name) | ||
return get_parameter_value(var, executor) | ||
|
||
|
||
SUCCESS = "_SUCCESS" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. SUCCESS = SUCCESS_MARK_FILENAME There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
BEGIN_SECS = None | ||
|
||
|
||
def save_checkpoint(executor, | ||
dirname, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can use current working directory as default? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
keep_max=3, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. keep_max => max_num_checkpoints There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
save_secs=600, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. save_secs => interval There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
main_program=None): | ||
""" | ||
Save Variables to Checkpint Dir | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Checkpint => Checkpoint There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
|
||
:param dirname | ||
:param keep_max | ||
:param save_secs | ||
:param main_program | ||
""" | ||
if dirname is None: | ||
raise Exception("save checkpoint dir can not be none") | ||
|
||
if not os.path.isdir(dirname): | ||
os.makedirs(dirname) | ||
|
||
global BEGIN_SECS | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. try not to use global please There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
if BEGIN_SECS is not None: | ||
if time.time() - BEGIN_SECS < save_secs: | ||
return | ||
BEGIN_SECS = time.time() | ||
|
||
serial = _get_lastest_checkpoint_dir(dirname) + 1 | ||
cur_dir = os.path.join(dirname, str(serial)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The checkpoint directories will be saved in "1", "2", etc. which may not make sense to users, I think it's better to be like "checkpoint_1", "checkpoint_2", and the "_SUCCESS" file can save a timestamp when the checkpoint is saved. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
# save_persistables(executor, cur_dir, main_program) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No commented out codes please. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
save_vars( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. save_psersistables can not filter out gradient variables. |
||
executor, | ||
dirname=cur_dir, | ||
main_program=main_program, | ||
vars=None, | ||
predicate=is_checkpoint_var, | ||
filename=None) | ||
_write_success(cur_dir) | ||
_lru_delete(dirname, keep_max) | ||
|
||
|
||
def restore_checkpoint(dirname, executor, main_program=None): | ||
""" | ||
Load Variables from Checkpint Dir | ||
|
||
:param dirname | ||
:param executor | ||
:param main_program | ||
""" | ||
if dirname is None and os.path.isdir(dirname): | ||
raise Exception("restore checkpoint can not load variables from %s" % | ||
dirname) | ||
serial = _get_lastest_checkpoint_dir(dirname) | ||
|
||
if serial < 0: | ||
return | ||
cur_dir = os.path.join(dirname, str(serial)) | ||
# load_persistables(executor, cur_dir, main_program) | ||
load_vars( | ||
executor, | ||
dirname=cur_dir, | ||
main_program=main_program, | ||
predicate=is_checkpoint_var, | ||
filename=None) | ||
|
||
|
||
def is_checkpoint_var(var): | ||
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ | ||
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \ | ||
var.desc.type() == core.VarDesc.VarType.RAW: | ||
return False | ||
|
||
if var.name.endswith("@GRAD"): | ||
return False | ||
|
||
return var.persistable | ||
|
||
|
||
def _lru_delete(dirname, keep_max=3): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. keep_max => max_num_checkpoints There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
""" | ||
retain checkpoint nums with keep_max | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. keep_max => max_num_checkpoints There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
""" | ||
dirs = os.listdir(dirname) | ||
serials = [] | ||
for serial in dirs: | ||
try: | ||
serials.append(int(serial)) | ||
except ValueError: | ||
continue | ||
|
||
if len(serials) <= keep_max: | ||
return | ||
|
||
serials.sort(reverse=True) | ||
serials = serials[keep_max:] | ||
for serial in serials: | ||
cur_dir = os.path.join(dirname, str(serial)) | ||
shutil.rmtree(cur_dir) | ||
|
||
|
||
def _write_success(dirname): | ||
""" | ||
write _SUCCESS to checkpoint dir | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. update There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
""" | ||
success_file = os.path.join(dirname, SUCCESS) | ||
with open(success_file, 'a'): | ||
pass | ||
|
||
|
||
def _get_lastest_checkpoint_dir(checkpoint_dir): | ||
""" | ||
get the biggest number in checkpoint_dir, which has _SUCCESS | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. update _SUCCESS There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
""" | ||
if not checkpoint_dir.strip(): | ||
return -1 | ||
|
||
def has_success(checkpoint_dir, cur_dir): | ||
""" | ||
is _SUCCESS in this dir | ||
""" | ||
if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)): | ||
return -1 | ||
|
||
try: | ||
int(cur_dir) | ||
except ValueError: | ||
return -1 | ||
|
||
success_path = os.path.join(checkpoint_dir, cur_dir, SUCCESS) | ||
if os.path.isfile(success_path): | ||
return int(cur_dir) | ||
|
||
if not os.path.isdir(checkpoint_dir): | ||
return -1 | ||
|
||
current_dir = -1 | ||
dirs = os.listdir(checkpoint_dir) | ||
for cur_dir in dirs: | ||
success_num = has_success(checkpoint_dir, cur_dir) | ||
if success_num > current_dir: | ||
current_dir = success_num | ||
return current_dir |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change may not be nessesary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am sorry about it, I will revert it later.