Skip to content

Commit

Permalink
Strip asserts right at the end of lowering (#8094)
Browse files Browse the repository at this point in the history
The simplifier exploits asserts to make simplification. When compiling
with NoAsserts, certain assertions aren't ever introduced, which means
that the simplifier can't exploit certain things that we know to be
true. Mostly this has a negative effect on code size. E.g. tail cases
get generated even though they are actually dead code.

This PR keeps all the assertions right until the end of lowering, when
it strips them in a dedicated pass.

This reduces object file size for a large production blob of Halide code
by ~10%, without measurably affecting runtime.
  • Loading branch information
abadams committed Feb 15, 2024
1 parent e6e1b6f commit 2855ca3
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 31 deletions.
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,7 @@ SOURCE_FILES = \
StorageFlattening.cpp \
StorageFolding.cpp \
StrictifyFloat.cpp \
StripAsserts.cpp \
Substitute.cpp \
Target.cpp \
Tracing.cpp \
Expand Down Expand Up @@ -785,6 +786,7 @@ HEADER_FILES = \
StorageFlattening.h \
StorageFolding.h \
StrictifyFloat.h \
StripAsserts.h \
Substitute.h \
Target.h \
Tracing.h \
Expand Down
39 changes: 13 additions & 26 deletions src/AddImageChecks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@ Stmt add_image_checks_inner(Stmt s,
const FuncValueBounds &fb,
bool will_inject_host_copies) {

bool no_asserts = t.has_feature(Target::NoAsserts);
bool no_bounds_query = t.has_feature(Target::NoBoundsQuery);

// First hunt for all the referenced buffers
Expand Down Expand Up @@ -618,12 +617,9 @@ Stmt add_image_checks_inner(Stmt s,
replace_with_constrained[name] = constrained_var;
}

Expr error = 0;
if (!no_asserts) {
error = Call::make(Int(32), "halide_error_constraint_violated",
{name, var, constrained_var_str, constrained_var},
Call::Extern);
}
Expr error = Call::make(Int(32), "halide_error_constraint_violated",
{name, var, constrained_var_str, constrained_var},
Call::Extern);

// Check the var passed in equals the constrained version (when not in inference mode)
asserts_constrained.push_back(AssertStmt::make(var == constrained_var, error));
Expand Down Expand Up @@ -679,14 +675,12 @@ Stmt add_image_checks_inner(Stmt s,
}
};

if (!no_asserts) {
// Inject the code that checks the host pointers.
prepend_stmts(&asserts_host_non_null);
prepend_stmts(&asserts_host_alignment);
prepend_stmts(&asserts_device_not_dirty);
prepend_stmts(&dims_no_overflow_asserts);
prepend_lets(&lets_overflow);
}
// Inject the code that checks the host pointers.
prepend_stmts(&asserts_host_non_null);
prepend_stmts(&asserts_host_alignment);
prepend_stmts(&asserts_device_not_dirty);
prepend_stmts(&dims_no_overflow_asserts);
prepend_lets(&lets_overflow);

// Replace uses of the var with the constrained versions in the
// rest of the program. We also need to respect the existence of
Expand All @@ -698,25 +692,18 @@ Stmt add_image_checks_inner(Stmt s,
// all in reverse order compared to execution, as we incrementally
// prepending code.

// Inject the code that checks the constraints are correct. We
// need these regardless of how NoAsserts is set, because they are
// what gets Halide to actually exploit the constraint.
// Inject the code that checks the constraints are correct.
prepend_stmts(&asserts_constrained);

if (!no_asserts) {
prepend_stmts(&asserts_required);
prepend_stmts(&asserts_type_checks);
}
prepend_stmts(&asserts_required);
prepend_stmts(&asserts_type_checks);

// Inject the code that returns early for inference mode.
if (!no_bounds_query) {
s = IfThenElse::make(!maybe_return_condition, s);
prepend_stmts(&buffer_rewrites);
}

if (!no_asserts) {
prepend_stmts(&asserts_proposed);
}
prepend_stmts(&asserts_proposed);

// Inject the code that defines the proposed sizes.
prepend_lets(&lets_proposed);
Expand Down
2 changes: 2 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ set(HEADER_FILES
StorageFlattening.h
StorageFolding.h
StrictifyFloat.h
StripAsserts.h
Substitute.h
Target.h
Tracing.h
Expand Down Expand Up @@ -340,6 +341,7 @@ set(SOURCE_FILES
StorageFlattening.cpp
StorageFolding.cpp
StrictifyFloat.cpp
StripAsserts.cpp
Substitute.cpp
Target.cpp
Tracing.cpp
Expand Down
7 changes: 7 additions & 0 deletions src/Lower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
#include "StorageFlattening.h"
#include "StorageFolding.h"
#include "StrictifyFloat.h"
#include "StripAsserts.h"
#include "Substitute.h"
#include "Tracing.h"
#include "TrimNoOps.h"
Expand Down Expand Up @@ -427,6 +428,12 @@ void lower_impl(const vector<Function> &output_funcs,
s = hoist_prefetches(s);
log("Lowering after hoisting prefetches:", s);

if (t.has_feature(Target::NoAsserts)) {
debug(1) << "Stripping asserts...\n";
s = strip_asserts(s);
log("Lowering after stripping asserts:", s);
}

debug(1) << "Lowering after final simplification:\n"
<< s << "\n\n";

Expand Down
6 changes: 1 addition & 5 deletions src/ScheduleFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1368,11 +1368,7 @@ class InjectFunctionRealization : public IRMutator {

// This is also the point at which we inject explicit bounds
// for this realization.
if (target.has_feature(Target::NoAsserts)) {
return s;
} else {
return inject_explicit_bounds(s, func);
}
return inject_explicit_bounds(s, func);
}

Stmt build_realize_function_from_group(Stmt s, int func_index) {
Expand Down
121 changes: 121 additions & 0 deletions src/StripAsserts.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
#include "StripAsserts.h"
#include "IRMutator.h"
#include "IROperator.h"
#include "IRVisitor.h"
#include <set>

namespace Halide {
namespace Internal {

namespace {

bool may_discard(const Expr &e) {
class MayDiscard : public IRVisitor {
using IRVisitor::visit;

void visit(const Call *op) override {
// Extern calls that are side-effecty in the sense that you can't
// move them around in the IR, but we're free to discard because
// they're just getters.
static const std::set<std::string> discardable{
Call::buffer_get_dimensions,
Call::buffer_get_min,
Call::buffer_get_extent,
Call::buffer_get_stride,
Call::buffer_get_max,
Call::buffer_get_host,
Call::buffer_get_device,
Call::buffer_get_device_interface,
Call::buffer_get_shape,
Call::buffer_get_host_dirty,
Call::buffer_get_device_dirty,
Call::buffer_get_type};

if (!(op->is_pure() ||
discardable.count(op->name))) {
result = false;
}
}

public:
bool result = true;
} d;
e.accept(&d);

return d.result;
}

class StripAsserts : public IRMutator {
using IRMutator::visit;

// We're going to track which symbols are used so that we can strip lets we
// don't need after removing the asserts.
std::set<std::string> used;

// Drop all assert stmts. Assumes that you don't want any side-effects from
// the condition.
Stmt visit(const AssertStmt *op) override {
return Evaluate::make(0);
}

Expr visit(const Variable *op) override {
used.insert(op->name);
return op;
}

Expr visit(const Load *op) override {
used.insert(op->name);
return IRMutator::visit(op);
}

Stmt visit(const Store *op) override {
used.insert(op->name);
return IRMutator::visit(op);
}

// Also dead-code eliminate any let stmts wrapped around asserts
Stmt visit(const LetStmt *op) override {
Stmt body = mutate(op->body);
if (is_no_op(body)) {
if (may_discard(op->value)) {
return body;
} else {
// We visit the value just to keep the used variable set
// accurate.
mutate(op->value);
return Evaluate::make(op->value);
}
} else if (body.same_as(op->body)) {
mutate(op->value);
return op;
} else if (may_discard(op->value) && !used.count(op->name)) {
return body;
} else {
mutate(op->value);
return LetStmt::make(op->name, op->value, body);
}
}

Stmt visit(const Block *op) override {
Stmt first = mutate(op->first);
Stmt rest = mutate(op->rest);
if (first.same_as(op->first) && rest.same_as(op->rest)) {
return op;
} else if (is_no_op(rest)) {
return first;
} else if (is_no_op(first)) {
return rest;
} else {
return Block::make(first, rest);
}
}
};

} // namespace

Stmt strip_asserts(const Stmt &s) {
return StripAsserts().mutate(s);
}

} // namespace Internal
} // namespace Halide
18 changes: 18 additions & 0 deletions src/StripAsserts.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#ifndef HALIDE_STRIP_ASSERTS_H
#define HALIDE_STRIP_ASSERTS_H

/** \file
* Defines the lowering pass that strips asserts when NoAsserts is set.
*/

#include "Expr.h"

namespace Halide {
namespace Internal {

Stmt strip_asserts(const Stmt &s);

} // namespace Internal
} // namespace Halide

#endif

0 comments on commit 2855ca3

Please sign in to comment.