Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add checkpoint util class and implement #10532

Merged
merged 59 commits into from
May 23, 2018
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
568a329
add checkpoint util class and implement
seiriosPlus May 9, 2018
1fabbba
modify const to const &
seiriosPlus May 10, 2018
77c6b71
add ckpt to sync loop
seiriosPlus May 10, 2018
b81671e
add ckpt attr to pserver python config
seiriosPlus May 10, 2018
e21a72d
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into checkpoint
seiriosPlus May 11, 2018
2a05b3d
delete checkpoint function
seiriosPlus May 11, 2018
87a0856
add checkpoint save op
seiriosPlus May 11, 2018
dc534fc
add checkpoint save op test
seiriosPlus May 11, 2018
802d10c
rename cpkt_save_op
seiriosPlus May 11, 2018
d1bd3fd
add build and test make
seiriosPlus May 11, 2018
5e74db3
add build and test make
seiriosPlus May 11, 2018
a1419f1
test add op declare
seiriosPlus May 11, 2018
461d2fc
rename ckpt -> checkpoint
seiriosPlus May 14, 2018
2f4c039
rename, modify ckpt structure
seiriosPlus May 14, 2018
38596cf
move file_path to dir
seiriosPlus May 14, 2018
ce1bcc9
add op to framework.py
seiriosPlus May 14, 2018
3c82006
remove overwrite judge to test load
seiriosPlus May 14, 2018
f04b23a
add checkpoint_load, update checkpoint save
seiriosPlus May 15, 2018
c80125f
add checkpoint_load to python framework
seiriosPlus May 15, 2018
2e25e73
write checkpoint_load code simply
seiriosPlus May 15, 2018
30b50dc
fix Serial output type
seiriosPlus May 15, 2018
0334d49
fix bug
seiriosPlus May 15, 2018
d081256
add api in distribute transpiler
seiriosPlus May 16, 2018
886897c
load implement
seiriosPlus May 16, 2018
9cf47af
modify get trainer param
seiriosPlus May 16, 2018
c6f042f
modify load op
seiriosPlus May 16, 2018
b677d82
bug fix
seiriosPlus May 16, 2018
744e95d
add ckpt load
seiriosPlus May 16, 2018
955c793
add X to test
seiriosPlus May 16, 2018
3dd2746
modify Get -> GetMutable
seiriosPlus May 16, 2018
4220b31
update pserver startup
seiriosPlus May 16, 2018
6d53dce
optimized checkpoint serial number and folder
seiriosPlus May 17, 2018
8430c8d
remove boost filesystem
seiriosPlus May 17, 2018
7b6c0ab
modify variable point
seiriosPlus May 17, 2018
f9d4b9d
fix auto serial_num has no initializer
seiriosPlus May 17, 2018
a4fd375
bug fix
seiriosPlus May 18, 2018
f688652
bug fix
seiriosPlus May 18, 2018
821acdb
update op to trianer and pserver
seiriosPlus May 18, 2018
eff92d0
merge develop
seiriosPlus May 18, 2018
cd98f2b
bug fix
seiriosPlus May 18, 2018
dbd0237
fix serial number
seiriosPlus May 18, 2018
22df4c2
fix serial number
seiriosPlus May 18, 2018
d98480c
fix serial number
seiriosPlus May 18, 2018
ee91e48
fix serial number
seiriosPlus May 18, 2018
b6ee59a
optimize python checkpint dir config
seiriosPlus May 18, 2018
e130bf3
optimize python checkpint dir config
seiriosPlus May 18, 2018
5451c78
add checkpoint in io
seiriosPlus May 21, 2018
01975ec
add checkpoint in io
seiriosPlus May 21, 2018
ed2129c
revert distribute_transpiler.py
seiriosPlus May 21, 2018
be05056
delete old checkpoint code
seiriosPlus May 21, 2018
06aa23b
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into checkpoint
seiriosPlus May 21, 2018
2412dee
code optimized
seiriosPlus May 21, 2018
e901de6
update var name
seiriosPlus May 22, 2018
27b7175
update python annotation
seiriosPlus May 22, 2018
9d98534
update annotation grammar
seiriosPlus May 23, 2018
d96b442
rename checkpoint folder to checkpoint_serial
seiriosPlus May 23, 2018
192f9a5
bug fix
seiriosPlus May 23, 2018
cf3fb24
add clean checkpoint
seiriosPlus May 23, 2018
2c47e06
add clean checkpoint
seiriosPlus May 23, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions paddle/fluid/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,8 @@ op_library(save_op DEPS lod_tensor)
op_library(load_op DEPS lod_tensor)
op_library(save_combine_op DEPS lod_tensor)
op_library(load_combine_op DEPS lod_tensor)
op_library(checkpoint_save_op DEPS lod_tensor)
op_library(checkpoint_load_op DEPS lod_tensor)
op_library(concat_op DEPS concat)

# FIXME(thuan): Move CSP operators to paddle/fluid/framework/operators/concurrency
Expand Down Expand Up @@ -277,5 +279,6 @@ cc_test(beam_search_op_test SRCS beam_search_op_test.cc DEPS lod_tensor beam_sea
cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor memory)
cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op)
cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op)
cc_test(checkpoint_op_test SRCS checkpoint_op_test.cc DEPS checkpoint_save_op checkpoint_load_op)
nv_test(nccl_op_test SRCS nccl_op_test.cu.cc DEPS nccl_op gpu_info device_context)
nv_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor)
214 changes: 214 additions & 0 deletions paddle/fluid/operators/checkpoint_load_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see there's load_op and load_combine_op and corresponding saving ops, on python side, you can also use fluid.io.save_persistables to save all persistable variables.

In order to make save_persistables equal to save a checkpoint, make sure that the state variables are all "persistable" like step counters, learning rates, learning_rate moments etc.

So can you reuse those ops instead of writing some new one?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load_op and save_op are designed for LodTensor variable, But checkpoint will save variables not only LodTensor, and checkpoint has some arguments particular.
At present, checkpoint load/save op and load/save op have no clear-cut distinction.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to reuse current operators maybe check the variable type will be fine.
So what the other variable types are saved in the checkpoint? "RAW" types and "feed" "fetch" may not need to be saved.


Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <stdint.h>
#include <sys/stat.h>
#include <fstream>
#include <numeric>
#include <sstream>
#include <streambuf>
#include <string>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_context.h"

namespace paddle {
namespace operators {

constexpr char kSEP = '/';
// write empty file named _SUCCESS
const char SUCCESS[] = "_SUCCESS";
const char SERIAL_VAR[] = "SERIAL_NUMBER";

static bool FileExists(const std::string &filepath) {
struct stat buffer;
return (stat(filepath.c_str(), &buffer) == 0);
}

static std::string GenePath(const std::string &dir, const std::string &file) {
std::string file_path;
file_path.append(file_path);
file_path.append("/");
file_path.append(file);
return file_path;
}

static bool IsNumber(const std::string &s) {
std::string::const_iterator it = s.begin();
while (it != s.end() && std::isdigit(*it)) ++it;
return !s.empty() && it == s.end();
}

static void LoadInputVars(const framework::Scope &scope,
const platform::Place &place,
const std::vector<std::string> &inp_var_names,
const std::string &dir) {
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);

// todo (tangwei) made it async
for (size_t i = 0; i < inp_var_names.size(); i++) {
auto *var = scope.FindVar(inp_var_names[i]);

PADDLE_ENFORCE(var != nullptr,
"Cannot find variable %s for save_combine_op",
inp_var_names[i]);
PADDLE_ENFORCE(var->IsType<framework::LoDTensor>(),
"LoadCombineOp only supports LoDTensor, %s has wrong type",
inp_var_names[i]);

std::string var_file = GenePath(dir, inp_var_names[i]);
auto *tensor = var->GetMutable<framework::LoDTensor>();
std::ifstream fin(var_file);
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op",
var_file);
framework::DeserializeFromStream(fin, tensor, dev_ctx);
fin.close();
VLOG(3) << " load var: " << inp_var_names[i] << " finished";
}
}

static void LoadStringArgv(const framework::Scope &scope,
const platform::Place &place,
const std::vector<std::string> &argv,
const std::string &dir) {
for (size_t i = 0; i < argv.size(); i++) {
auto *var = scope.FindVar(argv[i]);
std::string *var_str = var->GetMutable<std::string>();
std::string var_file = GenePath(dir, argv[i]);
std::ifstream fin(var_file);
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op",
var_file);
std::getline(fin, *var_str);
fin.close();
VLOG(3) << " load String argv: " << argv[i] << " value is: " << var_str;
}
}

class CheckpointLoadOp : public framework::OperatorBase {
public:
CheckpointLoadOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}

private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
std::string dir = Attr<std::string>("dir");
std::string serial_num_attr = Attr<std::string>("Serial");

PADDLE_ENFORCE(!IsNumber(serial_num_attr),
"Checkpoint Serial must be a number");

std::string serial_var_name = std::string(SERIAL_VAR);
auto *serial_var = scope.FindVar(serial_var_name);
PADDLE_ENFORCE(serial_var != nullptr,
"Cannot find variable %s for checkpoint_load_op",
serial_var_name);

auto *serial_num = serial_var->GetMutable<std::string>();
serial_num->clear();
serial_num->append(serial_num_attr);

VLOG(1) << "CheckpointLoadOp set " << SERIAL_VAR
<< " value: " << serial_num;

std::string success = GenePath(dir, serial_num->c_str());
VLOG(3) << "Load checkpoint from dir: " << success;
success = GenePath(success, SUCCESS);
bool is_present = FileExists(success);
if (!is_present) {
VLOG(1) << "CheckpointLoadOp can not find " << SUCCESS
<< " from: " << success;
return;
}

VLOG(3) << "Ready to load vars to scope";
auto inp_var_names = Inputs("X");
PADDLE_ENFORCE_GT(static_cast<int>(inp_var_names.size()), 0,
"The number of input variables should be greater than 0");
LoadInputVars(scope, place, inp_var_names, dir);

// VLOG(3) << "Ready to load string argv to scope";
// auto argv = Output("Argv");
// LoadStringArgv(scope, place, argv, dir);
}
};

class CheckpointLoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
CheckpointLoadOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"X",
"(vector) Input LoDTensors that need to be saved together in a file.")
.AsDuplicable();
AddOutput(
"Argv",
"(vector) Input LoDTensors that need to be saved together in a file.");
AddComment(R"DOC(
CheckpointLoad operator

This operator will serialize and write a list of input LoDTensor variables
to a file on disk.
)DOC");

AddAttr<std::string>(
"Serial",
"(std::string)"
"The serial number of the checkpoint will to be load.");
AddAttr<std::string>(
"dir",
"(string)"
"The \"file_path\" where the LoDTensor variables will be saved.")
.AddCustomChecker(
[](const std::string &path) { return !path.empty(); });
}
};

class CheckpointLoadOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto out_var_name = op_desc.Output("Argv").front();
auto &out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type);
}
};

class CheckpointLoadOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

REGISTER_OPERATOR(checkpoint_load, ops::CheckpointLoadOp,
paddle::framework::EmptyGradOpMaker,
ops::CheckpointLoadOpProtoMaker,
ops::CheckpointLoadOpVarTypeInference,
ops::CheckpointLoadOpShapeInference);

// REGISTER_OPERATOR(checkpoint_load, ops::CheckpointLoadOp,
// ops::CheckpointLoadOpProtoMaker);
82 changes: 82 additions & 0 deletions paddle/fluid/operators/checkpoint_op_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h"

USE_NO_KERNEL_OP(checkpoint_save)
USE_NO_KERNEL_OP(checkpoint_load)

TEST(CheckpointSaveOp, CPU) {
paddle::framework::Scope scope;
paddle::platform::CPUPlace place;

auto var = scope.Var("test_var");
auto tensor = var->GetMutable<paddle::framework::LoDTensor>();
tensor->Resize({3, 10});
paddle::framework::LoD expect_lod;
expect_lod.resize(1);
expect_lod[0].push_back(0);
expect_lod[0].push_back(1);
expect_lod[0].push_back(2);
expect_lod[0].push_back(3);

tensor->set_lod(expect_lod);
float* expect = tensor->mutable_data<float>(place);
for (int64_t i = 0; i < tensor->numel(); ++i) {
expect[i] = static_cast<float>(paddle::platform::float16(i));
}

scope.Var("SERIAL_NUMBER");

paddle::framework::AttributeMap attrs;
attrs.insert({"dir", std::string("ckpt")});

auto save_op = paddle::framework::OpRegistry::CreateOp(
"checkpoint_save", {{"X", {"test_var"}}}, {}, attrs);
save_op->Run(scope, place);
}

TEST(CheckpointLoadOp, CPU) {
paddle::framework::Scope scope;
paddle::platform::CPUPlace place;

auto var = scope.Var("test_var");
auto tensor = var->GetMutable<paddle::framework::LoDTensor>();
tensor->Resize({3, 10});
paddle::framework::LoD expect_lod;
expect_lod.resize(1);
expect_lod[0].push_back(0);
expect_lod[0].push_back(1);
expect_lod[0].push_back(2);
expect_lod[0].push_back(3);

tensor->set_lod(expect_lod);
float* expect = tensor->mutable_data<float>(place);
for (int64_t i = 0; i < tensor->numel(); ++i) {
expect[i] = static_cast<float>(paddle::platform::float16(i));
}

scope.Var("SERIAL_NUMBER");
auto* serial_num = scope.FindVar("SERIAL_NUMBER")->GetMutable<std::string>();
serial_num->append("0");

paddle::framework::AttributeMap attrs;
attrs.insert({"dir", std::string("ckpt")});
attrs.insert({"Serial", std::string("SERIAL_NUMBER")});

auto load_op = paddle::framework::OpRegistry::CreateOp(
"checkpoint_load", {{"X", {"test_var"}}}, {{"Argv", {}}}, attrs);
load_op->Run(scope, place);
}
Loading