Skip to content

Commit

Permalink
Add BlockFrame (apache#34)
Browse files Browse the repository at this point in the history
* add BlockFrame

* upd

* add T::axis::Spatial/Reduce

* include dom in for-frame

* finish T.axis.remap
  • Loading branch information
junrushao committed Jun 25, 2022
1 parent 48bc108 commit 706ce82
Show file tree
Hide file tree
Showing 6 changed files with 283 additions and 48 deletions.
20 changes: 17 additions & 3 deletions src/script/builder/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ class FrameNode : public runtime::Object {
TVM_DECLARE_BASE_OBJECT_INFO(FrameNode, runtime::Object);

public:
virtual void EnterWithScope() {}

virtual void ExitWithScope() {}

virtual ~FrameNode() {
for (auto it = callbacks.rbegin(); it != callbacks.rend(); ++it) {
(*it)();
Expand All @@ -48,6 +52,17 @@ class FrameNode : public runtime::Object {

class Frame : public runtime::ObjectRef {
public:
void EnterWithScope() {
ICHECK(data_ != nullptr);
static_cast<FrameNode*>(data_.get())->EnterWithScope();
}

void ExitWithScope() {
ICHECK(data_ != nullptr);
static_cast<FrameNode*>(data_.get())->ExitWithScope();
data_.reset();
}

TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Frame, ObjectRef, FrameNode);

protected:
Expand All @@ -67,15 +82,14 @@ class BuilderNode : public runtime::Object {

public:
template <typename TFrame>
TFrame FindFrame() const {
Optional<TFrame> FindFrame() const {
using TFrameNode = typename TFrame::ContainerType;
for (auto it = frames.rbegin(); it != frames.rend(); ++it) {
if (const TFrameNode* p = (*it).template as<TFrameNode>()) {
return GetRef<TFrame>(p);
}
}
LOG(FATAL) << "IndexError: Cannot find frame: " << TFrameNode::_type_key;
throw;
return NullOpt;
}
};

Expand Down
135 changes: 135 additions & 0 deletions src/script/builder/tir/block_frame.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "./block_frame.h"

#include "./for_frame.h"

namespace tvm {
namespace script {
namespace builder {
namespace tir {

BlockFrame::BlockFrame(String name) {
ObjectPtr<BlockFrameNode> n = make_object<BlockFrameNode>();
n->name = name;
n->iter_vars.clear();
n->reads = NullOpt;
n->writes = NullOpt;
n->init = NullOpt;
n->alloc_buffers.clear();
n->match_buffers.clear();
n->annotations.clear();
n->iter_values.clear();
n->predicate = NullOpt;
data_ = n;
}

namespace axis {

// TODO(@junrushao1994): figure out the Block syntax without BlockRealize

tvm::tir::IterVar PushBlockVar(tvm::tir::IterVar iter_var, PrimExpr binding) {
if (const BlockFrameNode* opt_frame = Builder::Current()->frames.back().as<BlockFrameNode>()) {
BlockFrame frame = GetRef<BlockFrame>(opt_frame);
frame->iter_vars.push_back(iter_var);
frame->iter_values.push_back(binding);
} else {
LOG(FATAL) << "TypeError: The last frame is not BlockFrame";
}
return iter_var;
}

tvm::tir::IterVar Spatial(Range dom, PrimExpr binding, DataType dtype) {
using namespace tvm::tir;
ICHECK(dom.defined()) << "Spatial axis must have a domain";
int bits = std::max({dom->min.dtype().bits(), dom->extent.dtype().bits(), dtype.bits()});
return PushBlockVar(IterVar(/*dom=*/dom, //
/*var=*/Var("_", dtype.with_bits(bits)), //
/*iter_type=*/IterVarType::kDataPar, //
/*thread_tag=*/""),
binding);
}

tvm::tir::IterVar Reduce(Range dom, PrimExpr binding, DataType dtype) {
using namespace tvm::tir;
ICHECK(dom.defined()) << "Spatial axis must have a domain";
int bits = std::max({dom->min.dtype().bits(), dom->extent.dtype().bits(), dtype.bits()});
return PushBlockVar(IterVar(/*dom=*/dom, //
/*var=*/Var("_", dtype.with_bits(bits)), //
/*iter_type=*/IterVarType::kCommReduce, //
/*thread_tag=*/""),
binding);
}

Array<tvm::tir::IterVar> Remap(String kinds, Array<PrimExpr> bindings, DataType dtype) {
using namespace tvm::tir;
Array<tvm::tir::IterVar> results;
ICHECK_EQ(kinds.size(), bindings.size());
int n = bindings.size();
results.reserve(n);
for (int i = 0; i < n; ++i) {
char c = kinds.c_str()[i];
PrimExpr e = bindings[i];
const VarNode* v = e.as<VarNode>();
ICHECK(v) << "TypeError: Only Var is supported in T.axis.remap";
Range dom{nullptr};
for (const auto& frame : Builder::Current()->frames) {
if (const auto* for_frame = frame.as<ForFrameNode>()) {
ICHECK_EQ(for_frame->doms.size(), for_frame->vars.size());
int n = for_frame->doms.size();
for (int i = 0; i < n; ++i) {
if (for_frame->vars[i].get() == v) {
dom = for_frame->doms[i];
break;
}
}
if (dom.defined()) {
break;
}
}
}
ICHECK(dom.defined()) << "TypeError: Variable is not in the loop: " << GetRef<Var>(v);
DataType dtype = v->dtype;
if (c == 'S') {
results.push_back(PushBlockVar(IterVar(/*dom=*/dom,
/*var=*/Var("_", dtype),
/*iter_type=*/IterVarType::kDataPar,
/*thread_tag=*/""),
e));
} else if (c == 'R') {
results.push_back(PushBlockVar(IterVar(/*dom=*/dom,
/*var=*/Var("_", dtype),
/*iter_type=*/IterVarType::kCommReduce,
/*thread_tag=*/""),
e));
} else {
LOG(FATAL) << "Unknown axis kind: " << c;
}
}
return results;
}

} // namespace axis

TVM_REGISTER_NODE_TYPE(BlockFrameNode);

} // namespace tir
} // namespace builder
} // namespace script
} // namespace tvm
76 changes: 76 additions & 0 deletions src/script/builder/tir/block_frame.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_SCRIPT_BUILDER_TIR_BLOCK_FRAME_H_
#define TVM_SCRIPT_BUILDER_TIR_BLOCK_FRAME_H_

#include "./tir.h"

namespace tvm {
namespace script {
namespace builder {
namespace tir {

class BlockFrameNode : public TIRFrameNode {
public:
String name;
Array<tvm::tir::IterVar> iter_vars;
Optional<Array<tvm::tir::BufferRegion>> reads;
Optional<Array<tvm::tir::BufferRegion>> writes;
Optional<tvm::tir::Stmt> init;
Array<tvm::tir::Buffer> alloc_buffers;
Array<tvm::tir::MatchBufferRegion> match_buffers;
Map<String, ObjectRef> annotations;

Array<PrimExpr> iter_values;
Optional<PrimExpr> predicate;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("iter_vars", &iter_vars);
v->Visit("reads", &reads);
v->Visit("writes", &writes);
v->Visit("init", &init);
v->Visit("alloc_buffers", &alloc_buffers);
v->Visit("match_buffers", &match_buffers);
v->Visit("annotations", &annotations);
v->Visit("iter_values", &iter_values);
v->Visit("predicate", &predicate);
}

static constexpr const char* _type_key = "script.builder.tir.BlockFrame";
TVM_DECLARE_FINAL_OBJECT_INFO(BlockFrameNode, TIRFrameNode);
};

class BlockFrame : public TIRFrame {
public:
explicit BlockFrame(String name);
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockFrame, TIRFrame, BlockFrameNode);
};

namespace axis {
tvm::tir::IterVar Spatial(Range dom, PrimExpr binding, DataType dtype);
tvm::tir::IterVar Reduce(Range dom, PrimExpr binding, DataType dtype);
Array<tvm::tir::IterVar> Remap(String kinds, Array<PrimExpr> bindings, DataType dtype);
} // namespace axis
} // namespace tir
} // namespace builder
} // namespace script
} // namespace tvm

#endif // TVM_SCRIPT_BUILDER_TIR_BLOCK_FRAME_H_
69 changes: 39 additions & 30 deletions src/script/builder/tir/for_frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,28 @@ namespace script {
namespace builder {
namespace tir {

ForFrame::ForFrame(Array<tvm::tir::Var> loop_vars, ForFrame::FMakeForLoop f_make_for_loop) {
ForFrame::ForFrame(Array<tvm::tir::Var> vars, Array<Range> doms,
ForFrameNode::FMakeForLoop f_make_for_loop) {
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
n->loop_vars = std::move(loop_vars);
n->vars = std::move(vars);
n->doms = std::move(doms);
n->f_make_for_loop = std::move(f_make_for_loop);
data_ = std::move(n);
}

#define TVM_SCRIPT_BUILDER_TIR_FOR_CREATE(Method, Kind) \
With<ForFrame> Method(PrimExpr min, PrimExpr extent, Map<String, ObjectRef> attrs) { \
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>(); \
int bits = std::max(min.dtype().bits(), extent.dtype().bits()); \
n->loop_vars = {tvm::tir::Var("v", DataType::Int(bits))}; \
n->f_make_for_loop = [=](Array<tvm::tir::Var> vars, tvm::tir::Stmt body) -> tvm::tir::For { \
ICHECK_EQ(vars.size(), 1); \
return tvm::tir::For(/*loop_var=*/vars[0], min, extent, Kind, body, \
/*thread_binding=*/NullOpt, attrs); \
}; \
return With<ForFrame>(n); \
#define TVM_SCRIPT_BUILDER_TIR_FOR_CREATE(Method, Kind) \
ForFrame Method(PrimExpr min, PrimExpr extent, Map<String, ObjectRef> attrs) { \
using namespace tvm::tir; \
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>(); \
int bits = std::max(min.dtype().bits(), extent.dtype().bits()); \
n->vars = {Var("v", DataType::Int(bits))}; \
n->doms = {Range(min, extent)}; \
n->f_make_for_loop = [attrs](Array<Var> vars, Array<Range> doms, Stmt body) { \
ICHECK_EQ(vars.size(), 1); \
ICHECK_EQ(doms.size(), 1); \
return For(vars[0], doms[0]->min, doms[0]->extent, Kind, body, NullOpt, attrs); \
}; \
return ForFrame(n); \
}

TVM_SCRIPT_BUILDER_TIR_FOR_CREATE(Serial, tvm::tir::ForKind::kSerial);
Expand All @@ -50,39 +54,44 @@ TVM_SCRIPT_BUILDER_TIR_FOR_CREATE(Unroll, tvm::tir::ForKind::kUnrolled);

#undef TVM_SCRIPT_BUILDER_TIR_FOR_CREATE

With<ForFrame> ThreadBinding(PrimExpr min, PrimExpr extent, String thread,
Map<String, ObjectRef> attrs) {
ForFrame ThreadBinding(PrimExpr min, PrimExpr extent, String thread, Map<String, ObjectRef> attrs) {
using namespace tvm::tir;
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
int bits = std::max(min.dtype().bits(), extent.dtype().bits());
n->loop_vars = {Var("v", DataType::Int(bits))};
n->f_make_for_loop = [=](Array<Var> vars, Stmt body) -> For {
n->vars = {Var("v", DataType::Int(bits))};
n->doms = {Range(min, extent)};
n->f_make_for_loop = [attrs, thread](Array<Var> vars, Array<Range> doms, Stmt body) -> For {
ICHECK_EQ(vars.size(), 1);
IterVar iter_var(Range(nullptr), Var(ObjectPtr<Object>(nullptr)), IterVarType::kThreadIndex,
thread);
return For(vars[0], min, extent, tvm::tir::ForKind::kThreadBinding, body, iter_var, attrs);
ICHECK_EQ(doms.size(), 1);
IterVar iter_var(Range(nullptr), NullValue<Var>(), IterVarType::kThreadIndex, thread);
return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kThreadBinding, body, iter_var,
attrs);
};
return With<ForFrame>(n);
return ForFrame(n);
}

With<ForFrame> Grid(Array<PrimExpr> extents) {
ForFrame Grid(Array<PrimExpr> extents) {
using namespace tvm::tir;
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
n->loop_vars.reserve(extents.size());
n->vars.reserve(extents.size());
n->doms.reserve(extents.size());
for (const auto& extent : extents) {
n->loop_vars.push_back(Var("v", extent.dtype()));
DataType dtype = extent.dtype();
n->vars.push_back(Var("v", extent.dtype()));
n->doms.push_back(Range(make_const(dtype, 0), extent));
}
n->f_make_for_loop = [=](Array<Var> vars, Stmt body) -> Stmt {
ICHECK_EQ(extents.size(), vars.size());
int n = extents.size();
n->f_make_for_loop = [](Array<Var> vars, Array<Range> doms, Stmt body) -> Stmt {
ICHECK_EQ(vars.size(), doms.size());
int n = vars.size();
for (int i = n - 1; i >= 0; --i) {
Range dom = doms[i];
Var var = vars[i];
PrimExpr extent = extents[i];
body = For(var, Integer(0), extent, ForKind::kSerial, body, /*thread_binding=*/NullOpt, {});
body = For(var, dom->min, dom->extent, ForKind::kSerial, std::move(body),
/*thread_binding=*/NullOpt, /*annotations=*/{});
}
return body;
};
return With<ForFrame>(n);
return ForFrame(n);
}

TVM_REGISTER_NODE_TYPE(ForFrameNode);
Expand Down
Loading

0 comments on commit 706ce82

Please sign in to comment.