Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add checkpoint util class and implement #10532

Merged
merged 59 commits into from
May 23, 2018
Merged
Show file tree
Hide file tree
Changes from 51 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
568a329
add checkpoint util class and implement
seiriosPlus May 9, 2018
1fabbba
modify const to const &
seiriosPlus May 10, 2018
77c6b71
add ckpt to sync loop
seiriosPlus May 10, 2018
b81671e
add ckpt attr to pserver python config
seiriosPlus May 10, 2018
e21a72d
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into checkpoint
seiriosPlus May 11, 2018
2a05b3d
delete checkpoint function
seiriosPlus May 11, 2018
87a0856
add checkpoint save op
seiriosPlus May 11, 2018
dc534fc
add checkpoint save op test
seiriosPlus May 11, 2018
802d10c
rename cpkt_save_op
seiriosPlus May 11, 2018
d1bd3fd
add build and test make
seiriosPlus May 11, 2018
5e74db3
add build and test make
seiriosPlus May 11, 2018
a1419f1
test add op declare
seiriosPlus May 11, 2018
461d2fc
rename ckpt -> checkpoint
seiriosPlus May 14, 2018
2f4c039
rename, modify ckpt structure
seiriosPlus May 14, 2018
38596cf
move file_path to dir
seiriosPlus May 14, 2018
ce1bcc9
add op to framework.py
seiriosPlus May 14, 2018
3c82006
remove overwrite judge to test load
seiriosPlus May 14, 2018
f04b23a
add checkpoint_load, update checkpoint save
seiriosPlus May 15, 2018
c80125f
add checkpoint_load to python framework
seiriosPlus May 15, 2018
2e25e73
write checkpoint_load code simply
seiriosPlus May 15, 2018
30b50dc
fix Serial output type
seiriosPlus May 15, 2018
0334d49
fix bug
seiriosPlus May 15, 2018
d081256
add api in distribute transpiler
seiriosPlus May 16, 2018
886897c
load implement
seiriosPlus May 16, 2018
9cf47af
modify get trainer param
seiriosPlus May 16, 2018
c6f042f
modify load op
seiriosPlus May 16, 2018
b677d82
bug fix
seiriosPlus May 16, 2018
744e95d
add ckpt load
seiriosPlus May 16, 2018
955c793
add X to test
seiriosPlus May 16, 2018
3dd2746
modify Get -> GetMutable
seiriosPlus May 16, 2018
4220b31
update pserver startup
seiriosPlus May 16, 2018
6d53dce
optimized checkpoint serial number and folder
seiriosPlus May 17, 2018
8430c8d
remove boost filesystem
seiriosPlus May 17, 2018
7b6c0ab
modify variable point
seiriosPlus May 17, 2018
f9d4b9d
fix auto serial_num has no initializer
seiriosPlus May 17, 2018
a4fd375
bug fix
seiriosPlus May 18, 2018
f688652
bug fix
seiriosPlus May 18, 2018
821acdb
update op to trianer and pserver
seiriosPlus May 18, 2018
eff92d0
merge develop
seiriosPlus May 18, 2018
cd98f2b
bug fix
seiriosPlus May 18, 2018
dbd0237
fix serial number
seiriosPlus May 18, 2018
22df4c2
fix serial number
seiriosPlus May 18, 2018
d98480c
fix serial number
seiriosPlus May 18, 2018
ee91e48
fix serial number
seiriosPlus May 18, 2018
b6ee59a
optimize python checkpint dir config
seiriosPlus May 18, 2018
e130bf3
optimize python checkpint dir config
seiriosPlus May 18, 2018
5451c78
add checkpoint in io
seiriosPlus May 21, 2018
01975ec
add checkpoint in io
seiriosPlus May 21, 2018
ed2129c
revert distribute_transpiler.py
seiriosPlus May 21, 2018
be05056
delete old checkpoint code
seiriosPlus May 21, 2018
06aa23b
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into checkpoint
seiriosPlus May 21, 2018
2412dee
code optimized
seiriosPlus May 21, 2018
e901de6
update var name
seiriosPlus May 22, 2018
27b7175
update python annotation
seiriosPlus May 22, 2018
9d98534
update annotation grammar
seiriosPlus May 23, 2018
d96b442
rename checkpoint folder to checkpoint_serial
seiriosPlus May 23, 2018
192f9a5
bug fix
seiriosPlus May 23, 2018
cf3fb24
add clean checkpoint
seiriosPlus May 23, 2018
2c47e06
add clean checkpoint
seiriosPlus May 23, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/operators/listen_and_serv_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
for (auto &var : sparse_vars) {
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change may not be nessesary.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am sorry about it, I will revert it later.

rpc_service_->SetCond(1);
// FIXME(typhoonzero): use another condition to sync wait clients get.
rpc_service_->WaitClientGet(fan_in);
Expand Down
162 changes: 153 additions & 9 deletions python/paddle/fluid/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,17 @@
# limitations under the License.

import os
import time
import shutil

from paddle.fluid.evaluator import Evaluator
from paddle.fluid.framework import Program, Parameter, default_main_program, Variable
from . import core

__all__ = [
'save_vars',
'save_params',
'save_persistables',
'load_vars',
'load_params',
'load_persistables',
'save_inference_model',
'load_inference_model',
'get_inference_program',
'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
'load_persistables', 'save_inference_model', 'load_inference_model',
'get_inference_program', 'save_checkpoint', 'restore_checkpoint'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe it's better to be named as load_checkpoint or restore_from_checkpoint?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will use load_checkpoint , restore_from_checkpoint is too long I think.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

]


Expand Down Expand Up @@ -195,6 +191,8 @@ def load_vars(executor,
load_var_map = {}
for each_var in vars:
assert isinstance(each_var, Variable)
if each_var.type == core.VarDesc.VarType.RAW:
continue
new_var = _clone_var_in_block_(load_block, each_var)
if filename is None:
load_block.append_op(
Expand Down Expand Up @@ -454,3 +452,149 @@ def get_parameter_value_by_name(name, executor, program=None):
program = default_main_program()
var = program.global_block().var(name)
return get_parameter_value(var, executor)


SUCCESS = "_SUCCESS"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SUCCESS = SUCCESS_MARK_FILENAME

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

BEGIN_SECS = None


def save_checkpoint(executor,
dirname,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can use current working directory as default?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

keep_max=3,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keep_max => max_num_checkpoints

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

save_secs=600,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

save_secs => interval

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

main_program=None):
"""
Save Variables to Checkpint Dir
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checkpint => Checkpoint
Dir => Directory

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


:param dirname
:param keep_max
:param save_secs
:param main_program
"""
if dirname is None:
raise Exception("save checkpoint dir can not be none")

if not os.path.isdir(dirname):
os.makedirs(dirname)

global BEGIN_SECS
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try not to use global please

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if BEGIN_SECS is not None:
if time.time() - BEGIN_SECS < save_secs:
return
BEGIN_SECS = time.time()

serial = _get_lastest_checkpoint_dir(dirname) + 1
cur_dir = os.path.join(dirname, str(serial))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The checkpoint directories will be saved in "1", "2", etc. which may not make sense to users, I think it's better to be like "checkpoint_1", "checkpoint_2", and the "_SUCCESS" file can save a timestamp when the checkpoint is saved.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

# save_persistables(executor, cur_dir, main_program)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No commented out codes please.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

save_vars(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why call save_vars instead of save_psersistables?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

save_psersistables can not filter out gradient variables.

executor,
dirname=cur_dir,
main_program=main_program,
vars=None,
predicate=is_checkpoint_var,
filename=None)
_write_success(cur_dir)
_lru_delete(dirname, keep_max)


def restore_checkpoint(dirname, executor, main_program=None):
"""
Load Variables from Checkpint Dir

:param dirname
:param executor
:param main_program
"""
if dirname is None and os.path.isdir(dirname):
raise Exception("restore checkpoint can not load variables from %s" %
dirname)
serial = _get_lastest_checkpoint_dir(dirname)

if serial < 0:
return
cur_dir = os.path.join(dirname, str(serial))
# load_persistables(executor, cur_dir, main_program)
load_vars(
executor,
dirname=cur_dir,
main_program=main_program,
predicate=is_checkpoint_var,
filename=None)


def is_checkpoint_var(var):
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
var.desc.type() == core.VarDesc.VarType.RAW:
return False

if var.name.endswith("@GRAD"):
return False

return var.persistable


def _lru_delete(dirname, keep_max=3):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keep_max => max_num_checkpoints

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"""
retain checkpoint nums with keep_max
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keep_max => max_num_checkpoints

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"""
dirs = os.listdir(dirname)
serials = []
for serial in dirs:
try:
serials.append(int(serial))
except ValueError:
continue

if len(serials) <= keep_max:
return

serials.sort(reverse=True)
serials = serials[keep_max:]
for serial in serials:
cur_dir = os.path.join(dirname, str(serial))
shutil.rmtree(cur_dir)


def _write_success(dirname):
"""
write _SUCCESS to checkpoint dir
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update _SUCCESS

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_SUCCESS is the file name ,not the var name

"""
success_file = os.path.join(dirname, SUCCESS)
with open(success_file, 'a'):
pass


def _get_lastest_checkpoint_dir(checkpoint_dir):
"""
get the biggest number in checkpoint_dir, which has _SUCCESS
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update _SUCCESS

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_SUCCESS is the file name, not the var name

"""
if not checkpoint_dir.strip():
return -1

def has_success(checkpoint_dir, cur_dir):
"""
is _SUCCESS in this dir
"""
if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)):
return -1

try:
int(cur_dir)
except ValueError:
return -1

success_path = os.path.join(checkpoint_dir, cur_dir, SUCCESS)
if os.path.isfile(success_path):
return int(cur_dir)

if not os.path.isdir(checkpoint_dir):
return -1

current_dir = -1
dirs = os.listdir(checkpoint_dir)
for cur_dir in dirs:
success_num = has_success(checkpoint_dir, cur_dir)
if success_num > current_dir:
current_dir = success_num
return current_dir