Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 14 additions & 26 deletions src/CodeGen_ARM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,6 @@ class CodeGen_ARM : public CodeGen_Posix {
void compile_func(const LoweredFunc &f,
const std::string &simple_name, const std::string &extern_name) override;

void begin_func(LinkageType linkage, const std::string &simple_name,
const std::string &extern_name, const std::vector<LoweredArgument> &args) override;

/** Nodes for which we want to emit specific ARM vector intrinsics */
// @{
void visit(const Cast *) override;
Expand Down Expand Up @@ -1137,6 +1134,20 @@ void CodeGen_ARM::compile_func(const LoweredFunc &f,

LoweredFunc func = f;

// Make sure run-time vscale is equal to compile-time vscale.
// Avoiding the assert on inner functions is both an efficiency and a correctness issue
// as the assertion code may not compile in all contexts.
if (f.linkage != LinkageType::Internal) {
int effective_vscale = target_vscale();
if (effective_vscale != 0 && !target.has_feature(Target::NoAsserts)) {
Expr runtime_vscale = Call::make(Int(32), Call::get_runtime_vscale, {}, Call::PureIntrinsic);
Expr compiletime_vscale = Expr(effective_vscale);
Expr error = Call::make(Int(32), "halide_error_vscale_invalid",
{simple_name, runtime_vscale, compiletime_vscale}, Call::Extern);
func.body = Block::make(AssertStmt::make(runtime_vscale == compiletime_vscale, error), func.body);
}
}

if (target.os != Target::IOS && target.os != Target::OSX) {
// Substitute in strided loads to get vld2/3/4 emission. We don't do it
// on Apple silicon, because doing a dense load and then shuffling is
Expand All @@ -1150,29 +1161,6 @@ void CodeGen_ARM::compile_func(const LoweredFunc &f,
CodeGen_Posix::compile_func(func, simple_name, extern_name);
}

void CodeGen_ARM::begin_func(LinkageType linkage, const std::string &simple_name,
const std::string &extern_name, const std::vector<LoweredArgument> &args) {
CodeGen_Posix::begin_func(linkage, simple_name, extern_name, args);

// TODO(https://github.com/halide/Halide/issues/8092): There is likely a
// better way to ensure this is only generated for the outermost function
// that is being compiled. Avoiding the assert on inner functions is both an
// efficiency and a correctness issue as the assertion code may not compile
// in all contexts.
if (linkage != LinkageType::Internal) {
int effective_vscale = target_vscale();
if (effective_vscale != 0 && !target.has_feature(Target::NoAsserts)) {
// Make sure run-time vscale is equal to compile-time vscale
Expr runtime_vscale = Call::make(Int(32), Call::get_runtime_vscale, {}, Call::PureIntrinsic);
Value *val_runtime_vscale = codegen(runtime_vscale);
Value *val_compiletime_vscale = ConstantInt::get(i32_t, effective_vscale);
Value *cond = builder->CreateICmpEQ(val_runtime_vscale, val_compiletime_vscale);
create_assertion(cond, Call::make(Int(32), "halide_error_vscale_invalid",
{simple_name, runtime_vscale, Expr(effective_vscale)}, Call::Extern));
}
}
}

void CodeGen_ARM::visit(const Cast *op) {
if (!simd_intrinsics_disabled() && op->type.is_vector()) {
vector<Expr> matches;
Expand Down
5 changes: 3 additions & 2 deletions src/CodeGen_Internal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -719,8 +719,9 @@ void set_function_attributes_from_halide_target_options(llvm::Function &fn) {
// inaccurate even for us.
fn.addFnAttr("reciprocal-estimates", "none");

// If a fixed vscale is asserted, add it as an attribute on the function.
if (vscale_range != 0) {
// If a fixed vscale is asserted, add it as an attribute on the function
// except for those which already have vscale_range for some purpose
if (vscale_range != 0 && !fn.hasFnAttribute(llvm::Attribute::VScaleRange)) {
fn.addFnAttr(llvm::Attribute::getWithVScaleRangeArgs(
module.getContext(), vscale_range, vscale_range));
}
Expand Down
23 changes: 21 additions & 2 deletions src/CodeGen_LLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3333,11 +3333,30 @@ void CodeGen_LLVM::visit(const Call *op) {
} else if (op->is_intrinsic(Call::concat_bits)) {
value = codegen(lower_concat_bits(op));
} else if (op->is_intrinsic(Call::get_runtime_vscale)) {
// This intrin function must be defined independently.
// Otherwise, vscale_range(n, n) attribute is added and llvm compiler optimize away the runtime call,
// which makes runtime assertion of vscale useless.
llvm::Function *fn = module->getFunction(op->name);
if (!fn) {
FunctionType *func_t = FunctionType::get(i32_t, {}, false);
fn = llvm::Function::Create(func_t, llvm::Function::InternalLinkage, op->name, module.get());
llvm::BasicBlock *block = llvm::BasicBlock::Create(module->getContext(), "entry", fn);
IRBuilderBase::InsertPoint here = builder->saveIP();
builder->SetInsertPoint(block);
#if LLVM_VERSION >= 210
value = builder->CreateVScale(i32_t);
Value *ret = builder->CreateVScale(i32_t);
#else
value = builder->CreateVScale(ConstantInt::get(i32_t, 1));
Value *ret = builder->CreateVScale(ConstantInt::get(i32_t, 1));
#endif
builder->CreateRet(ret);

// To avoid vscale_range(n,n) added in CodeGen_Internal
fn->addFnAttr(llvm::Attribute::getWithVScaleRangeArgs(*context, 1, 16));
fn->addFnAttr(llvm::Attribute::NoInline);
internal_assert(!verifyFunction(*fn, &llvm::errs()));
builder->restoreIP(here);
}
value = builder->CreateCall(fn, {});
} else if (op->is_intrinsic()) {
Expr lowered = lower_intrinsic(op);
if (!lowered.defined()) {
Expand Down
1 change: 1 addition & 0 deletions test/error/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ tests(GROUPS error
memoize_redefine_eviction_key.cpp
metal_threads_too_large.cpp
metal_vector_too_large.cpp
mismatch_runtime_vscale.cpp
missing_args.cpp
no_default_device.cpp
nonexistent_update_stage.cpp
Expand Down
27 changes: 27 additions & 0 deletions test/error/mismatch_runtime_vscale.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#include "Halide.h"
#include <stdio.h>

using namespace Halide;

int main(int argc, char **argv) {
auto target = get_host_target();
if (!target.features_any_of({Target::SVE, Target::SVE2})) {
printf("[SKIP] Scalable vector is not supported on this target.\n");
_halide_user_assert(0);
return 1;
}

Func f("f");
Var x("x");

f(x) = x;

const int wrong_vector_bits = target.vector_bits == 128 ? 256 : 128;
target.vector_bits = wrong_vector_bits;

// Compile with wrong vscale and run on host, which should end up with assertion failure.
Buffer<int> out = f.realize({100}, target);

printf("Success!\n");
return 0;
}
Loading