Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New Interface #17

Open
wants to merge 5 commits into
base: pass_api
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions cinn/api/fuse_pass_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,24 @@

#pragma once

#include <absl/types/any.h>
#include "cinn/api/op_group_interface.h"

namespace cinn {
namespace api {

using any = absl::any;

class FusePassContext {
public:
FusePassContext() = default;

std::shared_ptr<OpGroupInterface> PickGroup();
virtual void EnableFuse(const OpGroupInterface& first_op_group, const OpGroupInterface& second_op_group) = 0;

void EnableRecompute(const OpGroupInterface& op_group);
// User can cache some group info in context by using this function.
// The group info can be any data and need to create by create_fn.
virtual any* FindOrCreateCachedGroupInfo(const OpGroupInterface& op_group, const std::function<any(const OpGroupInterface& op_group)>& create_fn) = 0;

void EnableVerticalFuse(const OpGroupInterface& first_op_group, const OpGroupInterface& second_op_group);

void EnableHorizontalFuse(const OpGroupInterface& first_op_group, const OpGroupInterface& second_op_group);
protected:
FusePassContext() = default;
};

} // namespace api
Expand Down
127 changes: 127 additions & 0 deletions cinn/api/general_fuse_group.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "cinn/hlir/framework/op_group_interface.h"
#include "cinn/hlir/framework/graph.h"
#include "cinn/hlir/framework/tensor_interface_impl.h"


namespace cinn {
namespace hlir {
namespace framework {


class GeneralFuseGroup : public OpGroupInterface {
public:
GeneralFuseGroup(GroupPtr group) : group_(group) {
TensorNodeFactory
}

const TensorInterfaceList& input_tensors() const {
return input_tensors_;
}

const TensorInterfaceList& output_tensors() const {
return output_tensors_;
}

const std::unordered_set<std::shared_ptr<OpGroupInterface>>& producers() const {
return producers_;
}

const std::unordered_set<std::shared_ptr<OpGroupInterface>>& consumers() const {
return consumers_;
}

private:
bool set_producers(std::unordered_set<std::shared_ptr<OpGroupInterface>> producers) {
producers_ = std::move(producers);
return true;
}

bool set_consumers(std::unordered_set<std::shared_ptr<OpGroupInterface>> consumers) {
consumers_ = std::move(consumers);
return true;
}

friend std::vector<std::shared_ptr<OpGroupInterface>> GroupLite2GeneralFuseGroups(const std::vector<std::shared_ptr<Group>>& fusion_groups);
friend std::vector<std::shared_ptr<Group>> GeneralFuseGroups2GroupList(const std::vector<OpGroupInterface>& op_groups)

private:
std::shared_ptr<Graph::Group> group_;

TensorInterfaceList input_tensors_;
TensorInterfaceList output_tensors_;

std::unordered_set<std::shared_ptr<OpGroupInterface>> producers_;
std::unordered_set<std::shared_ptr<OpGroupInterface>> consumers_;
};

// std::vector<std::shared_ptr<OpGroupInterface>> Graph2GeneralFuseGroups(const Graph& graph) {
// const auto& fusion_groups = graph.fusion_groups;
// return GroupLite2GeneralFuseGroups(fusion_groups)
// }

std::vector<std::shared_ptr<OpGroupInterface>> GroupLite2GeneralFuseGroups(const std::vector<std::shared_ptr<Group>>& fusion_groups) {
std::vector<std::shared_ptr<OpGroupInterface>> result;
result.reserve(fusion_groups.size());
std::unordered_map<const Group*, std::shared_ptr<OpGroupInterface>> op_group_map;

// 1. Create GeneralFuseGroup by Graph::Group
for (const auto& group : fusion_groups) {
result.push_back(std::make_shared<OpGroup>(group));
op_group_map[group.get()] = result.back();
}

// 2. TODO: Set input and output tensors for OpGroup

// 3. Set producers and consumers for OpGroup
for (const auto& group : fusion_groups) {
const auto& producer_groups = group->producer_groups;
std::unordered_set<std::shared_ptr<OpGroupInterface>> producers;
for (const auto& producer_group : producer_groups) {
producers.insert(op_group_map[producer_group.get()]);
}
op_group_map[group.get()]->set_producers(producers);

const auto& consumer_groups = group->consumer_groups;
std::unordered_set<std::shared_ptr<OpGroupInterface>> consumers;
for (const auto& consumer_group : consumer_groups) {
consumers.insert(op_group_map[consumer_groups.get()]);
}
op_group_map[group.get()]->set_producers(consumers);
}

return result;
}

std::vector<std::shared_ptr<Group>> GeneralFuseGroups2GroupList(const std::vector<OpGroupInterface>& op_groups) {
std::vector<std::shared_ptr<Group>> result;
result.reserve(op_groups.size());

// 1. Create Graph::Group by OpGroup
for (const auto& op_group : op_groups) {
result.push_back(op_group.group_);
}

return result;
}



} // namespace framework
} // namespace hlir
} // namespace cinn
113 changes: 113 additions & 0 deletions cinn/api/general_fuse_pass_context.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "cinn/hlir/framework/fuse_pass_context.h"
#include "cinn/hlir/framework/graph.h"

namespace cinn {
namespace hlir {
namespace framework {


class GeneralFusePassContext {
public:
GeneralFusePassContext(Graph* graph) = default;

Graph* graph() {
return graph;
}

std::shared_ptr<OpGroupInterface> PickGroup() {
CHECK(0 <= current_group_index_ && current_group_index_ < op_groups_.size())
<< "Can't find group with current index: " << current_group_index_ << ", the groups size is " << op_groups_.size();
return op_groups_[i++];
}

void EnableRecompute(const OpGroupInterface& op_group) {
op_group_fuse_tag_[&op_group] = std::make_shared<Tag>(Tag::Recompute);
}

bool CanRecompute(const OpGroupInterface& op_group) const{
auto iter = op_group_fuse_tag_.find(&op_group);
if (iter != op_group_fuse_tag_.end()) {
return *iter->second == Tag::Recompute;
}
return false;
}

void EnableVerticalFuse(const OpGroupInterface& first_op_group, const OpGroupInterface& second_op_group) {
auto tag = std::make_shared<Tag>(Tag::VerticalFuse);
op_group_fuse_tag_[&first_op_group] = tag;
op_group_fuse_tag_[&second_op_group] = tag;
}

bool CanVerticalFuse(const OpGroupInterface& first_op_group, const OpGroupInterface& second_op_group) const{
auto iter1 = op_group_fuse_tag_.find(&first_op_group);
if (iter1 != op_group_fuse_tag_.end()) {
auto iter2 = op_group_fuse_tag_.find(&second_op_group);
if (iter2 != op_group_fuse_tag_.end()) {
return iter1->second == iter2->second && *iter1->second == Tag::VerticalFuse;
}
}
return false;
}

void EnableHorizontalFuse(const OpGroupInterface& first_op_group, const OpGroupInterface& second_op_group) {
auto tag = std::make_shared<Tag>(Tag::HorizontalFuse);
op_group_fuse_tag_[&first_op_group] = tag;
op_group_fuse_tag_[&second_op_group] = tag;
}

bool CanHorizontalFuse(const OpGroupInterface& first_op_group, const OpGroupInterface& second_op_group) {
auto iter1 = op_group_fuse_tag_.find(&first_op_group);
if (iter1 != op_group_fuse_tag_.end()) {
auto iter2 = op_group_fuse_tag_.find(&second_op_group);
if (iter2 != op_group_fuse_tag_.end()) {
return iter1->second == iter2->second && *iter1->second == Tag::HorizontalFuse;
}
}
return false;
}

void InsertOpGroup(const std::shared_ptr<OpGroupInterface>& op_group) {
op_groups_.push_back(op_group);
op_groups_set_.insert(op_groups_.back().get());
}

void DeleteOpGroup(const std::shared_ptr<OpGroupInterface>& op_group) {
op_groups_set_.erase(op_group.get());
op_group_fuse_tag_.erase(op_group.get());
}

private:

enum class Tag {
Recompute,
VerticalFuse,
HorizontalFuse
};

Graph* graph_;

std::vector<std::shared_ptr<OpGroupInterface>> op_groups_;
int64_t current_group_index_ = -1;
std::unordered_set<const OpGroupInterface*> op_groups_set_;
std::unordered_map<const OpGroupInterface*, std::shared_ptr<Tag>> op_group_fuse_tag_;
};

} // namespace framework
} // namespace hlir
} // namespace cinn
Loading