diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 03b5496c244c..c70a1ab47990 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -426,6 +426,17 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) { #endif fcompute->addFnAttr(llvm::Attribute::NoInline); } + // Add alignment attribute if needed. +#if TVM_LLVM_VERSION >= 50 + auto f = alloc_storage_info_.find(var.get()); + if (f != alloc_storage_info_.end()) { + unsigned align = f->second.alignment; + if (align > 1) { + auto attr = llvm::Attribute::get(*ctx_, llvm::Attribute::Alignment, align); + fcompute->addParamAttr(idx, attr); + } + } +#endif } std::swap(function_, fcompute); std::swap(new_vmap, var_map_); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index f664532b2dc1..b43e9889a4ae 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -156,6 +156,21 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { builder_->SetInsertPoint(entry); this->VisitStmt(f->body); + // Add alignment attribute if needed. +#if TVM_LLVM_VERSION >= 50 + for (size_t i = 0; i < f->params.size(); ++i) { + const Var& var = f->params[i]; + auto f = alloc_storage_info_.find(var.get()); + if (f != alloc_storage_info_.end()) { + unsigned align = f->second.alignment; + if (align > 1) { + auto attr = llvm::Attribute::get(*ctx_, llvm::Attribute::Alignment, align); + function_->addParamAttr(i, attr); + } + } + } +#endif + if (ret_void) { builder_->CreateRetVoid(); } else { @@ -1250,6 +1265,10 @@ void CodeGenLLVM::VisitStmt_(const AttrStmtNode* op) { const VarNode* v = op->node.as(); CHECK(v); alloc_storage_info_[v].alignment = static_cast(op->value.as()->value); + if (var_map_.count(v) && alloc_storage_info_[v].alignment > 1) { + builder_->CreateAlignmentAssumption(*data_layout_, GetVarValue(v), + alloc_storage_info_[v].alignment); + } } else if (op->attr_key == tir::attr::volatile_scope) { const VarNode* v = op->node.as(); CHECK(v); @@ -1264,14 +1283,19 @@ void CodeGenLLVM::VisitStmt_(const AssertStmtNode* op) { } void CodeGenLLVM::VisitStmt_(const LetStmtNode* op) { - CHECK(!var_map_.count(op->var.get())); - if (op->var.dtype().is_handle()) { + const VarNode* v = op->var.get(); + CHECK(!var_map_.count(v)); + if (v->dtype.is_handle()) { if (!is_restricted_) { - alias_var_set_.insert(op->var.get()); + alias_var_set_.insert(v); } } - var_map_[op->var.get()] = MakeValue(op->value); + var_map_[v] = MakeValue(op->value); analyzer_->Bind(op->var, op->value); + if (alloc_storage_info_.count(v) && alloc_storage_info_[v].alignment > 1) { + builder_->CreateAlignmentAssumption(*data_layout_, GetVarValue(v), + alloc_storage_info_[v].alignment); + } this->VisitStmt(op->body); } diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index c6591721d247..ff35de026127 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -21,6 +21,7 @@ import numpy as np import ctypes import math +import re def test_llvm_intrin(): @@ -462,12 +463,39 @@ def test_alignment(): s = te.create_schedule(B.op) bx, tx = s[B].split(B.op.axis[0], factor=8) s[B].vectorize(tx) - f = tvm.build(s, [A, B], "llvm") + f = tvm.build(s, [A, B], "llvm", name="test_alignment") - for l in f.get_source().split("\n"): + lines = f.get_source().split("\n") + + # Check alignment on load/store. + for l in lines: if "align" in l and "4 x float" in l: assert "align 32" in l + # Check parameter alignment. This looks for the definition of the + # outlined "compute_" function to see if there is an "align" attribute + # listed there. + def has_param_alignment(): + for l in lines: + if re.search(r'test_alignment_compute_\([^(]*align [0-9]', l): + return True + return False + + if tvm.target.codegen.llvm_version_major() >= 5: + assert has_param_alignment() + + # Check for assume intrinsics. This isn't 100% accurate, since it just + # checks if the llvm.assume is there, but detailed check would require + # a much more detailed analysis of the LLVM IR. + def has_call_to_assume(): + for l in lines: + if re.search(r'call.*llvm.assume', l): + return True + return False + + assert has_call_to_assume() + + def test_llvm_div(): """Check that the semantics of div and mod is correct""" def check(start, end, dstart, dend, dtype, floor_div=False): @@ -625,7 +653,6 @@ def check_llvm_object(): temp = util.tempdir() o_path = temp.relpath("temp.o") m.save(o_path) - import re import shutil import subprocess import sys