Skip to content

Commit

Permalink
Add ClampUnsafeAccesses pass. (#6294)
Browse files Browse the repository at this point in the history
* Add ClampUnsafeAccesses pass. Fixes #6131

Inject clamps around func calls h(...) when all the following conditions hold:
  1. The call flows into an indexing context, such as: `f(x) = g(h(x))` or `let y = h(x) in f(x) = g(y)`
  2. The FuncValueBounds of h are smaller than those of its type
  3. h's allocation bounds might be wider than its compute bounds

Condition (3) is not yet implemented see #6297.
  • Loading branch information
alexreinking authored Oct 8, 2021
1 parent c6529ed commit 2bfa567
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 0 deletions.
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@ SOURCE_FILES = \
Buffer.cpp \
CanonicalizeGPUVars.cpp \
Closure.cpp \
ClampUnsafeAccesses.cpp \
CodeGen_ARM.cpp \
CodeGen_C.cpp \
CodeGen_D3D12Compute_Dev.cpp \
Expand Down Expand Up @@ -590,6 +591,7 @@ HEADER_FILES = \
BoundSmallAllocations.h \
Buffer.h \
CanonicalizeGPUVars.h \
ClampUnsafeAccesses.h \
Closure.h \
CodeGen_C.h \
CodeGen_D3D12Compute_Dev.h \
Expand Down
2 changes: 2 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ set(HEADER_FILES
BoundSmallAllocations.h
Buffer.h
CanonicalizeGPUVars.h
ClampUnsafeAccesses.h
Closure.h
CodeGen_C.h
CodeGen_D3D12Compute_Dev.h
Expand Down Expand Up @@ -181,6 +182,7 @@ set(SOURCE_FILES
BoundSmallAllocations.cpp
Buffer.cpp
CanonicalizeGPUVars.cpp
ClampUnsafeAccesses.cpp
Closure.cpp
CodeGen_ARM.cpp
CodeGen_C.cpp
Expand Down
88 changes: 88 additions & 0 deletions src/ClampUnsafeAccesses.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#include "ClampUnsafeAccesses.h"
#include "IREquality.h"
#include "IRMutator.h"
#include "IRPrinter.h"
#include "Simplify.h"

namespace Halide::Internal {

namespace {

struct ClampUnsafeAccesses : IRMutator {
const std::map<std::string, Function> &env;
FuncValueBounds &func_bounds;

ClampUnsafeAccesses(const std::map<std::string, Function> &env, FuncValueBounds &func_bounds)
: env(env), func_bounds(func_bounds) {
}

protected:
using IRMutator::visit;

Expr visit(const Let *let) override {
return visit_let<Let, Expr>(let);
}

Stmt visit(const LetStmt *let) override {
return visit_let<LetStmt, Stmt>(let);
}

Expr visit(const Variable *var) override {
if (is_inside_indexing && let_var_inside_indexing.contains(var->name)) {
let_var_inside_indexing.ref(var->name) = true;
}
return var;
}

Expr visit(const Call *call) override {
if (call->call_type != Call::Halide) {
return IRMutator::visit(call);
}

if (is_inside_indexing) {
auto bounds = func_bounds.at({call->name, call->value_index});
if (bounds_smaller_than_type(bounds, call->type)) {
// TODO(#6297): check that the clamped function's allocation bounds might be wider than its compute bounds

auto [new_args, changed] = mutate_with_changes(call->args);
Expr new_call = changed ? call : Call::make(call->type, call->name, new_args, call->call_type, call->func, call->value_index, call->image, call->param);
return Max::make(Min::make(new_call, std::move(bounds.max)), std::move(bounds.min));
}
}

ScopedValue s(is_inside_indexing, true);
return IRMutator::visit(call);
}

private:
template<typename L, typename Body>
Body visit_let(const L *let) {
ScopedBinding<bool> binding(let_var_inside_indexing, let->name, false);
Body body = mutate(let->body);

ScopedValue s(is_inside_indexing, is_inside_indexing || let_var_inside_indexing.get(let->name));
Expr value = mutate(let->value);

return L::make(let->name, std::move(value), std::move(body));
}

bool bounds_smaller_than_type(const Interval &bounds, Type type) {
return bounds.is_bounded() && !(equal(bounds.min, type.min()) && equal(bounds.max, type.max()));
}

/**
* A let-var is marked "true" if is used somewhere in an indexing expression.
* visit_let will process its value binding with is_inside_indexing set when
* this is the case.
*/
Scope<bool> let_var_inside_indexing;
bool is_inside_indexing = false;
};

} // namespace

Stmt clamp_unsafe_accesses(const Stmt &s, const std::map<std::string, Function> &env, FuncValueBounds &func_bounds) {
return ClampUnsafeAccesses(env, func_bounds).mutate(s);
}

} // namespace Halide::Internal
23 changes: 23 additions & 0 deletions src/ClampUnsafeAccesses.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#ifndef HALIDE_CLAMPUNSAFEACCESSES_H
#define HALIDE_CLAMPUNSAFEACCESSES_H

/** \file
* Defines the clamp_unsafe_accesses lowering pass.
*/

#include "Bounds.h"
#include "Expr.h"
#include "Function.h"

namespace Halide::Internal {

/** Inject clamps around func calls h(...) when all the following conditions hold:
* 1. The call is in an indexing context, such as: f(x) = g(h(x));
* 2. The FuncValueBounds of h are smaller than those of its type
* 3. The allocation bounds of h might be wider than its compute bounds.
*/
Stmt clamp_unsafe_accesses(const Stmt &s, const std::map<std::string, Function> &env, FuncValueBounds &func_bounds);

} // namespace Halide::Internal

#endif // HALIDE_CLAMPUNSAFEACCESSES_H
7 changes: 7 additions & 0 deletions src/Lower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "BoundsInference.h"
#include "CSE.h"
#include "CanonicalizeGPUVars.h"
#include "ClampUnsafeAccesses.h"
#include "CompilerLogger.h"
#include "Debug.h"
#include "DebugArguments.h"
Expand Down Expand Up @@ -162,6 +163,12 @@ void lower_impl(const vector<Function> &output_funcs,
debug(1) << "Computing bounds of each function's value\n";
FuncValueBounds func_bounds = compute_function_value_bounds(order, env);

// Clamp unsafe instances where a Func f accesses a Func g using
// an index which depends on a third Func h.
debug(1) << "Clamping unsafe data-dependent accesses\n";
s = clamp_unsafe_accesses(s, env, func_bounds);
log("Lowering after clamping unsafe data-dependent accesses", s);

// This pass injects nested definitions of variable names, so we
// can't simplify statements from here until we fix them up. (We
// can still simplify Exprs).
Expand Down
1 change: 1 addition & 0 deletions test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ tests(GROUPS correctness
implicit_args.cpp
implicit_args_tests.cpp
in_place.cpp
indexing_access_undef.cpp
infer_arguments.cpp
inline_reduction.cpp
inlined_generator.cpp
Expand Down
41 changes: 41 additions & 0 deletions test/correctness/indexing_access_undef.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#include "Halide.h"

using namespace Halide;

// https://github.com/halide/Halide/issues/6131
// Prior to the ClampUnsafeAccesses pass, this test case would
// crash as described in the comments below.

int main(int argc, char **argv) {
Var x{"x"};

Func f{"f"}, g{"g"}, h{"h"}, out{"out"};

const int min = -10000000;
const int max = min + 20;

h(x) = clamp(x, min, max);
// Within its compute bounds, h's value will be within
// [min,max]. Outside that, it's uninitialized memory.

g(x) = sin(x);
// Halide thinks g will be accessed within [min,max], so its
// allocation bounds will be [min,max]

f(x) = g(h(x));
f.vectorize(x, 64, TailStrategy::RoundUp);
// f will access h at values outside its compute bounds, and get
// garbage, and then use that garbage to access g outside its
// allocation bounds.

out(x) = f(x);

h.compute_root();
g.compute_root();
f.compute_root();

out.realize({1});

printf("Success!\n");
return 0;
}

0 comments on commit 2bfa567

Please sign in to comment.