Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Pass] Meta-Schedule-Layout-Rewrite #11845

Merged
merged 1 commit into from
Jun 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,12 @@ TVM_DLL Pass AlterOpLayout();
*/
TVM_DLL Pass AutoSchedulerLayoutRewrite();

/*!
* \brief Do layout rewrite according to the tile structure created by meta-schedule.
* \return The pass
*/
TVM_DLL Pass MetaScheduleLayoutRewrite();

/*!
* \brief Given a dest layout, this pass transforms the expr such that most of the ops input data
* layout is changed to the dest layout. In ideal situation, there are only 2 layout transforms, one
Expand Down
14 changes: 14 additions & 0 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,20 @@ class RelayBuildModule : public runtime::ModuleNode {
relay_module = transform::FuseOps()(relay_module);
}
}
if (backend::IsMetaScheduleEnabled() && config_->optional_homogeneous_target.defined()) {
Pass major_pass = transform::MetaScheduleLayoutRewrite();
bool enable_layout_rewrite_targets =
config_->optional_homogeneous_target->kind->device_type == kDLCPU ||
config_->optional_homogeneous_target->GetAttr<String>("device", "") == "mali";
if (enable_layout_rewrite_targets && pass_ctx.PassEnabled(major_pass->Info())) {
With<Target> tctx(config_->optional_homogeneous_target);
relay_module = major_pass(relay_module);
// Defuse ops to fold constants, then fuse them again
relay_module = transform::DefuseOps()(relay_module);
relay_module = transform::FoldConstant()(relay_module);
relay_module = transform::FuseOps()(relay_module);
}
}

relay_module = transform::InferType()(relay_module);

Expand Down
14 changes: 14 additions & 0 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1078,6 +1078,20 @@ IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) {
pass_seqs.push_back(transform::FuseOps());
}
}
if (backend::IsMetaScheduleEnabled() && config_->optional_homogeneous_target.defined()) {
Pass major_pass = transform::MetaScheduleLayoutRewrite();
bool enable_layout_rewrite_targets =
config_->optional_homogeneous_target->kind->device_type == kDLCPU ||
config_->optional_homogeneous_target->GetAttr<String>("device", "") == "mali";
if (enable_layout_rewrite_targets && pass_ctx.PassEnabled(major_pass->Info())) {
With<Target> tctx(config_->optional_homogeneous_target);
pass_seqs.push_back(major_pass);
// Defuse ops to fold constants, then fuse them again
pass_seqs.push_back(transform::DefuseOps());
pass_seqs.push_back(transform::FoldConstant());
pass_seqs.push_back(transform::FuseOps());
}
}

pass_seqs.push_back(transform::ToANormalForm());
pass_seqs.push_back(transform::InferType());
Expand Down
175 changes: 175 additions & 0 deletions src/relay/transforms/meta_schedule_layout_rewrite.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include "./meta_schedule_layout_rewrite.h"

#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>

#include <deque>
#include <mutex>
#include <vector>

#include "../backend/te_compiler.h"

namespace tvm {
namespace relay {

class LayoutIndexQueue {
public:
static LayoutIndexQueue* Global() {
static LayoutIndexQueue inst;
return &inst;
}

void Clear() {
std::lock_guard<std::mutex> lock(mutex_);
queue_.clear();
}

private:
friend class MetaScheduleLayoutRewriter;
std::mutex mutex_;
std::deque<tir::IndexMap> queue_;
};

void MetaScheduleLayoutRewriter::LayoutQueuePush(const tir::IndexMap& index_map) {
LayoutIndexQueue* self = LayoutIndexQueue::Global();
{
std::lock_guard<std::mutex> lock(self->mutex_);
self->queue_.push_back(index_map);
}
}

bool IsSupportedOp(const OpNode* op) {
static std::vector<std::string> target_ops{
"nn.conv2d", //
"nn.contrib_conv2d_winograd_without_weight_transform",
"nn.conv3d",
"nn.matmul",
"nn.dense",
"nn.batch_matmul",
};
return std::find(target_ops.begin(), target_ops.end(), op->name) != target_ops.end();
}

#define TVM_RELAY_LAYOUT_WITH_ORIGINAL_SHAPE(Attr, AttrType, OriginalShape, Result) \
if (const AttrType* ptr = Attr.as<AttrType>()) { \
ObjectPtr<AttrType> n = make_object<AttrType>(*ptr); \
n->meta_schedule_original_shape = OriginalShape; \
Result = Attrs(n); \
}

// Mutate ops in a function
class MetaScheduleFuncMutator : public ExprMutator {
public:
explicit MetaScheduleFuncMutator(std::deque<tir::IndexMap>&& layout_queue)
: layout_queue_(std::move(layout_queue)) {}

Expr VisitExpr_(const CallNode* call) {
Expr expr = ExprMutator::VisitExpr_(call);
if (layout_queue_.empty()) {
return expr;
}
if (const auto* call = expr.as<CallNode>()) {
if (const auto* op = call->op.as<OpNode>()) {
if (IsSupportedOp(op)) {
ICHECK_EQ(call->args.size(), 2);
tir::IndexMap index_map = layout_queue_.front();
layout_queue_.pop_front();
Var var = Downcast<Var>(call->args[1]);
Array<PrimExpr> shape = Downcast<TensorType>(var->type_annotation)->shape;
Attrs attrs{nullptr};
TVM_RELAY_LAYOUT_WITH_ORIGINAL_SHAPE(call->attrs, Conv2DAttrs, shape, attrs);
TVM_RELAY_LAYOUT_WITH_ORIGINAL_SHAPE(call->attrs, Conv2DWinogradAttrs, shape, attrs);
TVM_RELAY_LAYOUT_WITH_ORIGINAL_SHAPE(call->attrs, Conv3DAttrs, shape, attrs);
TVM_RELAY_LAYOUT_WITH_ORIGINAL_SHAPE(call->attrs, MatmulAttrs, shape, attrs);
TVM_RELAY_LAYOUT_WITH_ORIGINAL_SHAPE(call->attrs, DenseAttrs, shape, attrs);
TVM_RELAY_LAYOUT_WITH_ORIGINAL_SHAPE(call->attrs, BatchMatmulAttrs, shape, attrs);
ICHECK(attrs.defined()) << "TypeError: Unknown attribute: " << call->attrs;
expr = Call(call->op,
{call->args[0], MakeMetaScheduleLayoutTransform(call->args[1], index_map)},
attrs);
}
}
}
return expr;
}

private:
std::deque<tir::IndexMap> layout_queue_;
};

#undef TVM_RELAY_LAYOUT_WITH_ORIGINAL_SHAPE

Expr MetaScheduleLayoutRewriter::VisitExpr_(const CallNode* call) {
Expr expr = ExprMutator::VisitExpr_(call);
call = expr.as<CallNode>();
if (call != nullptr) {
if (const auto* func = call->op.as<FunctionNode>()) {
LayoutIndexQueue* self = LayoutIndexQueue::Global();
self->queue_.clear();
tec::PrimFuncFor(GetRef<Function>(func), Target::Current(),
[](std::string name) { return name; });
if (!self->queue_.empty()) {
std::deque<tir::IndexMap> queue = std::move(self->queue_);
self->queue_.clear();
return MetaScheduleFuncMutator(std::move(queue)).VisitExpr(expr);
}
}
}
return expr;
}

namespace transform {

Pass MetaScheduleLayoutRewrite() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) -> Function {
return Downcast<Function>(MetaScheduleLayoutRewriter().Mutate(std::move(f)));
};
return CreateFunctionPass(pass_func, 3, "MetaScheduleLayoutRewrite", {"InferType"});
}

#define TVM_RELAY_META_SCHEDULE_LAYOUT_REWRITE_GET_ORIGINAL_SHAPE(Attrs, AttrType) \
if (const auto* p = Attrs.as<AttrType>()) { \
return p->meta_schedule_original_shape; \
}

TVM_REGISTER_GLOBAL("relay.attrs.get_meta_schedule_original_shape")
.set_body_typed([](const Attrs& attrs) -> Array<PrimExpr> {
TVM_RELAY_META_SCHEDULE_LAYOUT_REWRITE_GET_ORIGINAL_SHAPE(attrs, Conv2DAttrs);
TVM_RELAY_META_SCHEDULE_LAYOUT_REWRITE_GET_ORIGINAL_SHAPE(attrs, Conv2DWinogradAttrs);
TVM_RELAY_META_SCHEDULE_LAYOUT_REWRITE_GET_ORIGINAL_SHAPE(attrs, Conv3DAttrs);
TVM_RELAY_META_SCHEDULE_LAYOUT_REWRITE_GET_ORIGINAL_SHAPE(attrs, MatmulAttrs);
TVM_RELAY_META_SCHEDULE_LAYOUT_REWRITE_GET_ORIGINAL_SHAPE(attrs, DenseAttrs);
TVM_RELAY_META_SCHEDULE_LAYOUT_REWRITE_GET_ORIGINAL_SHAPE(attrs, BatchMatmulAttrs);
LOG(FATAL) << "TypeError: Unknown attribute: " << attrs;
throw;
});
TVM_REGISTER_GLOBAL("relay._transform.MetaScheduleLayoutRewrite")
.set_body_typed(MetaScheduleLayoutRewrite);

#undef TVM_RELAY_META_SCHEDULE_LAYOUT_REWRITE_GET_ORIGINAL_SHAPE

} // namespace transform
} // namespace relay
} // namespace tvm
38 changes: 38 additions & 0 deletions src/relay/transforms/meta_schedule_layout_rewrite.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_RELAY_TRANSFORMS_META_SCHEDULE_LAYOUT_REWRITE_H_
#define TVM_RELAY_TRANSFORMS_META_SCHEDULE_LAYOUT_REWRITE_H_

#include <tvm/relay/expr_functor.h>
#include <tvm/tir/index_map.h>

namespace tvm {
namespace relay {

class MetaScheduleLayoutRewriter : public ExprMutator {
public:
Expr VisitExpr_(const CallNode* n) final;

static void LayoutQueuePush(const tir::IndexMap& index_map);
};

} // namespace relay
} // namespace tvm

#endif // TVM_RELAY_TRANSFORMS_META_SCHEDULE_LAYOUT_REWRITE_H_