Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add cinn graph symbolization #36417

Merged
merged 21 commits into from
Oct 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
1e96ef6
add cinn graph symbolization
thisjiang Oct 13, 2021
f775779
fix some bug
thisjiang Oct 13, 2021
ac23d17
add paddle scope to cinn scope
thisjiang Oct 14, 2021
5bec7a8
fix conflict
thisjiang Oct 15, 2021
0d01f3e
add paddle scope to CINN scope in Symbolization, and add feed op when…
thisjiang Oct 15, 2021
c85696e
fix some bug
thisjiang Oct 15, 2021
64210e1
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thisjiang Oct 18, 2021
5df7f48
fix some bug by review advices
thisjiang Oct 18, 2021
5f9535d
optimize code problem
thisjiang Oct 18, 2021
47478d3
revert build_cinn_pass and move the change to https://github.com/Padd…
thisjiang Oct 18, 2021
96075e3
fix some bug after co-compilation
thisjiang Oct 19, 2021
2d1abd9
perfect single test script
thisjiang Oct 19, 2021
b25bdea
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thisjiang Oct 19, 2021
2d9a34c
remove scope and rename feed_target to input_tensor
thisjiang Oct 19, 2021
7ed7efe
using std::unordered_map instead of absl::flat_hash_map
thisjiang Oct 20, 2021
cdc9904
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thisjiang Oct 21, 2021
a2a68ce
fix single test bug
thisjiang Oct 22, 2021
6428c7b
revert to preverion for WITH_CINN has add in later PR
thisjiang Oct 22, 2021
4b795f9
full error information for CI
thisjiang Oct 22, 2021
9f269c4
full enfore information for CI pass
thisjiang Oct 23, 2021
467a903
Merge branch 'develop' into add_cinn_graph_symbolization
thisjiang Oct 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions paddle/fluid/framework/paddle2cinn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ cc_library(build_cinn_pass SRCS build_cinn_pass.cc DEPS pass subgraph_detector)

if (WITH_CINN)
cc_library(transform_desc SRCS transform_desc.cc DEPS proto_desc cinn)
cc_library(cinn_graph_symbolization SRCS cinn_graph_symbolization.cc DEPS lod_tensor graph graph_helper transform_desc cinn)

cc_test(test_transform_desc SRCS transform_desc_test.cc DEPS transform_desc)
cc_test(test_cinn_graph_symbolization SRCS cinn_graph_symbolization_test.cc DEPS cinn_graph_symbolization)
endif()

cc_test(cinn_cache_key_test SRCS cinn_cache_key_test.cc DEPS cinn_cache_key)
Expand Down
172 changes: 172 additions & 0 deletions paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.h"

#include <algorithm>
#include <iterator>
#include <queue>
#include <vector>

#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/paddle2cinn/transform_desc.h"
#include "paddle/fluid/framework/variable.h"

#include "cinn/frontend/op_mappers/use_op_mappers.h"
#include "cinn/frontend/var_type_utils.h"

namespace paddle {
namespace framework {
namespace paddle2cinn {

using ir::Graph;
using ir::Node;
using CinnTensor = ::cinn::hlir::framework::Tensor;
using OpMapperContext = CinnGraphSymbolization::OpMapperContext;
using CinnOpDesc = CinnGraphSymbolization::CinnOpDesc;
using FeedInfoMap = CinnGraphSymbolization::FeedInfoMap;

namespace utils {

OpMapperContext::FeedInfo GetCinnFeedInfoFromTensor(const Tensor& tensor) {
OpMapperContext::FeedInfo info;
const auto& dim = tensor.dims();
for (int i = 0; i < dim.size(); i++) {
info.shape.emplace_back(static_cast<int>(dim[i]));
}

auto cinn_var_type = TransformVarDataTypeToCinn(tensor.type());
info.type = ::cinn::frontend::utils::CppVarType2CommonType(cinn_var_type);
return info;
}
} // namespace utils

FeedInfoMap CinnGraphSymbolization::GetFeedInfoMapFromInput() const {
FeedInfoMap feed_map;
for (auto& feed_pair : input_tensors_) {
const auto& feed_name = feed_pair.first;
const auto* tensor = feed_pair.second;

feed_map[feed_name] = utils::GetCinnFeedInfoFromTensor(*tensor);
}
return feed_map;
}

// get the graph's op input Parameter var name set
std::unordered_set<std::string>
CinnGraphSymbolization::GetGraphInputParameterNames() const {
std::unordered_set<std::string> names;

for (auto* node : graph_.Nodes()) {
if (node->IsOp()) {
for (auto* var : node->inputs) {
if (var->Var()->IsParameter()) {
// Only need preserve the input parameter var of graph,
// others do not.
names.insert(var->Name());
}
}
}
}

return names;
}

// Transform paddle scope to cinn, note that we only preserve the graph’s
// input parameter variable and ignore others.
std::shared_ptr<::cinn::hlir::framework::Scope>
CinnGraphSymbolization::CreateCinnScope(const FeedInfoMap& feed_map) const {
auto cinn_scope = ::cinn::hlir::framework::Scope::Create();

// get the graph's input parameter variable name list
auto parameter_names = GetGraphInputParameterNames();

for (const auto& param_name : parameter_names) {
VLOG(4) << "add param var [" << param_name << "] info scope";
// if cannot find var in graph input, skip.
// scope accepte the CINN format name, so here we need transform
// paddle format name to CINN format.
auto* cinn_var = cinn_scope->Var<CinnTensor>(
::cinn::utils::TransValidVarName(param_name));

auto& cinn_tensor = absl::get<CinnTensor>(*cinn_var);
// here we only need preserve dtype and shape, do not need preserve data
auto feed_info = feed_map.at(param_name);
cinn_tensor->set_type(feed_info.type);
cinn_tensor->Resize(::cinn::hlir::framework::Shape(feed_info.shape));
}

return cinn_scope;
}

std::vector<std::unique_ptr<CinnOpDesc>>
CinnGraphSymbolization::TransformAllGraphOpToCinn() const {
std::vector<std::unique_ptr<CinnOpDesc>> cinn_op_descs;

const auto& sorted_ops = ir::TopologySortOperations(graph_);
for (auto* node : sorted_ops) {
cinn_op_descs.emplace_back(std::make_unique<CinnOpDesc>());
auto& cinn_desc = cinn_op_descs.back();

TransformOpDescToCinn(node->Op(), cinn_desc.get());
}
return cinn_op_descs;
}

void CinnGraphSymbolization::RunOp(const CinnOpDesc& op_desc,
const OpMapperContext& ctx) const {
const auto& op_type = op_desc.Type();
auto* kernel = ::cinn::frontend::OpMapperRegistry::Global()->Find(op_type);
PADDLE_ENFORCE_NE(kernel, nullptr,
platform::errors::NotFound(
"Op %s is Not Supported by CINN, please register"
" this op in the CINN repo.",
op_type.c_str()));
VLOG(4) << "Running Op " << op_type;
kernel->Run(op_desc, ctx);
}

void CinnGraphSymbolization::RunGraph(const OpMapperContext& ctx) const {
auto cinn_op_descs = TransformAllGraphOpToCinn();
// run the CINN op one by one, note that all ops
// have been sorted at constructor.
for (auto& op_desc : cinn_op_descs) {
RunOp(*op_desc, ctx);
}
}

::cinn::frontend::Program CinnGraphSymbolization::operator()() {
std::string builder_name = "NetBuilder_of_graph_" + std::to_string(graph_id_);
VLOG(4) << "NetBuilder Name " << builder_name;

::cinn::frontend::NetBuilder builder(builder_name);

auto feed_map = GetFeedInfoMapFromInput();
auto cinn_scope = CreateCinnScope(feed_map);

OpMapperContext ctx(*cinn_scope, target_, &builder, &var_map_,
&var_model_to_program_map_);
// add all tensor's feed info into context
for (auto& feed_pair : feed_map) {
ctx.AddFeedInfo(feed_pair.first, feed_pair.second);
VLOG(4) << "add feed var [" << feed_pair.first << "] info context";
}
RunGraph(ctx);

return builder.Build();
}

} // namespace paddle2cinn
} // namespace framework
} // namespace paddle
128 changes: 128 additions & 0 deletions paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <map>
#include <unordered_map>
#include <unordered_set>

#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"

#include "cinn/frontend/net_builder.h"
#include "cinn/frontend/op_mapper_registry.h"

namespace paddle {
namespace framework {
namespace paddle2cinn {

// An executor accept subgraph which is generated by BuildCinnPass,
// run each op's CINN Op Mapper, finally return a frontend::Program object
// corresponding to the subgraph.
//
// Parameter:
// 1. graph_id:
// the unique graph id, used for generating unique NetBuilder name.
// 2. graph:
// the CINN subgraph whose op are all supported by CINN, and the
// graph is independently of other graph.
// 3. input_tensors:
// all input var nodes of CINN subgraph, they are necessary for
// we need pass the shape and data type into CINN, otherwise the
// NetBuilder may error for the shape not meet the precondition.
//
// Describe:
// The main function is operator(), it will run all op function by CINN
// OpMapper and finally return a program object.
// The executor operator() consisted by the following step:
// 1. create a NetBuilder, it's name is unique for each graph;
// 2. create OpMapperContext, contain scope, target, local var_map and
// local var_model_to_program_map;
// 3. add all feed var into OpMapperContext to pass the shape and type
// into CINN;
// 4. topological sorting graph op nodes;
// 5. transform all op from paddle opdesc format to cinn opdesc format;
// 5. run the CINN op in graph one by one. Note that the graph have been
// topo sorted;
// 6. return the NetBuilder.Build() after all op run.
class CinnGraphSymbolization {
public:
CinnGraphSymbolization(
int64_t graph_id, const ir::Graph& graph,
const ::cinn::common::Target& target,
const std::map<std::string, const LoDTensor*>& input_tensors)
: graph_id_(graph_id),
graph_(graph),
target_(target),
input_tensors_(input_tensors) {}

// run all CINN op in graph by topo sorting then return its NetBuilder
::cinn::frontend::Program operator()();

// return the internal variable map
const std::unordered_map<std::string, ::cinn::frontend::Variable>& var_map()
const {
return var_map_;
}

// return the map from the variable name in paddle model to cinn program.
const std::unordered_map<std::string, std::string>& var_model_to_program_map()
const {
return var_model_to_program_map_;
}

using OpMapperContext = ::cinn::frontend::OpMapperContext;
using FeedInfoMap =
std::unordered_map<std::string, OpMapperContext::FeedInfo>;
using CinnOpDesc = ::cinn::frontend::paddle::cpp::OpDesc;

private:
const int64_t graph_id_;
const ir::Graph& graph_;
const ::cinn::common::Target& target_;
const std::map<std::string, const LoDTensor*>& input_tensors_;

// preserve local variable map
std::unordered_map<std::string, ::cinn::frontend::Variable> var_map_;
std::unordered_map<std::string, std::string> var_model_to_program_map_;

// transform all paddle var desc in feed list into cinn_var_descs_
FeedInfoMap GetFeedInfoMapFromInput() const;

// transform all paddle op desc in graph into cinn op desc
std::vector<std::unique_ptr<CinnOpDesc>> TransformAllGraphOpToCinn() const;

// RunOp accept OpDesc and global run context then run
// it's kernel registered in OpMapper.
// called in RunGraph.
void RunOp(const CinnOpDesc& op_desc, const OpMapperContext& ctx) const;

// preserve var desc, run the op one by one.
void RunGraph(const OpMapperContext& ctx) const;

// create cinn scope and add parameter's feed info into scope
std::shared_ptr<::cinn::hlir::framework::Scope> CreateCinnScope(
const FeedInfoMap& feed_map) const;

// get the graph op's input persistable var name set
std::unordered_set<std::string> GetGraphInputParameterNames() const;

friend class CinnGraphSymbolizationForTest;
};

} // namespace paddle2cinn
} // namespace framework
} // namespace paddle
Loading