Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

add example: MLP on mnist #28

Merged
merged 4 commits into from
Jun 23, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,5 @@ cc_library(tape_function SRCS function.cc DEPS ${GLOB_OP_LIB} tape_variable tape
cc_test(test_tape
SRCS test_tape.cc
DEPS tape tape_variable tape_function)

add_subdirectory(example)
1 change: 0 additions & 1 deletion src/data/.gitignore

This file was deleted.

16 changes: 16 additions & 0 deletions src/example/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

add_subdirectory(mnist)
18 changes: 18 additions & 0 deletions src/example/mnist/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

cc_test(test_mnist
SRCS test_mnist.cc
DEPS tape tape_variable tape_function)
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def create_mnist_recordio_files():
],
place=fluid.CPUPlace())
fluid.recordio_writer.convert_reader_to_recordio_file(
'/tape/src/data/mnist.recordio', reader, feeder)
'/tmp/mnist.recordio', reader, feeder)


if __name__ == "__main__":
Expand Down
87 changes: 87 additions & 0 deletions src/example/mnist/test_mnist.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <fstream>

#include "gtest/gtest.h"
#include "src/function.h"

using paddle::tape::Linear;
using paddle::tape::SGD;
using paddle::tape::mean;
using paddle::tape::softmax;
using paddle::tape::cross_entropy;
using paddle::tape::reset_global_tape;
using paddle::tape::get_global_tape;

using paddle::tape::CreateRecordioFileReader;
using paddle::tape::ReadNext;

bool is_file_exist(const std::string& fileName) {
std::ifstream infile(fileName);
return infile.good();
}

TEST(Mnist, TestCPU) {
std::string filename = "/tmp/mnist.recordio";
PADDLE_ENFORCE(is_file_exist(filename),
"file doesn't exist; have you run create_mnist_recordio.py");
auto reader = CreateRecordioFileReader(
filename, {32, 1, 28, 28, 32, 1}, {4, 2}, {0, 0});

Linear linear1(784, 200, "relu");
Linear linear2(200, 200, "relu");
Linear linear3(200, 10, "relu");
SGD sgd(0.001);

int print_step = 100;
float avg_loss = 0.0;

for (int i = 0; i < 1000; ++i) {
reset_global_tape();
auto data_label = ReadNext(reader);
auto data = data_label[0];
auto label = data_label[1];

auto predict = softmax(linear3(linear2(linear1(data))));
auto loss = mean(cross_entropy(predict, label));
if (i % print_step == 0) {
avg_loss +=
loss->Value().Get<paddle::framework::LoDTensor>().data<float>()[0];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor thing: this avg_loss update step need to be outside if. I can fix this in my pr.

LOG(INFO) << avg_loss;
avg_loss = 0;
}

get_global_tape().Backward(loss);

for (auto w : linear1.Params()) {
sgd.Update(w);
}
for (auto w : linear2.Params()) {
sgd.Update(w);
}
for (auto w : linear3.Params()) {
sgd.Update(w);
}
}
}

int main(int argc, char** argv) {
std::vector<paddle::platform::Place> places;
places.emplace_back(paddle::platform::CPUPlace());
paddle::platform::DeviceContextPool::Init(places);

testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
46 changes: 27 additions & 19 deletions src/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,9 @@ class Fill {

void operator()(VariableHandle var) {
if (initializer_ == "fill_constant") {
// fill_constant is not OperatorWithKernel, so we can't add it to the tape
framework::OpDesc op_desc =
CreateOpDesc(initializer_, {}, {{"Out", {var}}}, attrs_);
ScopeWrapper scope({}, {{"Out", {var}}});
framework::OpRegistry::CreateOp(op_desc)->Run(scope,
platform::CPUPlace());
PADDLE_THROW(
"fill_constant is not supported, since it is not of type "
"OperatorWithKernel");
} else {
get_global_tape().AddOp(initializer_, {}, {{"Out", {var}}}, attrs_);
}
Expand All @@ -73,16 +70,11 @@ void init_params(VariableHandle v,
}
}

// TODO(tonyyang-svail): change this to a function
// https://github.com/PaddlePaddle/tape/issues/23
class Mean {
public:
VariableHandle operator()(VariableHandle var) {
VariableHandle out(new Variable("mean"));
get_global_tape().AddOp("mean", {{"X", {var}}}, {{"Out", {out}}}, {});
return out;
}
};
VariableHandle mean(VariableHandle x) {
VariableHandle out(new Variable("mean"));
get_global_tape().AddOp("mean", {{"X", {x}}}, {{"Out", {out}}}, {});
return out;
}

VariableHandle relu(VariableHandle x) {
VariableHandle out(new Variable("relu"));
Expand Down Expand Up @@ -253,11 +245,27 @@ VariableHandle CreateRecordioFileReader(std::string filename,
return reader;
}

void ReadNext(VariableHandle reader, VariableHandle data_holder) {
std::vector<VariableHandle> ReadNext(VariableHandle reader) {
PADDLE_ENFORCE(reader->Var().IsType<framework::ReaderHolder>());

reader->GetMutable<paddle::framework::ReaderHolder>()->ReadNext(
data_holder->GetMutable<paddle::framework::LoDTensorArray>());
paddle::framework::LoDTensorArray data_holder;
reader->GetMutable<paddle::framework::ReaderHolder>()->ReadNext(&data_holder);
if (data_holder.empty()) {
reader->GetMutable<paddle::framework::ReaderHolder>()->ReInit();
reader->GetMutable<paddle::framework::ReaderHolder>()->ReadNext(
&data_holder);
}
PADDLE_ENFORCE(!data_holder.empty(), "Error reading file.");

std::vector<VariableHandle> rval;
for (size_t i = 0; i < data_holder.size(); ++i) {
rval.emplace_back(new Variable("data" + std::to_string(i)));
auto *lod_tensor = rval.back()->GetMutable<framework::LoDTensor>();
lod_tensor->ShareDataWith(data_holder[i]);
lod_tensor->set_lod(data_holder[i].lod());
}

return rval;
}

} // namespace tape
Expand Down
4 changes: 2 additions & 2 deletions src/tape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ void Tape::AddOp(const std::string &type,
}

void Tape::Forward() {
LOG(INFO) << "Starting forward -------------------------";
VLOG(3) << "Starting forward -------------------------";
PADDLE_ENFORCE(!has_been_backwarded_);
while (current_position_ < tape_.size()) {
OpHandle &op = tape_[current_position_];
Expand All @@ -134,7 +134,7 @@ void Tape::Forward() {
current_position_++;
}

LOG(INFO) << "Finishing forward -------------------------";
VLOG(3) << "Finishing forward -------------------------";
}

void Tape::Backward(VariableHandle target) {
Expand Down
43 changes: 10 additions & 33 deletions src/test_tape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,53 +19,29 @@ using paddle::tape::VariableHandle;
using paddle::tape::Variable;
using paddle::tape::Linear;
using paddle::tape::Convolution2D;
using paddle::tape::Mean;
using paddle::tape::SGD;
using paddle::tape::Fill;
using paddle::tape::mean;
using paddle::tape::softmax;
using paddle::tape::cross_entropy;
using paddle::tape::reset_global_tape;
using paddle::tape::get_global_tape;
using paddle::tape::CreateRecordioFileReader;
using paddle::tape::ReadNext;

TEST(Tape, TestReader) {
VariableHandle data_label(new paddle::tape::Variable("data_label"));
VariableHandle reader = CreateRecordioFileReader(
"/tape/src/data/mnist.recordio", {32, 1, 28, 28, 32, 1}, {4, 2}, {0, 0});
ReadNext(reader, data_label);
LOG(INFO) << *data_label;
}

TEST(Tape, TestRelu) {
std::string initializer = "uniform_random";
paddle::framework::AttributeMap attrs;
attrs["min"] = -1.0f;
attrs["max"] = 1.0f;
attrs["dtype"] = paddle::framework::proto::VarType::Type::VarType_Type_FP32;
attrs["shape"] = std::vector<int>{10};
attrs["seed"] = 123;
Fill filler(initializer, attrs);

VariableHandle input(new Variable("input"));
filler(input);
auto loss = relu(input);
LOG(INFO) << input->Value();
LOG(INFO) << loss->Value();
}

TEST(Tape, TestConv) {
Convolution2D conv1(3, 16, 3, "relu");
Convolution2D conv2(16, 1, 3, "relu");
Mean mean;

SGD sgd(0.001);

std::string initializer = "fill_constant";
std::string initializer = "uniform_random";
paddle::framework::AttributeMap attrs;
attrs["min"] = -1.0f;
attrs["max"] = 1.0f;
attrs["dtype"] = paddle::framework::proto::VarType::Type::VarType_Type_FP32;
attrs["seed"] = 123;
attrs["shape"] = std::vector<int>{32, 3, 8, 8};
attrs["value"] = 1.0f;
Fill filler(initializer, attrs);

for (int i = 0; i < 2; ++i) {
Expand All @@ -90,15 +66,16 @@ TEST(Tape, TestConv) {
TEST(Tape, TestMLP) {
Linear linear1(3, 3, "relu");
Linear linear2(3, 3, "relu");
Mean mean;

SGD sgd(0.001);

std::string initializer = "fill_constant";
std::string initializer = "uniform_random";
paddle::framework::AttributeMap attrs;
attrs["min"] = -1.0f;
attrs["max"] = 1.0f;
attrs["dtype"] = paddle::framework::proto::VarType::Type::VarType_Type_FP32;
attrs["seed"] = 123;
attrs["shape"] = std::vector<int>{3, 3};
attrs["value"] = 1.0f;
Fill filler(initializer, attrs);

for (int i = 0; i < 2; ++i) {
Expand All @@ -121,7 +98,7 @@ TEST(Tape, TestMLP) {
}
}

int main(int argc, char **argv) {
int main(int argc, char** argv) {
std::vector<paddle::platform::Place> places;
places.emplace_back(paddle::platform::CPUPlace());
paddle::platform::DeviceContextPool::Init(places);
Expand Down
2 changes: 1 addition & 1 deletion src/variable.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class Variable {
: name_(pre_fix + (is_grad ? framework::kGradVarSuffix
: std::to_string(count()))) {}

~Variable() { LOG(INFO) << "Deleting " << Name(); }
~Variable() { VLOG(10) << "Deleting " << Name(); }

VariableHandle Grad() {
if (grad_.expired()) {
Expand Down