Skip to content

Commit

Permalink
[LLVM] Represent alignment information in LLVM IR (#5598)
Browse files Browse the repository at this point in the history
  • Loading branch information
Krzysztof Parzyszek authored May 15, 2020
1 parent e7e3c58 commit 8d988df
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 7 deletions.
11 changes: 11 additions & 0 deletions src/target/llvm/codegen_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,17 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) {
#endif
fcompute->addFnAttr(llvm::Attribute::NoInline);
}
// Add alignment attribute if needed.
#if TVM_LLVM_VERSION >= 50
auto f = alloc_storage_info_.find(var.get());
if (f != alloc_storage_info_.end()) {
unsigned align = f->second.alignment;
if (align > 1) {
auto attr = llvm::Attribute::get(*ctx_, llvm::Attribute::Alignment, align);
fcompute->addParamAttr(idx, attr);
}
}
#endif
}
std::swap(function_, fcompute);
std::swap(new_vmap, var_map_);
Expand Down
32 changes: 28 additions & 4 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,21 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) {
builder_->SetInsertPoint(entry);
this->VisitStmt(f->body);

// Add alignment attribute if needed.
#if TVM_LLVM_VERSION >= 50
for (size_t i = 0; i < f->params.size(); ++i) {
const Var& var = f->params[i];
auto f = alloc_storage_info_.find(var.get());
if (f != alloc_storage_info_.end()) {
unsigned align = f->second.alignment;
if (align > 1) {
auto attr = llvm::Attribute::get(*ctx_, llvm::Attribute::Alignment, align);
function_->addParamAttr(i, attr);
}
}
}
#endif

if (ret_void) {
builder_->CreateRetVoid();
} else {
Expand Down Expand Up @@ -1250,6 +1265,10 @@ void CodeGenLLVM::VisitStmt_(const AttrStmtNode* op) {
const VarNode* v = op->node.as<VarNode>();
CHECK(v);
alloc_storage_info_[v].alignment = static_cast<int>(op->value.as<IntImmNode>()->value);
if (var_map_.count(v) && alloc_storage_info_[v].alignment > 1) {
builder_->CreateAlignmentAssumption(*data_layout_, GetVarValue(v),
alloc_storage_info_[v].alignment);
}
} else if (op->attr_key == tir::attr::volatile_scope) {
const VarNode* v = op->node.as<VarNode>();
CHECK(v);
Expand All @@ -1264,14 +1283,19 @@ void CodeGenLLVM::VisitStmt_(const AssertStmtNode* op) {
}

void CodeGenLLVM::VisitStmt_(const LetStmtNode* op) {
CHECK(!var_map_.count(op->var.get()));
if (op->var.dtype().is_handle()) {
const VarNode* v = op->var.get();
CHECK(!var_map_.count(v));
if (v->dtype.is_handle()) {
if (!is_restricted_) {
alias_var_set_.insert(op->var.get());
alias_var_set_.insert(v);
}
}
var_map_[op->var.get()] = MakeValue(op->value);
var_map_[v] = MakeValue(op->value);
analyzer_->Bind(op->var, op->value);
if (alloc_storage_info_.count(v) && alloc_storage_info_[v].alignment > 1) {
builder_->CreateAlignmentAssumption(*data_layout_, GetVarValue(v),
alloc_storage_info_[v].alignment);
}
this->VisitStmt(op->body);
}

Expand Down
33 changes: 30 additions & 3 deletions tests/python/unittest/test_target_codegen_llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import numpy as np
import ctypes
import math
import re


def test_llvm_intrin():
Expand Down Expand Up @@ -462,12 +463,39 @@ def test_alignment():
s = te.create_schedule(B.op)
bx, tx = s[B].split(B.op.axis[0], factor=8)
s[B].vectorize(tx)
f = tvm.build(s, [A, B], "llvm")
f = tvm.build(s, [A, B], "llvm", name="test_alignment")

for l in f.get_source().split("\n"):
lines = f.get_source().split("\n")

# Check alignment on load/store.
for l in lines:
if "align" in l and "4 x float" in l:
assert "align 32" in l

# Check parameter alignment. This looks for the definition of the
# outlined "compute_" function to see if there is an "align" attribute
# listed there.
def has_param_alignment():
for l in lines:
if re.search(r'test_alignment_compute_\([^(]*align [0-9]', l):
return True
return False

if tvm.target.codegen.llvm_version_major() >= 5:
assert has_param_alignment()

# Check for assume intrinsics. This isn't 100% accurate, since it just
# checks if the llvm.assume is there, but detailed check would require
# a much more detailed analysis of the LLVM IR.
def has_call_to_assume():
for l in lines:
if re.search(r'call.*llvm.assume', l):
return True
return False

assert has_call_to_assume()


def test_llvm_div():
"""Check that the semantics of div and mod is correct"""
def check(start, end, dstart, dend, dtype, floor_div=False):
Expand Down Expand Up @@ -625,7 +653,6 @@ def check_llvm_object():
temp = util.tempdir()
o_path = temp.relpath("temp.o")
m.save(o_path)
import re
import shutil
import subprocess
import sys
Expand Down

0 comments on commit 8d988df

Please sign in to comment.