Skip to content

Commit

Permalink
add_requirement() maintenance (halide#7045)
Browse files Browse the repository at this point in the history
* add_requirement() maintenance

This PR started out as a quick fix to add Python bindings for the `add_requirements` methods on Pipeline and Generator (which were missing), but expanded a bit to fix other issues as well:
- The implementation of `Generator::add_requirement` was subtly wrong, in that it only worked if you called the method after everything else in your `generate()` method. Now we accumulate requirements and insert them at the end, so you can call the method anywhere.
- We had C++ methods that took both an explicit `vector<Expr>` and also a variadic-template version, but the former required a mutable vector... and fixing this to not require that ended up creating ambiguity about which overloaded call to use. Added an ugly enable_if thing to resolve this.

(Side note halide#1: overloading methods to have both templated and non-templated versions with the same name is probably something to avoid in the future.)

(Side note halide#2: we should probably thing more carefully about using variadic templates in our public API in the future; we currently use it pretty heavily, but it tends to be messy and hard to reason about IMHO.)

* tidy

* remove underscores
  • Loading branch information
steven-johnson authored and ardier committed Mar 3, 2024
1 parent ea64dbe commit cdf9324
Show file tree
Hide file tree
Showing 14 changed files with 155 additions and 38 deletions.
7 changes: 7 additions & 0 deletions python_bindings/src/halide/_generator_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,10 @@ def using_autoscheduler(self):
def natural_vector_size(self, type: Type) -> int:
return self.target().natural_vector_size(type)

def add_requirement(self, condition: Expr, *args) -> None:
assert self._stage < _Stage.pipeline_built
self._pipeline_requirements.append((condition, [*args]))

@classmethod
def call(cls, *args, **kwargs):
generator = cls()
Expand Down Expand Up @@ -475,6 +479,7 @@ def __init__(self, generator_params: dict = {}):
self._requirements = {}
self._replacements = {}
self._in_configure = 0
self._pipeline_requirements = []

self._advance_to_gp_created()
if generator_params:
Expand Down Expand Up @@ -699,6 +704,8 @@ def _build_pipeline(self) -> Pipeline:
funcs.append(f)

self._pipeline = Pipeline(funcs)
for condition, error_args in self._pipeline_requirements:
self._pipeline.add_requirement(condition, *error_args)
self._stage = _Stage.pipeline_built
return self._pipeline

Expand Down
16 changes: 16 additions & 0 deletions python_bindings/src/halide/halide_/PyHalide.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,5 +101,21 @@ Expr double_to_expr_check(double v) {
return Expr(f);
}

std::vector<Expr> collect_print_args(const py::args &args) {
std::vector<Expr> v;
v.reserve(args.size());
for (size_t i = 0; i < args.size(); ++i) {
// No way to see if a cast will work: just have to try
// and fail. Normally we don't want string to be convertible
// to Expr, but in this unusual case we do.
try {
v.emplace_back(args[i].cast<std::string>());
} catch (...) {
v.push_back(args[i].cast<Expr>());
}
}
return v;
}

} // namespace PythonBindings
} // namespace Halide
1 change: 1 addition & 0 deletions python_bindings/src/halide/halide_/PyHalide.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ std::vector<T> args_to_vector(const py::args &args, size_t start_offset = 0, siz
return v;
}

std::vector<Expr> collect_print_args(const py::args &args);
Expr double_to_expr_check(double v);

} // namespace PythonBindings
Expand Down
28 changes: 2 additions & 26 deletions python_bindings/src/halide/halide_/PyIROperator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,6 @@
namespace Halide {
namespace PythonBindings {

namespace {

// TODO: clever template usage could generalize this to list-of-types-to-try.
std::vector<Expr> args_to_vector_for_print(const py::args &args, size_t start_offset = 0) {
if (args.size() < start_offset) {
throw py::value_error("Not enough arguments");
}
std::vector<Expr> v;
v.reserve(args.size() - (start_offset));
for (size_t i = start_offset; i < args.size(); ++i) {
// No way to see if a cast will work: just have to try
// and fail. Normally we don't want string to be convertible
// to Expr, but in this unusual case we do.
try {
v.emplace_back(args[i].cast<std::string>());
} catch (...) {
v.push_back(args[i].cast<Expr>());
}
}
return v;
}

} // namespace

void define_operators(py::module &m) {
m.def("max", [](const py::args &args) -> Expr {
if (args.size() < 2) {
Expand Down Expand Up @@ -149,11 +125,11 @@ void define_operators(py::module &m) {
m.def("reinterpret", (Expr(*)(Type, Expr)) & reinterpret);
m.def("cast", (Expr(*)(Type, Expr)) & cast);
m.def("print", [](const py::args &args) -> Expr {
return print(args_to_vector_for_print(args));
return print(collect_print_args(args));
});
m.def(
"print_when", [](const Expr &condition, const py::args &args) -> Expr {
return print_when(condition, args_to_vector_for_print(args));
return print_when(condition, collect_print_args(args));
},
py::arg("condition"));
m.def(
Expand Down
7 changes: 7 additions & 0 deletions python_bindings/src/halide/halide_/PyPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,13 @@ void define_pipeline(py::module &m) {
.def("defined", &Pipeline::defined)
.def("invalidate_cache", &Pipeline::invalidate_cache)

.def(
"add_requirement", [](Pipeline &p, const Expr &condition, const py::args &error_args) -> void {
auto v = collect_print_args(error_args);
p.add_requirement(condition, v);
},
py::arg("condition"))

.def("__repr__", [](const Pipeline &p) -> std::string {
std::ostringstream o;
o << "<halide.Pipeline [";
Expand Down
44 changes: 43 additions & 1 deletion python_bindings/test/correctness/addconstant_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test(addconstant_impl_func, offset):
scalar_u64 = 5724968371
scalar_i8 = -7
scalar_i16 = -30712
scalar_i32 = -98901
scalar_i32 = 98901
scalar_i64 = -8163465847
scalar_float = 3.14159
scalar_double = 1.61803
Expand Down Expand Up @@ -93,6 +93,48 @@ def test(addconstant_impl_func, offset):
for z in range(input_3d.shape[2]):
assert output_3d[x, y, z] == input_3d[x, y, z] + scalar_i8 + offset

try:
# Expected requirement failure #1
scalar_i32 = 0
addconstant_impl_func(
scalar_u1,
scalar_u8, scalar_u16, scalar_u32, scalar_u64,
scalar_i8, scalar_i16, scalar_i32, scalar_i64,
scalar_float, scalar_double,
input_u8, input_u16, input_u32, input_u64,
input_i8, input_i16, input_i32, input_i64,
input_float, input_double, input_2d, input_3d,
output_u8, output_u16, output_u32, output_u64,
output_i8, output_i16, output_i32, output_i64,
output_float, output_double, output_2d, output_3d,
)
except RuntimeError as e:
assert str(e) == "Halide Runtime Error: -27", e
else:
assert False, 'Did not see expected exception!'

try:
# Expected requirement failure #2 -- note that for AOT-compiled
# code in Python, the error message is stricly numeric (the text
# of the error isn't currently propagated int he exception).
scalar_i32 = -1
addconstant_impl_func(
scalar_u1,
scalar_u8, scalar_u16, scalar_u32, scalar_u64,
scalar_i8, scalar_i16, scalar_i32, scalar_i64,
scalar_float, scalar_double,
input_u8, input_u16, input_u32, input_u64,
input_i8, input_i16, input_i32, input_i64,
input_float, input_double, input_2d, input_3d,
output_u8, output_u16, output_u32, output_u64,
output_i8, output_i16, output_i32, output_i64,
output_float, output_double, output_2d, output_3d,
)
except RuntimeError as e:
assert str(e) == "Halide Runtime Error: -27", e
else:
assert False, 'Did not see expected exception!'


if __name__ == "__main__":
for t, o in TESTS_AND_OFFSETS:
Expand Down
31 changes: 31 additions & 0 deletions python_bindings/test/correctness/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,36 @@ def test_typed_funcs():
assert False, 'Did not see expected exception!'


def test_requirements():
delta = hl.Param(hl.Int(32), 'delta')
x = hl.Var('x')
f = hl.Func('f_requirements')
f[x] = x + delta

# Add a requirement
p = hl.Pipeline([f])
p.add_requirement(delta != 0) # error_args omitted
p.add_requirement(delta > 0, "negative values are bad", delta)

delta.set(1)
p.realize([10])

try:
delta.set(0)
p.realize([10])
except hl.HalideError as e:
assert 'Requirement Failed: (false)' in str(e)
else:
assert False, 'Did not see expected exception!'

try:
delta.set(-1)
p.realize([10])
except hl.HalideError as e:
assert 'Requirement Failed: (false) negative values are bad -1' in str(e)
else:
assert False, 'Did not see expected exception!'

if __name__ == "__main__":
test_compiletime_error()
test_runtime_error()
Expand All @@ -402,3 +432,4 @@ def test_typed_funcs():
test_basics5()
test_scalar_funcs()
test_bool_conversion()
test_requirements()
3 changes: 3 additions & 0 deletions python_bindings/test/generators/addconstantcpp_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ class AddConstantGenerator : public Halide::Generator<AddConstantGenerator> {
Var x, y, z;

void generate() {
add_requirement(scalar_int32 != 0); // error_args omitted for this case
add_requirement(scalar_int32 > 0, "negative values are bad", scalar_int32);

output_uint8(x) = input_uint8(x) + scalar_uint8;
output_uint16(x) = input_uint16(x) + scalar_uint16;
output_uint32(x) = input_uint32(x) + scalar_uint32;
Expand Down
3 changes: 3 additions & 0 deletions python_bindings/test/generators/addconstantpy_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ class AddConstantGenerator:

def generate(self):
g = self
g.add_requirement(g.scalar_int32 != 0) # error_args omitted for this case
g.add_requirement(g.scalar_int32 > 0, "negative values are bad", g.scalar_int32)

g.output_uint8[x] = g.input_uint8[x] + g.scalar_uint8
g.output_uint16[x] = g.input_uint16[x] + g.scalar_uint16
g.output_uint32[x] = g.input_uint32[x] + g.scalar_uint32
Expand Down
8 changes: 8 additions & 0 deletions src/Generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1554,6 +1554,11 @@ void GeneratorBase::pre_schedule() {
void GeneratorBase::post_schedule() {
}

void GeneratorBase::add_requirement(const Expr &condition, const std::vector<Expr> &error_args) {
internal_assert(!pipeline.defined());
requirements.push_back({condition, error_args});
}

Pipeline GeneratorBase::get_pipeline() {
check_min_phase(GenerateCalled);
if (!pipeline.defined()) {
Expand Down Expand Up @@ -1584,6 +1589,9 @@ Pipeline GeneratorBase::get_pipeline() {
}
}
pipeline = Pipeline(funcs);
for (const auto &r : requirements) {
pipeline.add_requirement(r.condition, r.error_args);
}
}
return pipeline;
}
Expand Down
17 changes: 14 additions & 3 deletions src/Generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -3444,9 +3444,14 @@ class GeneratorBase : public NamesInterface, public AbstractGenerator {
return p;
}

template<typename... Args>
HALIDE_NO_USER_CODE_INLINE void add_requirement(Expr condition, Args &&...args) {
get_pipeline().add_requirement(condition, std::forward<Args>(args)...);
void add_requirement(const Expr &condition, const std::vector<Expr> &error_args);

template<typename... Args,
typename = typename std::enable_if<Internal::all_are_printable_args<Args...>::value>::type>
inline HALIDE_NO_USER_CODE_INLINE void add_requirement(const Expr &condition, Args &&...error_args) {
std::vector<Expr> collected_args;
Internal::collect_print_args(collected_args, std::forward<Args>(error_args)...);
add_requirement(condition, collected_args);
}

void trace_pipeline() {
Expand Down Expand Up @@ -3636,6 +3641,12 @@ class GeneratorBase : public NamesInterface, public AbstractGenerator {
std::string generator_registered_name, generator_stub_name;
Pipeline pipeline;

struct Requirement {
Expr condition;
std::vector<Expr> error_args;
};
std::vector<Requirement> requirements;

// Return our GeneratorParamInfo.
GeneratorParamInfo &param_info();

Expand Down
9 changes: 9 additions & 0 deletions src/IROperator.h
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,15 @@ Stmt remove_promises(const Stmt &s);
* the tagged expression. If not, returns the expression. */
Expr unwrap_tags(const Expr &e);

template<typename T>
struct is_printable_arg {
static constexpr bool value = std::is_convertible<T, const char *>::value ||
std::is_convertible<T, Halide::Expr>::value;
};

template<typename... Args>
struct all_are_printable_args : meta_and<is_printable_arg<Args>...> {};

// Secondary args to print can be Exprs or const char *
inline HALIDE_NO_USER_CODE_INLINE void collect_print_args(std::vector<Expr> &args) {
}
Expand Down
2 changes: 1 addition & 1 deletion src/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,7 @@ Realization Pipeline::realize(JITUserContext *context,
return r;
}

void Pipeline::add_requirement(const Expr &condition, std::vector<Expr> &error_args) {
void Pipeline::add_requirement(const Expr &condition, const std::vector<Expr> &error_args) {
user_assert(defined()) << "Pipeline is undefined\n";

// It is an error for a requirement to reference a Func or a Var
Expand Down
17 changes: 10 additions & 7 deletions src/Pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -547,17 +547,20 @@ class Pipeline {
* with the remaining arguments, and return
* halide_error_code_requirement_failed. Requirements are checked
* in the order added. */
void add_requirement(const Expr &condition, std::vector<Expr> &error);

/** Generate begin_pipeline and end_pipeline tracing calls for this pipeline. */
void trace_pipeline();
// @{
void add_requirement(const Expr &condition, const std::vector<Expr> &error_args);

template<typename... Args>
inline HALIDE_NO_USER_CODE_INLINE void add_requirement(const Expr &condition, Args &&...args) {
template<typename... Args,
typename = typename std::enable_if<Internal::all_are_printable_args<Args...>::value>::type>
inline HALIDE_NO_USER_CODE_INLINE void add_requirement(const Expr &condition, Args &&...error_args) {
std::vector<Expr> collected_args;
Internal::collect_print_args(collected_args, std::forward<Args>(args)...);
Internal::collect_print_args(collected_args, std::forward<Args>(error_args)...);
add_requirement(condition, collected_args);
}
// @}

/** Generate begin_pipeline and end_pipeline tracing calls for this pipeline. */
void trace_pipeline();

private:
std::string generate_function_name() const;
Expand Down

0 comments on commit cdf9324

Please sign in to comment.