Skip to content

Commit

Permalink
merge storage_alloc refactor branch
Browse files Browse the repository at this point in the history
commit 0ca1924
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jun 29 18:21:41 2021 +0900

    remove alloc map usage in cuda codegen

commit fd07b35
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jun 29 17:51:33 2021 +0900

    remove storage_scope map from storage_access.cc

commit 0ba5c71
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jun 29 17:36:07 2021 +0900

    remove realize_scope

commit f392e47
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jun 29 17:31:13 2021 +0900

    fix passing storage scope

commit bf132cf
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jun 29 13:11:55 2021 +0900

    make global storage scope by default

commit 6630cf3
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jun 29 13:10:37 2021 +0900

    remove realize_scope from schedule_ops

commit 8688205
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jun 29 13:10:18 2021 +0900

    remove storage_scope attr from storage_rewrite

commit b382bb3
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jun 29 12:50:31 2021 +0900

    remove attr::realize_scope from storage_flatten

commit e20e195
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jun 29 12:46:13 2021 +0900

    begin removing storage_scope attr

commit d39a470
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jun 29 12:23:46 2021 +0900

    thread storage scope through pipeline to buffer creation

commit 496a215
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jun 29 11:46:36 2021 +0900

    adding storage_scope to ProduerRealize

commit c586834
Author: Swift.Sun <sunjiwei@yeah.net>
Date:   Mon Jun 28 12:32:19 2021 +0800

    [AutoScheduler]Simplify the code (apache#8351)

commit 4ff5cef
Author: Rafael Stahl <r.stahl@tum.de>
Date:   Sun Jun 27 19:18:34 2021 +0200

    ffi: add missing binding for FixedPointMultiplyAttrs (apache#8353)

commit b71b837
Author: Matthew Brookhart <mbrookhart@octoml.ai>
Date:   Sun Jun 27 02:00:32 2021 -0600

    Remove an extra print from the relay astext tests (apache#8342)
  • Loading branch information
masa committed Jun 29, 2021
1 parent aa4b90d commit ca502ee
Show file tree
Hide file tree
Showing 28 changed files with 148 additions and 197 deletions.
14 changes: 7 additions & 7 deletions include/tvm/te/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ class TVM_DLL OperationNode : public Object {
* \return A realization statement that wraps body.
*/
virtual Stmt BuildRealize(const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const = 0;
const std::unordered_map<IterVar, Range>& realize_map, const Stmt& body,
String storage_scope = "") const = 0;
/*!
* \brief Build the statement that provide the output tensors.
* \param stage The schedule stage of the op.
Expand Down Expand Up @@ -168,7 +168,7 @@ class PlaceholderOpNode : public OperationNode {
void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
const Stmt& body, String storage_scope = "") const final;
Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final;

Expand Down Expand Up @@ -212,7 +212,7 @@ class TVM_DLL BaseComputeOpNode : public OperationNode {
void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
const Stmt& body, String storage_scope = "") const final;
virtual size_t num_schedulable_dims() const = 0;

static constexpr const char* _type_key = "BaseComputeOp";
Expand Down Expand Up @@ -370,7 +370,7 @@ class ScanOpNode : public OperationNode {
void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
const Stmt& body, String storage_scope = "") const final;
Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final;

Expand Down Expand Up @@ -433,7 +433,7 @@ class ExternOpNode : public OperationNode {
void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
const Stmt& body, String storage_scope = "") const final;
Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final;

Expand Down Expand Up @@ -498,7 +498,7 @@ class HybridOpNode : public OperationNode {
void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
const Stmt& body, String storage_scope = "") const final;
Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final;

Expand Down
4 changes: 3 additions & 1 deletion include/tvm/tir/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,9 @@ class Buffer : public ObjectRef {
* \sa Buffer for complete constructor.
*/
TVM_DLL Buffer decl_buffer(Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
String name = "buffer", Span span = Span());
String name = "buffer", String storage_scope = "", Span span = Span());

TVM_DLL String GetStorageScope(Var buffer_var);

/*!
* \brief Base node for data producers.
Expand Down
9 changes: 7 additions & 2 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -465,24 +465,29 @@ class ProducerRealizeNode : public StmtNode {
/*! \brief The body of realization. */
Stmt body;

String storage_scope;

void VisitAttrs(AttrVisitor* v) {
v->Visit("producer", &producer);
v->Visit("bounds", &bounds);
v->Visit("condition", &condition);
v->Visit("body", &body);
v->Visit("storage_scope", &storage_scope);
v->Visit("span", &span);
}

bool SEqualReduce(const ProducerRealizeNode* other, SEqualReducer equal) const {
return equal(producer, other->producer) && equal(bounds, other->bounds) &&
equal(condition, other->condition) && equal(body, other->body);
equal(condition, other->condition) && equal(body, other->body) &&
equal(storage_scope, other->storage_scope);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(producer);
hash_reduce(bounds);
hash_reduce(condition);
hash_reduce(body);
hash_reduce(storage_scope);
}

static constexpr const char* _type_key = "tir.ProducerRealize";
Expand All @@ -496,7 +501,7 @@ class ProducerRealizeNode : public StmtNode {
class ProducerRealize : public Stmt {
public:
TVM_DLL ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, Stmt body,
Span span = Span());
String storage_scope = "", Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(ProducerRealize, Stmt, ProducerRealizeNode);
};
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,3 +577,8 @@ class UniformAttrs(Attrs):
@tvm._ffi.register_object("relay.attrs.NLLLossAttrs")
class NLLLossAttrs(Attrs):
"""Attributes for nn.nll_loss"""


@tvm._ffi.register_object("relay.attrs.FixedPointMultiplyAttrs")
class FixedPointMultiplyAttrs(Attrs):
"""Attributes used in fixed_point_multiply operators"""
2 changes: 1 addition & 1 deletion python/tvm/tir/ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def allocate(self, dtype, shape, name="buf", scope=None):
buffer : BufferVar
The buffer var representing the buffer.
"""
buffer_var = _expr.Var(name, PointerType(PrimType(dtype)))
buffer_var = _expr.Var(name, PointerType(PrimType(dtype), scope))
if not isinstance(shape, (list, tuple, _container.Array)):
shape = [shape]
if scope:
Expand Down
101 changes: 42 additions & 59 deletions src/auto_scheduler/search_policy/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,24 +153,21 @@ State DoMultiLevelTiling(const State& state, int stage_id, const std::string& fo
if (spatial_split_step_ids == nullptr) {
spatial_split_step_ids = &temp_split_step_ids;
}
spatial_split_step_ids->clear();

std::vector<std::vector<Iterator>> space_levels;
std::vector<std::vector<Iterator>> reduce_levels;
std::vector<Iterator> space_outer, space_inner, reduce_outer, reduce_inner;
Array<Iterator> split_res;

for (const auto c : format) {
if (tolower(c) == 's') {
space_levels.emplace_back();
} else if (tolower(c) == 'r') {
reduce_levels.emplace_back();
} else {
LOG(FATAL) << "Invalid multi-level tiling format: " << format;
}
size_t n_space =
std::count(format.begin(), format.end(), 's') + std::count(format.begin(), format.end(), 'S');
size_t n_reduce =
std::count(format.begin(), format.end(), 'r') + std::count(format.begin(), format.end(), 'R');
if (n_space + n_reduce != format.size()) {
LOG(FATAL) << "Invalid multi-level tiling format: " << format;
}
size_t n_space = space_levels.size();
size_t n_reduce = reduce_levels.size();

spatial_split_step_ids->clear();
space_levels.resize(n_space);
reduce_levels.resize(n_reduce);

State tmp_s = state;
const Stage& stage = state->stages[stage_id];
Expand All @@ -179,31 +176,28 @@ State DoMultiLevelTiling(const State& state, int stage_id, const std::string& fo
? GetIterNameSetParam(stage->op->attrs, SearchPolicyKey::no_split_at_inner)
: std::set<std::string>();

auto sr_levels = [&](int size, const Iterator& iter, std::vector<std::vector<Iterator>>& levels) {
ICHECK_GE(size, 1);
if (size == 1) {
levels[0].push_back(iter);
} else {
Array<Iterator> split_res =
tmp_s.split(stage_id, iter, Array<Optional<Integer>>(size - 1, NullOpt));
for (int i = 0; i < size; i++) {
levels[i].push_back(split_res[i]);
}
if (iter->iter_kind == IteratorKind::kSpatial) {
spatial_split_step_ids->push_back(tmp_s->transform_steps.size() - 1);
}
}
};

for (const auto& iter : state->stages[stage_id]->iters) {
if (!no_split_at_inner_name_set.count(iter->name)) {
if (iter->iter_kind == IteratorKind::kSpatial) {
ICHECK_GE(n_space, 1);

if (n_space == 1) {
space_levels[0].push_back(iter);
} else {
split_res = tmp_s.split(stage_id, iter, Array<Optional<Integer>>(n_space - 1, NullOpt));
for (size_t i = 0; i < n_space; i++) {
space_levels[i].push_back(split_res[i]);
}
spatial_split_step_ids->push_back(tmp_s->transform_steps.size() - 1);
}
sr_levels(n_space, iter, space_levels);
} else if (iter->iter_kind == IteratorKind::kReduction) {
ICHECK_GE(n_reduce, 1);

if (n_reduce == 1) {
reduce_levels[0].push_back(iter);
} else {
split_res = tmp_s.split(stage_id, iter, Array<Optional<Integer>>(n_reduce - 1, NullOpt));
for (size_t i = 0; i < n_reduce; i++) {
reduce_levels[i].push_back(split_res[i]);
}
}
sr_levels(n_reduce, iter, reduce_levels);
} else {
LOG(FATAL) << "Invalid iter type: " << int(iter->iter_kind);
}
Expand All @@ -218,40 +212,29 @@ State DoMultiLevelTiling(const State& state, int stage_id, const std::string& fo
}
}

if (!space_outer.empty()) {
ICHECK(!space_levels.empty());
space_levels.front().insert(space_levels.front().begin(),
std::make_move_iterator(space_outer.begin()),
std::make_move_iterator(space_outer.end()));
}
if (!space_inner.empty()) {
ICHECK(!space_levels.empty());
space_levels.back().insert(space_levels.back().begin(),
std::make_move_iterator(space_inner.begin()),
std::make_move_iterator(space_inner.end()));
}

if (!reduce_outer.empty()) {
ICHECK(!reduce_levels.empty());
reduce_levels.front().insert(reduce_levels.front().begin(),
std::make_move_iterator(reduce_outer.begin()),
std::make_move_iterator(reduce_outer.end()));
auto fill_levels = [&](std::vector<Iterator>& levels_iter, std::vector<Iterator>& fill) {
if (!fill.empty()) {
levels_iter.insert(levels_iter.begin(), std::make_move_iterator(fill.begin()),
std::make_move_iterator(fill.end()));
}
};
if (!space_levels.empty()) {
fill_levels(space_levels.front(), space_outer);
fill_levels(space_levels.back(), space_inner);
}
if (!reduce_inner.empty()) {
ICHECK(!reduce_levels.empty());
reduce_levels.back().insert(reduce_levels.back().begin(),
std::make_move_iterator(reduce_inner.begin()),
std::make_move_iterator(reduce_inner.end()));
if (!reduce_levels.empty()) {
fill_levels(reduce_levels.front(), reduce_outer);
fill_levels(reduce_levels.back(), reduce_inner);
}

Array<Iterator> order;
int space_ct = 0, reduce_ct = 0;
for (const auto c : format) {
if (tolower(c) == 's') {
if (c == 's' || c == 'S') {
order.insert(order.end(), std::make_move_iterator(space_levels[space_ct].begin()),
std::make_move_iterator(space_levels[space_ct].end()));
space_ct++;
} else if (tolower(c) == 'r') {
} else if (c == 'r' || c == 'R') {
order.insert(order.end(), std::make_move_iterator(reduce_levels[reduce_ct].begin()),
std::make_move_iterator(reduce_levels[reduce_ct].end()));
reduce_ct++;
Expand Down
4 changes: 3 additions & 1 deletion src/runtime/thread_storage_scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,9 @@ struct StorageScope {
*/
static StorageScope Create(const std::string& s) {
StorageScope r;
if (s.compare(0, 6, "global") == 0) {
if (s == "") {
r.rank = StorageRank::kGlobal;
} else if (s.compare(0, 6, "global") == 0) {
r.rank = StorageRank::kGlobal;
r.tag = s.substr(6, std::string::npos);
} else if (s.compare(0, 6, "shared") == 0) {
Expand Down
7 changes: 4 additions & 3 deletions src/target/llvm/codegen_amdgpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ class CodeGenAMDGPU : public CodeGenLLVM {
llvm::Value* buf = nullptr;
StorageInfo& info = alloc_storage_info_[op->buffer_var.get()];

if (info.scope.rank == runtime::StorageRank::kDynShared) {
auto storage_scope = runtime::StorageScope::Create(GetStorageScope(op->buffer_var));
if (storage_scope.rank == runtime::StorageRank::kDynShared) {
buf = AllocateSharedMemory(op->dtype, 0, 3, std::min(info.alignment, 16),
llvm::GlobalValue::ExternalLinkage);
} else {
Expand All @@ -88,7 +89,7 @@ class CodeGenAMDGPU : public CodeGenLLVM {
if (info.alignment > 16) {
info.alignment = 16;
}
if (info.scope.rank == runtime::StorageRank::kLocal) {
if (storage_scope.rank == runtime::StorageRank::kLocal) {
// const int local_address_space = 5;
// TODO(tqchen): for higher version of LLVM, local address space can be set.
llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
Expand All @@ -103,7 +104,7 @@ class CodeGenAMDGPU : public CodeGenLLVM {
}
buf = alloca;
} else {
ICHECK(info.scope.rank == runtime::StorageRank::kShared)
ICHECK(storage_scope.rank == runtime::StorageRank::kShared)
<< "Can only allocate shared or local memory inside kernel";
// Shared memory: address space == 3
buf = AllocateSharedMemory(op->dtype, constant_size, 3, info.alignment,
Expand Down
8 changes: 2 additions & 6 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,8 @@ void CodeGenLLVM::GetAlignment(DataType t, const VarNode* buf_var, const PrimExp
auto it = alloc_storage_info_.find(buf_var);
if (it != alloc_storage_info_.end()) {
const StorageInfo& info = it->second;
*p_native_bits = NativeVectorBits(info.scope);
*p_native_bits =
NativeVectorBits(runtime::StorageScope::Create(GetStorageScope(GetRef<Var>(buf_var))));
max_align_bits = info.alignment * 8;
} else {
*p_native_bits = native_vector_bits_;
Expand Down Expand Up @@ -1407,11 +1408,6 @@ void CodeGenLLVM::VisitStmt_(const AttrStmtNode* op) {
analyzer_->Bind(iv->var, Range::FromMinExtent(0, op->value));
}
}
} else if (op->attr_key == tir::attr::storage_scope) {
const VarNode* v = op->node.as<VarNode>();
ICHECK(v);
alloc_storage_info_[v].scope =
runtime::StorageScope::Create(op->value.as<StringImmNode>()->value);
} else if (op->attr_key == tir::attr::storage_alignment) {
const VarNode* v = op->node.as<VarNode>();
ICHECK(v);
Expand Down
2 changes: 0 additions & 2 deletions src/target/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,6 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
protected:
/*! \brief The storage information */
struct StorageInfo {
/*! \brief The storage scope */
runtime::StorageScope scope;
/*! \brief The alignment of allocation */
int alignment{0};
};
Expand Down
8 changes: 4 additions & 4 deletions src/target/llvm/codegen_nvptx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ class CodeGenNVPTX : public CodeGenLLVM {
if (info.alignment > 16) {
info.alignment = 16;
}

if (info.scope.rank == runtime::StorageRank::kDynShared) {
auto storage_scope = runtime::StorageScope::Create(GetStorageScope(op->buffer_var));
if (storage_scope.rank == runtime::StorageRank::kDynShared) {
buf =
AllocateSharedMemory(op->dtype, 0, 3, info.alignment, llvm::GlobalValue::ExternalLinkage);
} else {
Expand All @@ -64,7 +64,7 @@ class CodeGenNVPTX : public CodeGenLLVM {
if (constant_size % 4 == 0 && info.alignment == 0) {
info.alignment = GetTempAllocaAlignment(op->dtype, constant_size);
}
if (info.scope.rank == runtime::StorageRank::kLocal) {
if (storage_scope.rank == runtime::StorageRank::kLocal) {
// const int local_address_space = 5;
// TODO(tqchen): for higher version of LLVM, local address space can be set.
llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
Expand All @@ -79,7 +79,7 @@ class CodeGenNVPTX : public CodeGenLLVM {
}
buf = alloca;
} else {
ICHECK(info.scope.rank == runtime::StorageRank::kShared)
ICHECK(storage_scope.rank == runtime::StorageRank::kShared)
<< "Can only allocate shared or local memory inside kernel";
buf = AllocateSharedMemory(op->dtype, constant_size, 3, info.alignment,
llvm::GlobalValue::PrivateLinkage);
Expand Down
5 changes: 1 addition & 4 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -705,12 +705,9 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) {
std::string vid = AllocVarID(op->buffer_var.get());

this->PrintIndent();
std::string scope = GetStorageScope(op->buffer_var);
const VarNode* buffer = op->buffer_var.as<VarNode>();
auto it = alloc_storage_scope_.find(buffer);
ICHECK(it != alloc_storage_scope_.end())
<< "Buffer " << op->buffer_var << " is missing an AttrStmt with a \"storage_scope\" key";

std::string scope = it->second;
if (scope.find("wmma.") == 0) {
if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8) ||
Expand Down
Loading

0 comments on commit ca502ee

Please sign in to comment.