-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add checkpoint util class and implement #10532
Merged
Merged
Changes from 38 commits
Commits
Show all changes
59 commits
Select commit
Hold shift + click to select a range
568a329
add checkpoint util class and implement
seiriosPlus 1fabbba
modify const to const &
seiriosPlus 77c6b71
add ckpt to sync loop
seiriosPlus b81671e
add ckpt attr to pserver python config
seiriosPlus e21a72d
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into checkpoint
seiriosPlus 2a05b3d
delete checkpoint function
seiriosPlus 87a0856
add checkpoint save op
seiriosPlus dc534fc
add checkpoint save op test
seiriosPlus 802d10c
rename cpkt_save_op
seiriosPlus d1bd3fd
add build and test make
seiriosPlus 5e74db3
add build and test make
seiriosPlus a1419f1
test add op declare
seiriosPlus 461d2fc
rename ckpt -> checkpoint
seiriosPlus 2f4c039
rename, modify ckpt structure
seiriosPlus 38596cf
move file_path to dir
seiriosPlus ce1bcc9
add op to framework.py
seiriosPlus 3c82006
remove overwrite judge to test load
seiriosPlus f04b23a
add checkpoint_load, update checkpoint save
seiriosPlus c80125f
add checkpoint_load to python framework
seiriosPlus 2e25e73
write checkpoint_load code simply
seiriosPlus 30b50dc
fix Serial output type
seiriosPlus 0334d49
fix bug
seiriosPlus d081256
add api in distribute transpiler
seiriosPlus 886897c
load implement
seiriosPlus 9cf47af
modify get trainer param
seiriosPlus c6f042f
modify load op
seiriosPlus b677d82
bug fix
seiriosPlus 744e95d
add ckpt load
seiriosPlus 955c793
add X to test
seiriosPlus 3dd2746
modify Get -> GetMutable
seiriosPlus 4220b31
update pserver startup
seiriosPlus 6d53dce
optimized checkpoint serial number and folder
seiriosPlus 8430c8d
remove boost filesystem
seiriosPlus 7b6c0ab
modify variable point
seiriosPlus f9d4b9d
fix auto serial_num has no initializer
seiriosPlus a4fd375
bug fix
seiriosPlus f688652
bug fix
seiriosPlus 821acdb
update op to trianer and pserver
seiriosPlus eff92d0
merge develop
seiriosPlus cd98f2b
bug fix
seiriosPlus dbd0237
fix serial number
seiriosPlus 22df4c2
fix serial number
seiriosPlus d98480c
fix serial number
seiriosPlus ee91e48
fix serial number
seiriosPlus b6ee59a
optimize python checkpint dir config
seiriosPlus e130bf3
optimize python checkpint dir config
seiriosPlus 5451c78
add checkpoint in io
seiriosPlus 01975ec
add checkpoint in io
seiriosPlus ed2129c
revert distribute_transpiler.py
seiriosPlus be05056
delete old checkpoint code
seiriosPlus 06aa23b
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into checkpoint
seiriosPlus 2412dee
code optimized
seiriosPlus e901de6
update var name
seiriosPlus 27b7175
update python annotation
seiriosPlus 9d98534
update annotation grammar
seiriosPlus d96b442
rename checkpoint folder to checkpoint_serial
seiriosPlus 192f9a5
bug fix
seiriosPlus cf3fb24
add clean checkpoint
seiriosPlus 2c47e06
add clean checkpoint
seiriosPlus File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,214 @@ | ||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include <stdint.h> | ||
#include <sys/stat.h> | ||
#include <fstream> | ||
#include <numeric> | ||
#include <sstream> | ||
#include <streambuf> | ||
#include <string> | ||
#include "paddle/fluid/framework/data_type.h" | ||
#include "paddle/fluid/framework/data_type_transform.h" | ||
#include "paddle/fluid/framework/framework.pb.h" | ||
#include "paddle/fluid/framework/lod_tensor.h" | ||
#include "paddle/fluid/framework/op_registry.h" | ||
#include "paddle/fluid/platform/device_context.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
constexpr char kSEP = '/'; | ||
// write empty file named _SUCCESS | ||
const char SUCCESS[] = "_SUCCESS"; | ||
const char SERIAL_VAR[] = "SERIAL_NUMBER"; | ||
|
||
static bool FileExists(const std::string &filepath) { | ||
struct stat buffer; | ||
return (stat(filepath.c_str(), &buffer) == 0); | ||
} | ||
|
||
static std::string GenePath(const std::string &dir, const std::string &file) { | ||
std::string file_path; | ||
file_path.append(file_path); | ||
file_path.append("/"); | ||
file_path.append(file); | ||
return file_path; | ||
} | ||
|
||
static bool IsNumber(const std::string &s) { | ||
std::string::const_iterator it = s.begin(); | ||
while (it != s.end() && std::isdigit(*it)) ++it; | ||
return !s.empty() && it == s.end(); | ||
} | ||
|
||
static void LoadInputVars(const framework::Scope &scope, | ||
const platform::Place &place, | ||
const std::vector<std::string> &inp_var_names, | ||
const std::string &dir) { | ||
// get device context from pool | ||
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); | ||
auto &dev_ctx = *pool.Get(place); | ||
|
||
// todo (tangwei) made it async | ||
for (size_t i = 0; i < inp_var_names.size(); i++) { | ||
auto *var = scope.FindVar(inp_var_names[i]); | ||
|
||
PADDLE_ENFORCE(var != nullptr, | ||
"Cannot find variable %s for save_combine_op", | ||
inp_var_names[i]); | ||
PADDLE_ENFORCE(var->IsType<framework::LoDTensor>(), | ||
"LoadCombineOp only supports LoDTensor, %s has wrong type", | ||
inp_var_names[i]); | ||
|
||
std::string var_file = GenePath(dir, inp_var_names[i]); | ||
auto *tensor = var->GetMutable<framework::LoDTensor>(); | ||
std::ifstream fin(var_file); | ||
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op", | ||
var_file); | ||
framework::DeserializeFromStream(fin, tensor, dev_ctx); | ||
fin.close(); | ||
VLOG(3) << " load var: " << inp_var_names[i] << " finished"; | ||
} | ||
} | ||
|
||
static void LoadStringArgv(const framework::Scope &scope, | ||
const platform::Place &place, | ||
const std::vector<std::string> &argv, | ||
const std::string &dir) { | ||
for (size_t i = 0; i < argv.size(); i++) { | ||
auto *var = scope.FindVar(argv[i]); | ||
std::string *var_str = var->GetMutable<std::string>(); | ||
std::string var_file = GenePath(dir, argv[i]); | ||
std::ifstream fin(var_file); | ||
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op", | ||
var_file); | ||
std::getline(fin, *var_str); | ||
fin.close(); | ||
VLOG(3) << " load String argv: " << argv[i] << " value is: " << var_str; | ||
} | ||
} | ||
|
||
class CheckpointLoadOp : public framework::OperatorBase { | ||
public: | ||
CheckpointLoadOp(const std::string &type, | ||
const framework::VariableNameMap &inputs, | ||
const framework::VariableNameMap &outputs, | ||
const framework::AttributeMap &attrs) | ||
: OperatorBase(type, inputs, outputs, attrs) {} | ||
|
||
private: | ||
void RunImpl(const framework::Scope &scope, | ||
const platform::Place &place) const override { | ||
std::string dir = Attr<std::string>("dir"); | ||
std::string serial_num_attr = Attr<std::string>("Serial"); | ||
|
||
PADDLE_ENFORCE(!IsNumber(serial_num_attr), | ||
"Checkpoint Serial must be a number"); | ||
|
||
std::string serial_var_name = std::string(SERIAL_VAR); | ||
auto *serial_var = scope.FindVar(serial_var_name); | ||
PADDLE_ENFORCE(serial_var != nullptr, | ||
"Cannot find variable %s for checkpoint_load_op", | ||
serial_var_name); | ||
|
||
auto *serial_num = serial_var->GetMutable<std::string>(); | ||
serial_num->clear(); | ||
serial_num->append(serial_num_attr); | ||
|
||
VLOG(1) << "CheckpointLoadOp set " << SERIAL_VAR | ||
<< " value: " << serial_num; | ||
|
||
std::string success = GenePath(dir, serial_num->c_str()); | ||
VLOG(3) << "Load checkpoint from dir: " << success; | ||
success = GenePath(success, SUCCESS); | ||
bool is_present = FileExists(success); | ||
if (!is_present) { | ||
VLOG(1) << "CheckpointLoadOp can not find " << SUCCESS | ||
<< " from: " << success; | ||
return; | ||
} | ||
|
||
VLOG(3) << "Ready to load vars to scope"; | ||
auto inp_var_names = Inputs("X"); | ||
PADDLE_ENFORCE_GT(static_cast<int>(inp_var_names.size()), 0, | ||
"The number of input variables should be greater than 0"); | ||
LoadInputVars(scope, place, inp_var_names, dir); | ||
|
||
// VLOG(3) << "Ready to load string argv to scope"; | ||
// auto argv = Output("Argv"); | ||
// LoadStringArgv(scope, place, argv, dir); | ||
} | ||
}; | ||
|
||
class CheckpointLoadOpProtoMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
CheckpointLoadOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker) | ||
: OpProtoAndCheckerMaker(proto, op_checker) { | ||
AddInput( | ||
"X", | ||
"(vector) Input LoDTensors that need to be saved together in a file.") | ||
.AsDuplicable(); | ||
AddOutput( | ||
"Argv", | ||
"(vector) Input LoDTensors that need to be saved together in a file."); | ||
AddComment(R"DOC( | ||
CheckpointLoad operator | ||
|
||
This operator will serialize and write a list of input LoDTensor variables | ||
to a file on disk. | ||
)DOC"); | ||
|
||
AddAttr<std::string>( | ||
"Serial", | ||
"(std::string)" | ||
"The serial number of the checkpoint will to be load."); | ||
AddAttr<std::string>( | ||
"dir", | ||
"(string)" | ||
"The \"file_path\" where the LoDTensor variables will be saved.") | ||
.AddCustomChecker( | ||
[](const std::string &path) { return !path.empty(); }); | ||
} | ||
}; | ||
|
||
class CheckpointLoadOpVarTypeInference : public framework::VarTypeInference { | ||
public: | ||
void operator()(const framework::OpDesc &op_desc, | ||
framework::BlockDesc *block) const override { | ||
auto out_var_name = op_desc.Output("Argv").front(); | ||
auto &out_var = block->FindRecursiveOrCreateVar(out_var_name); | ||
auto var_type = framework::proto::VarType::RAW; | ||
out_var.SetType(var_type); | ||
} | ||
}; | ||
|
||
class CheckpointLoadOpShapeInference : public framework::InferShapeBase { | ||
public: | ||
void operator()(framework::InferShapeContext *ctx) const override {} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
|
||
REGISTER_OPERATOR(checkpoint_load, ops::CheckpointLoadOp, | ||
paddle::framework::EmptyGradOpMaker, | ||
ops::CheckpointLoadOpProtoMaker, | ||
ops::CheckpointLoadOpVarTypeInference, | ||
ops::CheckpointLoadOpShapeInference); | ||
|
||
// REGISTER_OPERATOR(checkpoint_load, ops::CheckpointLoadOp, | ||
// ops::CheckpointLoadOpProtoMaker); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "gtest/gtest.h" | ||
#include "paddle/fluid/framework/op_registry.h" | ||
|
||
USE_NO_KERNEL_OP(checkpoint_save) | ||
USE_NO_KERNEL_OP(checkpoint_load) | ||
|
||
TEST(CheckpointSaveOp, CPU) { | ||
paddle::framework::Scope scope; | ||
paddle::platform::CPUPlace place; | ||
|
||
auto var = scope.Var("test_var"); | ||
auto tensor = var->GetMutable<paddle::framework::LoDTensor>(); | ||
tensor->Resize({3, 10}); | ||
paddle::framework::LoD expect_lod; | ||
expect_lod.resize(1); | ||
expect_lod[0].push_back(0); | ||
expect_lod[0].push_back(1); | ||
expect_lod[0].push_back(2); | ||
expect_lod[0].push_back(3); | ||
|
||
tensor->set_lod(expect_lod); | ||
float* expect = tensor->mutable_data<float>(place); | ||
for (int64_t i = 0; i < tensor->numel(); ++i) { | ||
expect[i] = static_cast<float>(paddle::platform::float16(i)); | ||
} | ||
|
||
scope.Var("SERIAL_NUMBER"); | ||
|
||
paddle::framework::AttributeMap attrs; | ||
attrs.insert({"dir", std::string("ckpt")}); | ||
|
||
auto save_op = paddle::framework::OpRegistry::CreateOp( | ||
"checkpoint_save", {{"X", {"test_var"}}}, {}, attrs); | ||
save_op->Run(scope, place); | ||
} | ||
|
||
TEST(CheckpointLoadOp, CPU) { | ||
paddle::framework::Scope scope; | ||
paddle::platform::CPUPlace place; | ||
|
||
auto var = scope.Var("test_var"); | ||
auto tensor = var->GetMutable<paddle::framework::LoDTensor>(); | ||
tensor->Resize({3, 10}); | ||
paddle::framework::LoD expect_lod; | ||
expect_lod.resize(1); | ||
expect_lod[0].push_back(0); | ||
expect_lod[0].push_back(1); | ||
expect_lod[0].push_back(2); | ||
expect_lod[0].push_back(3); | ||
|
||
tensor->set_lod(expect_lod); | ||
float* expect = tensor->mutable_data<float>(place); | ||
for (int64_t i = 0; i < tensor->numel(); ++i) { | ||
expect[i] = static_cast<float>(paddle::platform::float16(i)); | ||
} | ||
|
||
scope.Var("SERIAL_NUMBER"); | ||
auto* serial_num = scope.FindVar("SERIAL_NUMBER")->GetMutable<std::string>(); | ||
serial_num->append("0"); | ||
|
||
paddle::framework::AttributeMap attrs; | ||
attrs.insert({"dir", std::string("ckpt")}); | ||
attrs.insert({"Serial", std::string("SERIAL_NUMBER")}); | ||
|
||
auto load_op = paddle::framework::OpRegistry::CreateOp( | ||
"checkpoint_load", {{"X", {"test_var"}}}, {{"Argv", {}}}, attrs); | ||
load_op->Run(scope, place); | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see there's
load_op
andload_combine_op
and corresponding saving ops, on python side, you can also usefluid.io.save_persistables
to save all persistable variables.In order to make
save_persistables
equal to save a checkpoint, make sure that the state variables are all "persistable" like step counters, learning rates, learning_rate moments etc.So can you reuse those ops instead of writing some new one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
load_op
andsave_op
are designed forLodTensor variable
, But checkpoint will save variables not onlyLodTensor
, andcheckpoint
has some arguments particular.At present,
checkpoint load/save op
andload/save op
have no clear-cut distinction.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's better to reuse current operators maybe check the variable type will be fine.
So what the other variable types are saved in the checkpoint? "RAW" types and "feed" "fetch" may not need to be saved.