Skip to content

Commit

Permalink
Revise PyStub calling convention for GeneratorParams (#6742)
Browse files Browse the repository at this point in the history
This is a rethink of #6661, trying to make it saner in anticipation of the ongoing Python Generator work.

TL;DR: instead of mixing GeneratorParams in with the rest of the keywords, segregate them into an optional `generator_params` keyword argument, which is a plain Python dict. This neatly solves a couple of problems:
- synthetic params with funky names aren't a problem anymore.
- error reporting is simpler because before an unknown keyword could have been intended to be a GP or an Input.
- GP values are now clear and distinct from Inputs, which is IMHO a good thing.

This is technically a breaking change, but I doubt anyone will notice; this is mainly here to get a sane convention in place for use with Python Generators as well.

Also, a drive-by change to Func::output_types() to fix the assertion error message.
  • Loading branch information
steven-johnson committed May 4, 2022
1 parent 92dfb61 commit 1606039
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 42 deletions.
3 changes: 2 additions & 1 deletion python_bindings/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ GENPARAMS_complex=\
array_input.type=uint8 \
int_arg.size=2 \
simple_input.type=uint8 \
untyped_buffer_input.type=uint8
untyped_buffer_input.type=uint8 \
untyped_buffer_output.type=uint8

GENPARAMS_simple=\
func_input.type=uint8
Expand Down
3 changes: 2 additions & 1 deletion python_bindings/correctness/generators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ set(GENPARAMS_complex
array_input.type=uint8
int_arg.size=2
simple_input.type=uint8
untyped_buffer_input.type=uint8)
untyped_buffer_input.type=uint8
untyped_buffer_output.type=uint8)

set(GENPARAMS_simple
func_input.type=uint8)
Expand Down
3 changes: 1 addition & 2 deletions python_bindings/correctness/generators/complex_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ Halide::Buffer<Type, 3> make_image(int extra) {

class Complex : public Halide::Generator<Complex> {
public:
GeneratorParam<Type> untyped_buffer_output_type{"untyped_buffer_output_type", Float(32)};
GeneratorParam<bool> vectorize{"vectorize", true};
GeneratorParam<LoopLevel> intermediate_level{"intermediate_level", LoopLevel::root()};

Expand Down Expand Up @@ -52,7 +51,7 @@ class Complex : public Halide::Generator<Complex> {
// assert-fail, because there is no type constraint set: the type
// will end up as whatever we infer from the values put into it. We'll use an
// explicit GeneratorParam to allow us to set it.
untyped_buffer_output(x, y, c) = cast(untyped_buffer_output_type, untyped_buffer_input(x, y, c));
untyped_buffer_output(x, y, c) = cast(untyped_buffer_output.output_type(), untyped_buffer_input(x, y, c));

// Gratuitous intermediate for the purpose of exercising
// GeneratorParam<LoopLevel>
Expand Down
46 changes: 34 additions & 12 deletions python_bindings/correctness/pystub.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,23 @@ def test_simple(gen):
# ----------- Above set again, w/ GeneratorParam mixed in
k = 42

gp = { "offset": k }

# (positional)
f = gen(target, b_in, f_in, 3.5, offset=k)
f = gen(target, b_in, f_in, 3.5, generator_params=gp)
_realize_and_check(f, k)

# (keyword)
f = gen(target, offset=k, buffer_input=b_in, func_input=f_in, float_arg=3.5)
f = gen(target, generator_params=gp, buffer_input=b_in, func_input=f_in, float_arg=3.5)
_realize_and_check(f, k)

f = gen(target, buffer_input=b_in, offset=k, func_input=f_in, float_arg=3.5)
f = gen(target, buffer_input=b_in, generator_params=gp, func_input=f_in, float_arg=3.5)
_realize_and_check(f, k)

f = gen(target, buffer_input=b_in, func_input=f_in, offset=k, float_arg=3.5)
f = gen(target, buffer_input=b_in, func_input=f_in, generator_params=gp, float_arg=3.5)
_realize_and_check(f, k)

f = gen(target, buffer_input=b_in, float_arg=3.5, func_input=f_in, offset=k)
f = gen(target, buffer_input=b_in, float_arg=3.5, func_input=f_in, generator_params=gp)
_realize_and_check(f, k)

# ----------- Test various failure modes
Expand Down Expand Up @@ -104,19 +106,35 @@ def test_simple(gen):
else:
assert False, 'Did not see expected exception!'

try:
# generator_params is not a dict
f = gen(target, b_in, f_in, 3.5, generator_params=[1, 2, 3])
except TypeError as e:
assert "cannot convert dictionary" in str(e)
else:
assert False, 'Did not see expected exception!'

try:
# Bad gp name
f = gen(target, b_in, f_in, 3.5, generator_params={"foo": 0})
except RuntimeError as e:
assert "has no GeneratorParam named: foo" in str(e)
else:
assert False, 'Did not see expected exception!'

try:
# Bad input name
f = gen(target, buffer_input=b_in, float_arg=3.5, offset=k, funk_input=f_in)
f = gen(target, buffer_input=b_in, float_arg=3.5, generator_params=gp, funk_input=f_in)
except RuntimeError as e:
assert "Expected exactly 3 keyword args for inputs, but saw 2." in str(e)
assert "Unknown input 'funk_input' specified via keyword argument." in str(e)
else:
assert False, 'Did not see expected exception!'

try:
# Bad gp name
f = gen(target, buffer_input=b_in, float_arg=3.5, offset=k, func_input=f_in, nonexistent_generator_param="wat")
f = gen(target, buffer_input=b_in, float_arg=3.5, generator_params=gp, func_input=f_in, nonexistent_generator_param="wat")
except RuntimeError as e:
assert "has no GeneratorParam named: nonexistent_generator_param" in str(e)
assert "Unknown input 'nonexistent_generator_param' specified via keyword argument." in str(e)
else:
assert False, 'Did not see expected exception!'

Expand All @@ -132,7 +150,9 @@ def test_looplevel(gen):

simple_compute_at = hl.LoopLevel()
simple = gen(target, buffer_input, func_input, 3.5,
compute_level=simple_compute_at)
generator_params = {
"compute_level": simple_compute_at
})

computed_output = hl.Func('computed_output')
computed_output[x, y] = simple[x, y] + 3
Expand Down Expand Up @@ -171,9 +191,11 @@ def test_complex(gen):
array_input=[ input, input ],
float_arg=float_arg,
int_arg=[ int_arg, int_arg ],
untyped_buffer_output_type="uint8",
extra_func_input=func_input,
vectorize=True)
generator_params = {
"untyped_buffer_output.type": hl.UInt(8),
"vectorize": True
})

# return value is a tuple; unpack separately to avoid
# making the callsite above unreadable
Expand Down
55 changes: 30 additions & 25 deletions python_bindings/stub/PyStubImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ py::object generate_impl(const GeneratorFactory &factory, const GeneratorContext
// arg, they all must be specified that way; otherwise they must all be
// positional, in the order declared in the Generator.)
//
// GeneratorParams can only be specified by name, and are always optional.
// GeneratorParams are always specified as an optional named parameter
// called "generator_params", which is expected to be a python dict.

std::map<std::string, std::vector<StubInput>> kw_inputs;
for (const auto &name : names.inputs) {
Expand All @@ -127,34 +128,38 @@ py::object generate_impl(const GeneratorFactory &factory, const GeneratorContext

// Process the kwargs first.
for (auto kw : kwargs) {
// If the kwarg is the name of a known input, stick it in the input
// vector. If not, stick it in the GeneratorParamsMap (if it's invalid,
// an error will be reported further downstream).
std::string name = kw.first.cast<std::string>();
py::handle value = kw.second;
auto it = kw_inputs.find(name);
if (it != kw_inputs.end()) {
_halide_user_assert(it->second.empty())
<< "Generator Input named '" << it->first << "' was specified more than once.";
it->second = to_stub_inputs(py::cast<py::object>(value));
kw_inputs_specified++;
} else {
if (py::isinstance<LoopLevel>(value)) {
generator_params[name] = value.cast<LoopLevel>();
} else if (py::isinstance<py::list>(value)) {
// Convert [hl.UInt(8), hl.Int(16)] -> uint8,int16
std::string v;
for (auto t : value) {
if (!v.empty()) {
v += ",";
const std::string name = kw.first.cast<std::string>();
const py::handle value = kw.second;

if (name == "generator_params") {
py::dict gp = py::cast<py::dict>(value);
for (auto item : gp) {
const std::string gp_name = py::str(item.first).cast<std::string>();
const py::handle gp_value = item.second;
if (py::isinstance<LoopLevel>(gp_value)) {
generator_params[gp_name] = gp_value.cast<LoopLevel>();
} else if (py::isinstance<py::list>(gp_value)) {
// Convert [hl.UInt(8), hl.Int(16)] -> uint8,int16
std::string v;
for (auto t : gp_value) {
if (!v.empty()) {
v += ",";
}
v += py::str(t).cast<std::string>();
}
v += py::str(t).cast<std::string>();
generator_params[gp_name] = v;
} else {
generator_params[gp_name] = py::str(gp_value).cast<std::string>();
}
generator_params[name] = v;
} else {
generator_params[name] = py::str(value).cast<std::string>();
}
continue;
}

auto it = kw_inputs.find(name);
_halide_user_assert(it != kw_inputs.end()) << "Unknown input '" << name << "' specified via keyword argument.";
_halide_user_assert(it->second.empty()) << "Generator Input named '" << it->first << "' was specified more than once.";
it->second = to_stub_inputs(py::cast<py::object>(value));
kw_inputs_specified++;
}

std::vector<std::vector<StubInput>> inputs;
Expand Down
2 changes: 1 addition & 1 deletion src/Func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ const Type &Func::output_type() const {
const std::vector<Type> &Func::output_types() const {
const auto &types = defined() ? func.output_types() : func.required_types();
user_assert(!types.empty())
<< "Can't call Func::output_type on Func \"" << name()
<< "Can't call Func::output_types on Func \"" << name()
<< "\" because it is undefined or has no type requirements.\n";
return types;
}
Expand Down
5 changes: 5 additions & 0 deletions src/Generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ bool is_valid_name(const std::string &n) {
return false;
}
}
// prohibit this specific string so that we can use it for
// passing GeneratorParams in Python.
if (n == "generator_params") {
return false;
}
return true;
}

Expand Down

0 comments on commit 1606039

Please sign in to comment.