-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add ClampUnsafeAccesses pass. (#6294)
* Add ClampUnsafeAccesses pass. Fixes #6131 Inject clamps around func calls h(...) when all the following conditions hold: 1. The call flows into an indexing context, such as: `f(x) = g(h(x))` or `let y = h(x) in f(x) = g(y)` 2. The FuncValueBounds of h are smaller than those of its type 3. h's allocation bounds might be wider than its compute bounds Condition (3) is not yet implemented see #6297.
- Loading branch information
1 parent
c6529ed
commit 2bfa567
Showing
7 changed files
with
164 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
#include "ClampUnsafeAccesses.h" | ||
#include "IREquality.h" | ||
#include "IRMutator.h" | ||
#include "IRPrinter.h" | ||
#include "Simplify.h" | ||
|
||
namespace Halide::Internal { | ||
|
||
namespace { | ||
|
||
struct ClampUnsafeAccesses : IRMutator { | ||
const std::map<std::string, Function> &env; | ||
FuncValueBounds &func_bounds; | ||
|
||
ClampUnsafeAccesses(const std::map<std::string, Function> &env, FuncValueBounds &func_bounds) | ||
: env(env), func_bounds(func_bounds) { | ||
} | ||
|
||
protected: | ||
using IRMutator::visit; | ||
|
||
Expr visit(const Let *let) override { | ||
return visit_let<Let, Expr>(let); | ||
} | ||
|
||
Stmt visit(const LetStmt *let) override { | ||
return visit_let<LetStmt, Stmt>(let); | ||
} | ||
|
||
Expr visit(const Variable *var) override { | ||
if (is_inside_indexing && let_var_inside_indexing.contains(var->name)) { | ||
let_var_inside_indexing.ref(var->name) = true; | ||
} | ||
return var; | ||
} | ||
|
||
Expr visit(const Call *call) override { | ||
if (call->call_type != Call::Halide) { | ||
return IRMutator::visit(call); | ||
} | ||
|
||
if (is_inside_indexing) { | ||
auto bounds = func_bounds.at({call->name, call->value_index}); | ||
if (bounds_smaller_than_type(bounds, call->type)) { | ||
// TODO(#6297): check that the clamped function's allocation bounds might be wider than its compute bounds | ||
|
||
auto [new_args, changed] = mutate_with_changes(call->args); | ||
Expr new_call = changed ? call : Call::make(call->type, call->name, new_args, call->call_type, call->func, call->value_index, call->image, call->param); | ||
return Max::make(Min::make(new_call, std::move(bounds.max)), std::move(bounds.min)); | ||
} | ||
} | ||
|
||
ScopedValue s(is_inside_indexing, true); | ||
return IRMutator::visit(call); | ||
} | ||
|
||
private: | ||
template<typename L, typename Body> | ||
Body visit_let(const L *let) { | ||
ScopedBinding<bool> binding(let_var_inside_indexing, let->name, false); | ||
Body body = mutate(let->body); | ||
|
||
ScopedValue s(is_inside_indexing, is_inside_indexing || let_var_inside_indexing.get(let->name)); | ||
Expr value = mutate(let->value); | ||
|
||
return L::make(let->name, std::move(value), std::move(body)); | ||
} | ||
|
||
bool bounds_smaller_than_type(const Interval &bounds, Type type) { | ||
return bounds.is_bounded() && !(equal(bounds.min, type.min()) && equal(bounds.max, type.max())); | ||
} | ||
|
||
/** | ||
* A let-var is marked "true" if is used somewhere in an indexing expression. | ||
* visit_let will process its value binding with is_inside_indexing set when | ||
* this is the case. | ||
*/ | ||
Scope<bool> let_var_inside_indexing; | ||
bool is_inside_indexing = false; | ||
}; | ||
|
||
} // namespace | ||
|
||
Stmt clamp_unsafe_accesses(const Stmt &s, const std::map<std::string, Function> &env, FuncValueBounds &func_bounds) { | ||
return ClampUnsafeAccesses(env, func_bounds).mutate(s); | ||
} | ||
|
||
} // namespace Halide::Internal |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
#ifndef HALIDE_CLAMPUNSAFEACCESSES_H | ||
#define HALIDE_CLAMPUNSAFEACCESSES_H | ||
|
||
/** \file | ||
* Defines the clamp_unsafe_accesses lowering pass. | ||
*/ | ||
|
||
#include "Bounds.h" | ||
#include "Expr.h" | ||
#include "Function.h" | ||
|
||
namespace Halide::Internal { | ||
|
||
/** Inject clamps around func calls h(...) when all the following conditions hold: | ||
* 1. The call is in an indexing context, such as: f(x) = g(h(x)); | ||
* 2. The FuncValueBounds of h are smaller than those of its type | ||
* 3. The allocation bounds of h might be wider than its compute bounds. | ||
*/ | ||
Stmt clamp_unsafe_accesses(const Stmt &s, const std::map<std::string, Function> &env, FuncValueBounds &func_bounds); | ||
|
||
} // namespace Halide::Internal | ||
|
||
#endif // HALIDE_CLAMPUNSAFEACCESSES_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
#include "Halide.h" | ||
|
||
using namespace Halide; | ||
|
||
// https://github.com/halide/Halide/issues/6131 | ||
// Prior to the ClampUnsafeAccesses pass, this test case would | ||
// crash as described in the comments below. | ||
|
||
int main(int argc, char **argv) { | ||
Var x{"x"}; | ||
|
||
Func f{"f"}, g{"g"}, h{"h"}, out{"out"}; | ||
|
||
const int min = -10000000; | ||
const int max = min + 20; | ||
|
||
h(x) = clamp(x, min, max); | ||
// Within its compute bounds, h's value will be within | ||
// [min,max]. Outside that, it's uninitialized memory. | ||
|
||
g(x) = sin(x); | ||
// Halide thinks g will be accessed within [min,max], so its | ||
// allocation bounds will be [min,max] | ||
|
||
f(x) = g(h(x)); | ||
f.vectorize(x, 64, TailStrategy::RoundUp); | ||
// f will access h at values outside its compute bounds, and get | ||
// garbage, and then use that garbage to access g outside its | ||
// allocation bounds. | ||
|
||
out(x) = f(x); | ||
|
||
h.compute_root(); | ||
g.compute_root(); | ||
f.compute_root(); | ||
|
||
out.realize({1}); | ||
|
||
printf("Success!\n"); | ||
return 0; | ||
} |