Skip to content

Commit

Permalink
delete_repeated_ops_pass and reshape_unstack_concat_fuse_pass
Browse files Browse the repository at this point in the history
  • Loading branch information
zhupengyang committed Jun 25, 2023
1 parent e66beb0 commit f471937
Show file tree
Hide file tree
Showing 9 changed files with 854 additions and 121 deletions.
2 changes: 2 additions & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,8 @@ if(WITH_XPU)
pass_library(embedding_with_eltwise_add_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(fc_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(reshape_unstack_concat_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(multi_encoder_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(multi_encoder_xpu_adaptive_seqlen_fuse_pass inference DIR xpu
Expand Down
185 changes: 79 additions & 106 deletions paddle/fluid/framework/ir/delete_repeated_ops_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,68 +101,86 @@ class DeleteRepeatedOpsPass : public FusePassBase {
void ApplyImpl(ir::Graph* graph) const override;

private:
int DeleteShapePass(ir::Graph* graph) const;

int DeleteSlicePass(ir::Graph* graph) const;
void DeleteRepeatedOps(
ir::Graph* graph,
const std::string& op_type,
std::function<std::string(OpDesc*)> gen_op_key_fn) const;

const std::string name_scope_{"delete_repeated_ops_pass"};
};

int DeleteRepeatedOpsPass::DeleteShapePass(ir::Graph* graph) const {
void DeleteRepeatedOpsPass::DeleteRepeatedOps(
ir::Graph* graph,
const std::string& op_type,
std::function<std::string(OpDesc*)> gen_op_key_fn) const {
GraphPatternDetector gpd;
patterns::VarWithRepeatedOpsPattern pattern(
gpd.mutable_pattern(), name_scope_, "shape");
gpd.mutable_pattern(), name_scope_, op_type);

int delete_counts = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle DeleteShapePass";
VLOG(4) << "handle DeleteRepeatedOps";
GET_IR_NODE_FROM_SUBGRAPH(in_var, in_var, pattern);

std::vector<std::string> invalid_shape_out_ops{"while",
"conditional_block"};
std::vector<Node*> shapes;
std::vector<std::string> invalid_out_ops{
"while", "conditional_block", "fetch"};
std::map<std::string, std::vector<Node*>> ops_map;
for (auto* next_op : in_var->outputs) {
if (next_op->Name() != "shape") continue;
bool shape_out_op_is_invalid = false;
for (auto* shape_out_op : next_op->outputs[0]->outputs) {
if (std::count(invalid_shape_out_ops.begin(),
invalid_shape_out_ops.end(),
shape_out_op->Name()) > 0 ||
HasOutVarName(shape_out_op, next_op->outputs[0]->Name())) {
shape_out_op_is_invalid = true;
if (next_op->Name() != op_type) continue;
auto* op = next_op;
bool out_op_is_invalid = false;
for (auto* out_op : op->outputs[0]->outputs) {
if (std::count(invalid_out_ops.begin(),
invalid_out_ops.end(),
out_op->Name()) > 0 ||
HasOutVarName(out_op, op->outputs[0]->Name())) {
out_op_is_invalid = true;
break;
}
}
if (!shape_out_op_is_invalid) {
shapes.push_back(next_op);
if (out_op_is_invalid) continue;
auto attr_key = gen_op_key_fn(op->Op());
ops_map[attr_key].push_back(op);
}
for (auto iter = ops_map.begin(); iter != ops_map.end();) {
if (iter->second.size() <= 1) {
iter = ops_map.erase(iter);
} else {
iter++;
}
}
if (shapes.size() <= 1) return;

auto* first_shape_out = shapes[0]->outputs[0];
auto first_shape_out_name = first_shape_out->Name();
std::unordered_set<const Node*> delete_nodes;
for (size_t i = 1; i < shapes.size(); i++) {
auto* cur_shape = shapes[i];
auto* cur_shape_out = cur_shape->outputs[0];
auto cur_shape_out_name = cur_shape_out->Name();
for (auto* shape_out_op : cur_shape_out->outputs) {
shape_out_op->Op()->Rename(cur_shape_out_name, first_shape_out_name);
IR_NODE_LINK_TO(first_shape_out, shape_out_op);
for (auto iter : ops_map) {
auto ops = iter.second;
auto* first_op_out = ops[0]->outputs[0];
auto first_op_out_name = first_op_out->Name();
std::unordered_set<const Node*> delete_nodes;
for (size_t i = 1; i < ops.size(); i++) {
auto* cur_op = ops[i];
auto* cur_op_out = cur_op->outputs[0];
auto cur_op_out_name = cur_op_out->Name();
for (auto* out_op : cur_op_out->outputs) {
out_op->Op()->RenameInput(cur_op_out_name, first_op_out_name);
IR_NODE_LINK_TO(first_op_out, out_op);
}
delete_nodes.insert(cur_op);
delete_nodes.insert(cur_op_out);
delete_counts++;
}
delete_nodes.insert(cur_shape);
delete_nodes.insert(cur_shape_out);
delete_counts++;
GraphSafeRemoveNodes(graph, delete_nodes);
}

GraphSafeRemoveNodes(graph, delete_nodes);
};

gpd(graph, handler);
return delete_counts;
if (delete_counts > 0) {
LOG(INFO) << "--- delete " << delete_counts << " repeated " << op_type
<< " ops";
}
}

std::string GenShapeAttrKey(OpDesc* slice_op_desc) { return ""; }

std::string GenSliceAttrKey(OpDesc* slice_op_desc) {
std::string attr_key;
auto starts = slice_op_desc->GetAttrIfExists<std::vector<int>>("starts");
Expand All @@ -189,85 +207,40 @@ std::string GenSliceAttrKey(OpDesc* slice_op_desc) {
return attr_key;
}

int DeleteRepeatedOpsPass::DeleteSlicePass(ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::VarWithRepeatedOpsPattern pattern(
gpd.mutable_pattern(), name_scope_, "slice");

int delete_counts = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle DeleteSlicePass";
GET_IR_NODE_FROM_SUBGRAPH(in_var, in_var, pattern);

std::vector<std::string> invalid_slice_out_ops{"while",
"conditional_block"};
std::map<std::string, std::vector<Node*>> slice_ops;
for (auto* next_op : in_var->outputs) {
if (next_op->Name() != "slice") continue;
auto* slice = next_op;
bool slice_out_op_is_invalid = false;
for (auto* slice_out_op : slice->outputs[0]->outputs) {
if (std::count(invalid_slice_out_ops.begin(),
invalid_slice_out_ops.end(),
slice_out_op->Name()) > 0 ||
HasOutVarName(slice_out_op, slice->outputs[0]->Name())) {
slice_out_op_is_invalid = true;
break;
}
}
if (slice_out_op_is_invalid) continue;
auto attr_key = GenSliceAttrKey(slice->Op());
slice_ops[attr_key].push_back(slice);
}
for (auto iter = slice_ops.begin(); iter != slice_ops.end();) {
if (iter->second.size() <= 1) {
iter = slice_ops.erase(iter);
} else {
iter++;
}
}
std::string GenCastAttrKey(OpDesc* cast_op_desc) {
auto in_dtype = cast_op_desc->GetAttrIfExists<int>("in_dtype");
auto out_dtype = cast_op_desc->GetAttrIfExists<int>("out_dtype");
return "in_dtype_" + std::to_string(in_dtype) + "_out_dtype_" +
std::to_string(out_dtype);
}

for (auto iter : slice_ops) {
auto slices = iter.second;
auto* first_slice_out = slices[0]->outputs[0];
auto first_slice_out_name = first_slice_out->Name();
std::unordered_set<const Node*> delete_nodes;
for (size_t i = 1; i < slices.size(); i++) {
auto* cur_slice = slices[i];
auto* cur_slice_out = cur_slice->outputs[0];
auto cur_slice_out_name = cur_slice_out->Name();
for (auto* slice_out_op : cur_slice_out->outputs) {
slice_out_op->Op()->RenameInput(cur_slice_out_name,
first_slice_out_name);
IR_NODE_LINK_TO(first_slice_out, slice_out_op);
}
delete_nodes.insert(cur_slice);
delete_nodes.insert(cur_slice_out);
delete_counts++;
}
GraphSafeRemoveNodes(graph, delete_nodes);
}
};
std::string GenAddAttrKey(OpDesc* add_op_desc) {
std::string x_name = add_op_desc->Input("X")[0];
std::string y_name = add_op_desc->Input("Y")[0];
auto axis = add_op_desc->GetAttrIfExists<int>("axis");
return x_name + "_" + y_name + "_axis_" + std::to_string(axis);
}

gpd(graph, handler);
return delete_counts;
std::string GenScaleAttrKey(OpDesc* scale_op_desc) {
auto scale = scale_op_desc->GetAttrIfExists<float>("scale");
auto bias = scale_op_desc->GetAttrIfExists<float>("bias");
auto bias_after_scale =
scale_op_desc->GetAttrIfExists<bool>("bias_after_scale");
return "scale_" + std::to_string(scale) + "_bias_" + std::to_string(bias) +
"_bias_after_scale_" + std::to_string(bias_after_scale);
}

void DeleteRepeatedOpsPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);

int delete_counts = DeleteShapePass(graph);
if (delete_counts > 0) {
LOG(INFO) << "--- delete " << delete_counts << " repeated shape ops";
}

delete_counts = DeleteSlicePass(graph);
if (delete_counts > 0) {
LOG(INFO) << "--- delete " << delete_counts << " repeated slice ops";
}
DeleteRepeatedOps(graph, "shape", GenShapeAttrKey);
DeleteRepeatedOps(graph, "slice", GenSliceAttrKey);
DeleteRepeatedOps(graph, "cast", GenCastAttrKey);
DeleteRepeatedOps(graph, "elementwise_add", GenAddAttrKey);
DeleteRepeatedOps(graph, "scale", GenScaleAttrKey);
DeleteRepeatedOps(graph, "cast", GenCastAttrKey);
}

} // namespace ir
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/ir/pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ static const std::vector<std::string> xpu_support_subgraph_passes = {
"xpu_delete_cast_op_pass",
"fc_xpu_fuse_pass",
"link_xpu_op_max_pass",
"xpu_delete_cast_op_pass",
};

Graph *Pass::Apply(Graph *graph) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@ Reshape2MatmulPattern::Reshape2MatmulPattern(PDPattern* pattern,
->assert_more([](Node* node) {
auto reshape2_in_x_shape = node->Var()->GetShape();
size_t reshape2_in_rank = reshape2_in_x_shape.size();
bool nice_shape =
(reshape2_in_x_shape[2] == 1 && reshape2_in_x_shape[3] == 1) ||
(reshape2_in_x_shape[1] == 1 && reshape2_in_x_shape[3] == 1);
return (reshape2_in_rank == 4 && nice_shape);
return reshape2_in_rank == 4 && ((reshape2_in_x_shape[2] == 1 &&
reshape2_in_x_shape[3] == 1) ||
(reshape2_in_x_shape[1] == 1 &&
reshape2_in_x_shape[3] == 1));
});
auto* reshape2 =
pattern->NewNode(reshape2_repr())
Expand Down
Loading

0 comments on commit f471937

Please sign in to comment.