Skip to content

Commit

Permalink
add load_params test and opt (PaddlePaddle#430)
Browse files Browse the repository at this point in the history
Co-authored-by: wangyue50 <wangyue50@baidu.com>
  • Loading branch information
wenming2014 and wangyue50 authored Aug 25, 2021
1 parent eb8f4f7 commit e347c22
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 13 deletions.
1 change: 1 addition & 0 deletions cinn/hlir/pe/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ core_gather_srcs(SRCS
cc_test(test_cinn_pe_elementwise SRCS pe_elementwise_test.cc DEPS cinncore)
cc_test(test_cinn_pe_broadcast SRCS pe_broadcast_test.cc DEPS cinncore)
cc_test(test_cinn_pe_transform SRCS pe_transform_test.cc DEPS cinncore)
cc_test(test_load_params SRCS load_params_test.cc DEPS cinncore)
52 changes: 52 additions & 0 deletions cinn/hlir/pe/load_params_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#include <gtest/gtest.h>

#include "cinn/hlir/pe/schedule.h"

namespace cinn {
namespace hlir {
namespace pe {
using ir::Tensor;

TEST(load_x86_params, load_x86_params) {
auto &res = ScheduleParam::get_x86_instance().GetParam();
if (res.empty()) {
CreateX86SerialData();
LoadSerialData(&res);
}
std::string key = "X86ScheduleConv input 1 3 224 224 weight 64 3 7 7 stride 2 2 padding 3 3 dilation 1 1";
ASSERT_EQ(res.count(key), 1);

std::unordered_map<std::string, int> conv2d_factors;
auto target = common::DefaultHostTarget();
std::vector<int> shape_input = {1, 64, 56, 56};
std::vector<int> shape_weights = {64, 64, 3, 3};
std::vector<int> strides = {1, 1};
std::vector<int> pads = {1, 1};
std::vector<int> dilations = {1, 1};
key = GenerateX86ConvKey(shape_input, shape_weights, strides, pads, dilations);
GetConv2dFactors(&conv2d_factors, -1, -1, -1, -1, -1, Float(32), target, key);
int ic_bn_size = conv2d_factors["ic_bn"];
int oc_bn_size = conv2d_factors["oc_bn"];
int fc_bn_size = conv2d_factors["fc_bn"];
int ow_bn_size = conv2d_factors["ow_bn"];
int unroll_kw = conv2d_factors["unroll_kw"];
ASSERT_EQ(ic_bn_size, 64);
ASSERT_EQ(fc_bn_size, 64);
ASSERT_EQ(oc_bn_size, 32);
ASSERT_EQ(ow_bn_size, 7);
ASSERT_EQ(unroll_kw, 1);
}

TEST(load_cuda_params, load_cuda_params) {
auto &res = ScheduleParam::get_cuda_instance().GetParam();
if (res.empty()) {
CreateCudaSerialData();
LoadSerialData(&res);
}
std::string key = "CudaScheduleConv 1 3 230 230 64 3 7 7 1 64 112 112";
ASSERT_EQ(res.count(key), 1);
}

} // namespace pe
} // namespace hlir
} // namespace cinn
20 changes: 11 additions & 9 deletions cinn/hlir/pe/schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -269,10 +269,10 @@ void GetConv2dFactors(std::unordered_map<std::string, int> *factors,
const std::string &key,
bool import_params) {
if (import_params) {
auto &params = ScheduleParam::get_instance().GetParam();
auto &params = ScheduleParam::get_x86_instance().GetParam();
if (params.empty()) {
CreateX86SerialData();
LoadSerialData();
LoadSerialData(&params);
}
if (params.count(key)) {
CHECK(!params[key]["oc_bn"].empty());
Expand Down Expand Up @@ -727,6 +727,7 @@ void Conv2d_NCHWc_1X1_Schedule_CPU(poly::StageMap stages,
stages[res]->Reorder({oh_outer1, ow_outer1, oh_inner1, ow_inner1, oc_inner1});
// stages[res]->Fuse({0, 1, 2});
// Todo: computeAt according to forloops' range
// stages[packed_out]->ComputeAt2(stages[res], 2);
VLOG(3) << "stages[res]->transformed_domain()" << stages[res]->transformed_domain();
}
}
Expand Down Expand Up @@ -1035,6 +1036,7 @@ void Conv2d_NCHWc_Schedule_CPU(poly::StageMap stages,
stages[res]->Reorder({oh1, ow_outer1, ow_inner1, oc_inner1});
// stages[res]->Fuse({0, 1, 2});
// Todo: computeAt according to forloops' range
// stages[packed_out]->ComputeAt2(stages[res], 2);
}
}

Expand Down Expand Up @@ -1210,7 +1212,8 @@ int GetMaxSplitter(int a, int b) {
return b;
}

void LoadSerialData(const std::string &file_name) {
void LoadSerialData(std::unordered_map<std::string, std::unordered_map<std::string, std::vector<int>>> *params,
const std::string &file_name) {
proto::ModelData read_model_data;
std::fstream input(file_name, std::ios::in | std::ios::binary);
if (!read_model_data.ParseFromIstream(&input)) {
Expand All @@ -1221,7 +1224,6 @@ void LoadSerialData(const std::string &file_name) {
std::string test_write3;
read_model_data.SerializeToString(&test_write3);
auto read_model_map = read_model_data.data();
auto &res = ScheduleParam::get_instance().GetParam();
for (auto &i : read_model_map) {
auto read_schedule_map = i.second.data();
std::unordered_map<std::string, std::vector<int>> param_data;
Expand All @@ -1232,7 +1234,7 @@ void LoadSerialData(const std::string &file_name) {
}
param_data[j.first] = temp_data;
}
res[i.first] = param_data;
(*params)[i.first] = param_data;
}
}

Expand Down Expand Up @@ -1270,10 +1272,10 @@ void CudaScheduleConv(poly::StageMap stages,
ir::Tensor &weights,
ir::Tensor &output,
const common::Target &target) {
auto &res = ScheduleParam::get_instance().GetParam();
if (res.empty() || res.count("CudaScheduleConv 1 3 230 230 64 3 7 7 1 64 112 112") == 0) {
auto &res = ScheduleParam::get_cuda_instance().GetParam();
if (res.empty()) {
CreateCudaSerialData();
LoadSerialData();
LoadSerialData(&res);
}

int n = output->shape[0].as_int32();
Expand Down Expand Up @@ -1342,7 +1344,7 @@ void CudaScheduleConv2(poly::StageMap stages,
ir::Tensor &output,
const common::Target &target,
const std::string &key) {
auto &res = ScheduleParam::get_instance().GetParam();
auto &res = ScheduleParam::get_cuda_instance().GetParam();
stages[input_pad]->ComputeInline();
optim::Simplify(&(output->shape[2]));
optim::Simplify(&(output->shape[3]));
Expand Down
13 changes: 9 additions & 4 deletions cinn/hlir/pe/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,13 @@ class ScheduleParam {
~ScheduleParam();
ScheduleParam(const ScheduleParam &) = delete;
ScheduleParam &operator=(const ScheduleParam &) = delete;
static ScheduleParam &get_instance() {
static ScheduleParam instance;
return instance;
static ScheduleParam &get_cuda_instance() {
static ScheduleParam cuda_instance;
return cuda_instance;
}
static ScheduleParam &get_x86_instance() {
static ScheduleParam x86_instance;
return x86_instance;
}
std::unordered_map<std::string, std::unordered_map<std::string, std::vector<int>>> &GetParam() { return param_data; }
std::unordered_map<std::string, std::vector<int>> &operator[](const std::string &key) { return param_data[key]; }
Expand Down Expand Up @@ -156,7 +160,8 @@ std::string GenerateX86ConvKey(const std::vector<int> &input_shape,
const std::vector<int> &dilations);
void CreateX86SerialData(const std::string &file_name = "default_serial.log");

void LoadSerialData(const std::string &file_name = "default_serial.log");
void LoadSerialData(std::unordered_map<std::string, std::unordered_map<std::string, std::vector<int>>> *params,
const std::string &file_name = "default_serial.log");

void SaveSerialData(
const std::unordered_map<std::string, std::unordered_map<std::string, std::vector<int>>> &model_data,
Expand Down

0 comments on commit e347c22

Please sign in to comment.