-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add cinn graph symbolization #36417
add cinn graph symbolization #36417
Conversation
Thanks for your contribution! |
@@ -2,8 +2,10 @@ cc_library(cinn_cache_key SRCS cinn_cache_key.cc DEPS boost graph graph_helper l | |||
cc_library(cinn_compiled_object SRCS cinn_compiled_object.cc DEPS feed_fetch_method graph lod_tensor proto_desc) | |||
cc_library(cinn_runner SRCS cinn_runner.cc DEPS cinn_cache_key cinn_compiled_object feed_fetch_method graph lod_tensor scope) | |||
cc_library(build_cinn_pass SRCS build_cinn_pass.cc DEPS pass subgraph_detector) | |||
cc_library(cinn_graph_symbolization SRCS cinn_graph_symbolization.cc DEPS lod_tensor graph transform_desc cinnapi.so) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么要加.so后缀?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
按照其它cmake上的方法改成了:
cc_library(cinn_graph_symbolization SRCS cinn_graph_symbolization.cc DEPS lod_tensor graph transform_desc)
target_link_libraries(cinn_graph_symbolization cinnapi.so)
namespace paddle2cinn { | ||
|
||
// An executor accept subgraph which is generated by BuildCinnPass, | ||
// run each op's CINN Op Mapper, finally return the graph's CINN NetBuilder. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// run each op's CINN Op Mapper, finally return the graph's CINN NetBuilder. | |
// run each op's CINN Op Mapper, finally return a frontend::Program object corresponding to the subgraph. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
::cinn::frontend::Program operator()() const; | ||
|
||
// return the internal variable map | ||
const auto& var_map() const { return var_map_; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
返回值尽量不要使用auto类型,比较难读懂。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
改为了明确类型
const auto& var_map() const { return var_map_; } | ||
|
||
// return the map from the variable name in paddle model to cinn program. | ||
const auto& var_model_to_program_map() const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
builder_name.append(std::to_string(graph_id_)); | ||
builder_name.append("_of_"); | ||
static uint64_t unique_invoke_number = 0; | ||
builder_name.append(std::to_string(unique_invoke_number++)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unique_invoke_number
的意义是什么?这里永远都是0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
的确,这里之前想着的是Symbolization
对象通用,参数都放在operator()
函数参数里,这样每调用一次生成的NetBuilder
名称都是唯一的,现在这样写参数都放在构造函数里,每个图都生成一个对象,也只会调用一次operator()
函数
} // namespace utils | ||
|
||
// get the graph's op input Parameter var name set | ||
auto CinnGraphSymbolization::GetGraphInputParameterNames() const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不要返回auto,看起来真心不易读。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
|
||
::cinn::frontend::NetBuilder builder(builder_name); | ||
|
||
auto target = ::cinn::common::DefaultHostTarget(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里就确定target?这个应该是外面传进来的。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改为通过构造函数传入
|
||
TransformOpDescToCinn(node->Op(), cinn_desc.get()); | ||
} | ||
return std::move(cinn_op_descs_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不用加move。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
auto CinnGraphSymbolization::TransformAllGraphOpToCinn() const { | ||
std::vector<std::unique_ptr<CinnOpDesc>> cinn_op_descs_; | ||
|
||
const auto& sorted_ops = TopoSortGraph(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
拓扑排序建议使用现有的代码:
std::vector<ir::Node *> TopologySortOperations(const Graph &graph); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
// the subgraph is independently, so here we only need link | ||
// to the node in new subgraph, and discard the link to | ||
// out-graph. | ||
for (auto* op : cluster) { | ||
for (auto* var : op->inputs) { | ||
if (cluster_internals.count(var)) { | ||
old_op2new_op[op]->inputs.emplace_back(old_var2new_var[var]); | ||
} else if (cluster_inputs.count(var)) { | ||
if (var->Var()->IsParameter()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (var->Var()->IsParameter()) { | |
if (!var->Var()->IsParameter()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修复
// link feed var to cluster op | ||
for (auto* old_op : node->outputs) { | ||
if (cluster.count(old_op)) { | ||
var->outputs.emplace_back(old_op2new_op[old_op]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
old_op的输入应该去掉node,加上var。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
现在对于如下CINN支持的计算图:
v1 --
|
| --> mul --> v3 --
| |
v2 -- | --> add --> v5 --> relu --> v6
|
v4 --
生成的子图是不是以下这种形式:
feed --> v1 --
|
| --> mul --> v3 --
| |
v2 -- | --> add --> v5 --> relu
|
v4 --
(v2、v4为权重,且只分别单向作为mul和add的输入)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
生成的子图是这样的:
feed --> new_v1 --
|
| --> mul --> new_v3 --
| |
new_ v2 -- | --> add --> new_v5 --> relu --> new_var6
|
new_v4 --
|
||
std::vector<std::unique_ptr<CinnOpDesc>> | ||
CinnGraphSymbolization::TransformAllGraphOpToCinn() const { | ||
std::vector<std::unique_ptr<CinnOpDesc>> cinn_op_descs_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
局部变量名称的结尾不要加下划线。
} | ||
|
||
void CinnGraphSymbolization::RunGraph(const OpMapperContext& ctx) const { | ||
auto cinn_op_descs_ = TransformAllGraphOpToCinn(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上。
… add_cinn_graph_symbolization
… add_cinn_graph_symbolization
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for PADDLE_ENFORCE
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
APIs
Describe
增加CINN子图符号化函数类,该类接收一个通过
BuildCinnPass
类筛选出的CINN子图,并逐一运行拓扑排序后的每个op,最终返回一个::cinn::frontend::Program
对象。前置PR
当前状态
前置PR均未merge,且由于依赖文件较多,本地编译也不成功。
流程
输入:CINN子图
graph
,子图idgraph_id
,feed列表feed_targets
graph_id
构造出一个唯一的NetBuilder
::cinn::frontend::OpMapperContext
graph
并返回排序后的op列表OpDesc
转换为CINN中的OpDesc
::cinn::frontend::OpMapperRegistry
,找到对应的mapper函数,一一运行该函数builder.Build()
函数返回相应的::cinn::frontend::Program
对象