Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add checkpoint util class and implement #10532

Merged
merged 59 commits into from
May 23, 2018
Merged
Show file tree
Hide file tree
Changes from 54 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
568a329
add checkpoint util class and implement
seiriosPlus May 9, 2018
1fabbba
modify const to const &
seiriosPlus May 10, 2018
77c6b71
add ckpt to sync loop
seiriosPlus May 10, 2018
b81671e
add ckpt attr to pserver python config
seiriosPlus May 10, 2018
e21a72d
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into checkpoint
seiriosPlus May 11, 2018
2a05b3d
delete checkpoint function
seiriosPlus May 11, 2018
87a0856
add checkpoint save op
seiriosPlus May 11, 2018
dc534fc
add checkpoint save op test
seiriosPlus May 11, 2018
802d10c
rename cpkt_save_op
seiriosPlus May 11, 2018
d1bd3fd
add build and test make
seiriosPlus May 11, 2018
5e74db3
add build and test make
seiriosPlus May 11, 2018
a1419f1
test add op declare
seiriosPlus May 11, 2018
461d2fc
rename ckpt -> checkpoint
seiriosPlus May 14, 2018
2f4c039
rename, modify ckpt structure
seiriosPlus May 14, 2018
38596cf
move file_path to dir
seiriosPlus May 14, 2018
ce1bcc9
add op to framework.py
seiriosPlus May 14, 2018
3c82006
remove overwrite judge to test load
seiriosPlus May 14, 2018
f04b23a
add checkpoint_load, update checkpoint save
seiriosPlus May 15, 2018
c80125f
add checkpoint_load to python framework
seiriosPlus May 15, 2018
2e25e73
write checkpoint_load code simply
seiriosPlus May 15, 2018
30b50dc
fix Serial output type
seiriosPlus May 15, 2018
0334d49
fix bug
seiriosPlus May 15, 2018
d081256
add api in distribute transpiler
seiriosPlus May 16, 2018
886897c
load implement
seiriosPlus May 16, 2018
9cf47af
modify get trainer param
seiriosPlus May 16, 2018
c6f042f
modify load op
seiriosPlus May 16, 2018
b677d82
bug fix
seiriosPlus May 16, 2018
744e95d
add ckpt load
seiriosPlus May 16, 2018
955c793
add X to test
seiriosPlus May 16, 2018
3dd2746
modify Get -> GetMutable
seiriosPlus May 16, 2018
4220b31
update pserver startup
seiriosPlus May 16, 2018
6d53dce
optimized checkpoint serial number and folder
seiriosPlus May 17, 2018
8430c8d
remove boost filesystem
seiriosPlus May 17, 2018
7b6c0ab
modify variable point
seiriosPlus May 17, 2018
f9d4b9d
fix auto serial_num has no initializer
seiriosPlus May 17, 2018
a4fd375
bug fix
seiriosPlus May 18, 2018
f688652
bug fix
seiriosPlus May 18, 2018
821acdb
update op to trianer and pserver
seiriosPlus May 18, 2018
eff92d0
merge develop
seiriosPlus May 18, 2018
cd98f2b
bug fix
seiriosPlus May 18, 2018
dbd0237
fix serial number
seiriosPlus May 18, 2018
22df4c2
fix serial number
seiriosPlus May 18, 2018
d98480c
fix serial number
seiriosPlus May 18, 2018
ee91e48
fix serial number
seiriosPlus May 18, 2018
b6ee59a
optimize python checkpint dir config
seiriosPlus May 18, 2018
e130bf3
optimize python checkpint dir config
seiriosPlus May 18, 2018
5451c78
add checkpoint in io
seiriosPlus May 21, 2018
01975ec
add checkpoint in io
seiriosPlus May 21, 2018
ed2129c
revert distribute_transpiler.py
seiriosPlus May 21, 2018
be05056
delete old checkpoint code
seiriosPlus May 21, 2018
06aa23b
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into checkpoint
seiriosPlus May 21, 2018
2412dee
code optimized
seiriosPlus May 21, 2018
e901de6
update var name
seiriosPlus May 22, 2018
27b7175
update python annotation
seiriosPlus May 22, 2018
9d98534
update annotation grammar
seiriosPlus May 23, 2018
d96b442
rename checkpoint folder to checkpoint_serial
seiriosPlus May 23, 2018
192f9a5
bug fix
seiriosPlus May 23, 2018
cf3fb24
add clean checkpoint
seiriosPlus May 23, 2018
2c47e06
add clean checkpoint
seiriosPlus May 23, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/operators/listen_and_serv_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
for (auto &var : sparse_vars) {
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change may not be nessesary.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am sorry about it, I will revert it later.

rpc_service_->SetCond(1);
// FIXME(typhoonzero): use another condition to sync wait clients get.
rpc_service_->WaitClientGet(fan_in);
Expand Down
180 changes: 171 additions & 9 deletions python/paddle/fluid/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,17 @@
# limitations under the License.

import os
import time
import shutil

from paddle.fluid.evaluator import Evaluator
from paddle.fluid.framework import Program, Parameter, default_main_program, Variable
from . import core

__all__ = [
'save_vars',
'save_params',
'save_persistables',
'load_vars',
'load_params',
'load_persistables',
'save_inference_model',
'load_inference_model',
'get_inference_program',
'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
'load_persistables', 'save_inference_model', 'load_inference_model',
'get_inference_program', 'save_checkpoint', 'load_checkpoint'
]


Expand Down Expand Up @@ -195,6 +191,8 @@ def load_vars(executor,
load_var_map = {}
for each_var in vars:
assert isinstance(each_var, Variable)
if each_var.type == core.VarDesc.VarType.RAW:
continue
new_var = _clone_var_in_block_(load_block, each_var)
if filename is None:
load_block.append_op(
Expand Down Expand Up @@ -454,3 +452,167 @@ def get_parameter_value_by_name(name, executor, program=None):
program = default_main_program()
var = program.global_block().var(name)
return get_parameter_value(var, executor)


SUCCESS_MARK_FILENAME = "_SUCCESS"


def save_checkpoint(executor,
dirname=None,
max_num_checkpoints=3,
save_interval_secs=600,
main_program=None):
"""
Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory,
directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy
to keep numbers of checkpoint directory, the numbers of checkpoint directory are max_num_checkpoints at most,
The interval time between two save_checkpoint must great than or equal to save_interval_secs.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The interval between two saved checkpoints must greater than save_interval_secs.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


:param dirname
:param max_num_checkpoints
:param save_secs
:param main_program
"""
if dirname is None:
dirname = os.getcwd()

if not os.path.isdir(dirname):
os.makedirs(dirname)

serial = _get_lastest_checkpoint_dir(dirname)
if serial >= 0 and not _interval_secs_exceed(
os.path.join(dirname, str(serial)), save_interval_secs):
return

serial = serial + 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can use +=

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

cur_dir = os.path.join(dirname, str(serial))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The checkpoint directories will be saved in "1", "2", etc. which may not make sense to users, I think it's better to be like "checkpoint_1", "checkpoint_2", and the "_SUCCESS" file can save a timestamp when the checkpoint is saved.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


save_vars(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why call save_vars instead of save_psersistables?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

save_psersistables can not filter out gradient variables.

executor,
dirname=cur_dir,
main_program=main_program,
vars=None,
predicate=_is_checkpoint_var,
filename=None)
_write_success(cur_dir)
_lru_delete(dirname, max_num_checkpoints)


def load_checkpoint(executor, dirname=None, main_program=None):
"""
Load checkpoint from directory by executor,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

directory => one directory

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

it will find lastest checkpoint file and load it auto.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

latest => the most recent saved checkpoint


:param executor
:param dirname
:param main_program
"""

if dirname is None:
dirname = os.getcwd()

serial = _get_lastest_checkpoint_dir(dirname)

if serial < 0:
return
cur_dir = os.path.join(dirname, str(serial))

load_vars(
executor,
dirname=cur_dir,
main_program=main_program,
predicate=_is_checkpoint_var,
filename=None)


def _is_checkpoint_var(var):
"""
checkpoint will not save or load all the variables.
var type is FEED_MINIBATCH/FETCH_LIST/RAW and var name is end with @GRAD are discarded.

:param var
"""
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
var.desc.type() == core.VarDesc.VarType.RAW:
return False

if var.name.endswith("@GRAD"):
return False

return var.persistable


def _interval_secs_exceed(dirname, save_interval_secs):
dir_time = os.path.getmtime(dirname)
if save_interval_secs > (time.time() - dir_time):
return False
return True


def _lru_delete(dirname, max_num_checkpoints=3):
dirs = os.listdir(dirname)
serials = []
for serial in dirs:
try:
serials.append(int(serial))
except ValueError:
continue

if len(serials) <= max_num_checkpoints:
return

serials.sort(reverse=True)
serials = serials[max_num_checkpoints:]
for serial in serials:
cur_dir = os.path.join(dirname, str(serial))
shutil.rmtree(cur_dir)


def _write_success(dirname):
"""
write an empty _SUCCESS file to checkpoint dir, indicate this checkpoint is correct.

:param dirname
"""
success_file = os.path.join(dirname, SUCCESS_MARK_FILENAME)
with open(success_file, 'a'):
pass


def _get_lastest_checkpoint_dir(checkpoint_dir):
"""
get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory

:param checkpoint_dir
"""
if not checkpoint_dir.strip():
return -1

def has_success(checkpoint_dir, cur_dir):
"""
is _SUCCESS in this dir
"""
if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)):
return -1

try:
int(cur_dir)
except ValueError:
return -1

success_path = os.path.join(checkpoint_dir, cur_dir,
SUCCESS_MARK_FILENAME)
if os.path.isfile(success_path):
return int(cur_dir)

if not os.path.isdir(checkpoint_dir):
return -1

current_dir = -1
dirs = os.listdir(checkpoint_dir)
for cur_dir in dirs:
success_num = has_success(checkpoint_dir, cur_dir)
if success_num > current_dir:
current_dir = success_num
return current_dir