Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CINN] Add ScheduleBlock graph #56122

Merged
merged 9 commits into from
Aug 16, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/cinn/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ gather_srcs(
message(STATUS "srcs: ${cinnapi_src}")

cinn_cc_test(test_dfs_walker SRCS dfs_walker_test.cc DEPS gtest glog)
cinn_cc_test(test_dfs_topo_walker SRCS dfs_topo_walker_test.cc DEPS gtest glog)
cinn_cc_test(test_is_reachable_predicator SRCS is_reachable_predicator_test.cc
DEPS gtest glog)
cinn_cc_test(test_topo_walker SRCS topo_walker_test.cc DEPS gtest glog)
Expand Down
90 changes: 90 additions & 0 deletions paddle/cinn/common/dfs_topo_walker.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <array>
#include <functional>
#include <stack>
#include <unordered_set>

namespace cinn {
namespace common {

// DFS Topological order walker
template <typename NodeType,
typename NodeHash = std::hash<NodeType>,
typename NodeEqual = std::equal_to<NodeType>>
class DFSTopoWalker final {
zhhsplendid marked this conversation as resolved.
Show resolved Hide resolved
public:
DFSTopoWalker(const DFSTopoWalker&) = delete;
DFSTopoWalker(DFSTopoWalker&&) = delete;

using NodeHandlerType = std::function<void(NodeType)>;
using NodesVisitorType =
std::function<void(NodeType, const NodeHandlerType&)>;

DFSTopoWalker(const NodesVisitorType& VisitPreNodes,
const NodesVisitorType& VisitNextNodes)
: VisitPreNodes_(VisitPreNodes), VisitNextNodes_(VisitNextNodes) {}

void operator()(NodeType node, const NodeHandlerType& NodeHandler) const {
zhhsplendid marked this conversation as resolved.
Show resolved Hide resolved
std::array<NodeType, 1> nodes{node};
(*this)(nodes.begin(), nodes.end(), NodeHandler);
}

template <typename NodeIt>
void operator()(NodeIt begin,
NodeIt end,
const NodeHandlerType& NodeHandler) const {
std::stack<NodeType> node_stack;
std::unordered_set<NodeType, NodeHash, NodeEqual> visited;
std::unordered_map<NodeType, int, NodeHash, NodeEqual> in_degree;
const auto& InitInDegree = [&](NodeType node) {
if (in_degree.count(node) == 0) {
in_degree[node] = 0;
VisitPreNodes_(node, [&](NodeType in_node) { ++in_degree[node]; });
}
};
const auto& UpdateInDegree = [&](NodeType node) {
InitInDegree(node);
--in_degree[node];
};
const auto& TryPush = [&](NodeType node) {
InitInDegree(node);
if (visited.count(node) == 0 && in_degree[node] == 0) {
node_stack.push(node);
visited.insert(node);
}
};

for (NodeIt iter = begin; iter != end; ++iter) {
zhhsplendid marked this conversation as resolved.
Show resolved Hide resolved
TryPush(*iter);
while (!node_stack.empty()) {
NodeType cur = node_stack.top();
node_stack.pop();
NodeHandler(cur);
VisitNextNodes_(cur, UpdateInDegree);
VisitNextNodes_(cur, TryPush);
}
}
}

private:
NodesVisitorType VisitNextNodes_;
NodesVisitorType VisitPreNodes_;
};

} // namespace common
} // namespace cinn
54 changes: 54 additions & 0 deletions paddle/cinn/common/dfs_topo_walker_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <glog/logging.h>
#include <gtest/gtest.h>

#include "paddle/cinn/common/dfs_topo_walker.h"

namespace cinn {
namespace common {

TEST(DFSTopoWalker, simple) {
std::vector<std::pair<int, int>> edges{
{0, 1}, {2, 3}, {1, 3}, {0, 3}, {3, 4}};
DFSTopoWalker<int> walker(
[&](int node, const std::function<void(int)>& NodeHandler) {
for (const auto& pair : edges) {
if (pair.second == node) {
NodeHandler(pair.first);
}
}
},
[&](int node, const std::function<void(int)>& NodeHandler) {
for (const auto& pair : edges) {
if (pair.first == node) {
NodeHandler(pair.second);
}
}
});
std::vector<int> sources{0, 2};
std::vector<int> outputs;
walker(sources.begin(), sources.end(), [&](int node) {
outputs.push_back(node);
});
for (auto output : outputs) {
LOG(INFO) << output;
}
std::vector<int> expected{0, 1, 2, 3, 4};
zhhsplendid marked this conversation as resolved.
Show resolved Hide resolved
EXPECT_TRUE((outputs == expected));
}

} // namespace common
} // namespace cinn
3 changes: 2 additions & 1 deletion paddle/cinn/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ gather_srcs(
module.cc
lowered_func.cc
intrinsic_ops.cc
layout.cc)
layout.cc
schedule_block_graph.cc)

add_subdirectory(op)
add_subdirectory(test)
Expand Down
27 changes: 24 additions & 3 deletions paddle/cinn/ir/schedule/ir_schedule_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -860,11 +860,18 @@ std::vector<Expr> GetProducers(const Expr& block, const Expr& root) {
auto compute_body = block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->body;
std::string block_name = block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name;
ir::CollectIRNodesWithoutTensor(
compute_body, [&producer_tensor_names](const Expr* x) {
compute_body, [&producer_tensor_names, &block_name](const Expr* x) {
auto* load = x->As<ir::Load>();
if (load) {
producer_tensor_names.insert(load->tensor.as_tensor()->name);
if (load->tensor.as_tensor()->name == block_name) {
producer_tensor_names.insert(
GenReduceInitTensorNameOf(load->tensor.as_tensor()->name));
}
return true;
}
return false;
Expand Down Expand Up @@ -896,6 +903,18 @@ std::vector<Expr> GetConsumers(const Expr& block, const Expr& root) {
CHECK(root.As<ir::ScheduleBlockRealize>());
std::vector<Expr> consumers;
std::string block_tensor = GetTensor(block)->name;
if (IsReduceInitTensorName(block_tensor)) {
std::string consumer_name = GetOriginalReduceTensorName(block_tensor);
auto consumer = ir::CollectIRNodesWithoutTensor(root, [&](const Expr* x) {
return x->As<ir::ScheduleBlockRealize>() &&
x->As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name == consumer_name;
});
CHECK_EQ(consumer.size(), 1);
return {*consumer.begin()};
}

auto find_block = ir::CollectIRNodesWithoutTensor(root, [&](const Expr* x) {
return x->As<ir::ScheduleBlockRealize>() && *x != block && *x != root;
});
Expand Down Expand Up @@ -997,10 +1016,12 @@ std::vector<IterRange> CalculateRequiredRegions(
// deduce accessed regions of the provided tensor in block by itering each
// required block
for (const Expr& pro_node : provided_nodes) {
const std::string& provided_tensor_name =
std::string provided_tensor_name =
is_store_provided ? pro_node.As<ir::Store>()->tensor.as_tensor()->name
: pro_node.As<ir::Load>()->tensor.as_tensor()->name;

if (IsReduceInitTensorName(provided_tensor_name)) {
provided_tensor_name = GetOriginalReduceTensorName(provided_tensor_name);
}
for (const Expr& req_block : required_blocks) {
CHECK(req_block.As<ir::ScheduleBlockRealize>());
Expr block_body =
Expand Down
Loading