diff --git a/python_bindings/Makefile b/python_bindings/Makefile index ba2eaacdaaa9..73f60afbda84 100644 --- a/python_bindings/Makefile +++ b/python_bindings/Makefile @@ -151,7 +151,8 @@ GENPARAMS_complex=\ array_input.type=uint8 \ int_arg.size=2 \ simple_input.type=uint8 \ - untyped_buffer_input.type=uint8 + untyped_buffer_input.type=uint8 \ + untyped_buffer_output.type=uint8 GENPARAMS_simple=\ func_input.type=uint8 diff --git a/python_bindings/correctness/generators/CMakeLists.txt b/python_bindings/correctness/generators/CMakeLists.txt index 95c535c3c926..85590419a677 100644 --- a/python_bindings/correctness/generators/CMakeLists.txt +++ b/python_bindings/correctness/generators/CMakeLists.txt @@ -20,7 +20,8 @@ set(GENPARAMS_complex array_input.type=uint8 int_arg.size=2 simple_input.type=uint8 - untyped_buffer_input.type=uint8) + untyped_buffer_input.type=uint8 + untyped_buffer_output.type=uint8) set(GENPARAMS_simple func_input.type=uint8) diff --git a/python_bindings/correctness/generators/complex_generator.cpp b/python_bindings/correctness/generators/complex_generator.cpp index c1b72e50e93b..037bd5d71fc5 100644 --- a/python_bindings/correctness/generators/complex_generator.cpp +++ b/python_bindings/correctness/generators/complex_generator.cpp @@ -17,7 +17,6 @@ Halide::Buffer make_image(int extra) { class Complex : public Halide::Generator { public: - GeneratorParam untyped_buffer_output_type{"untyped_buffer_output_type", Float(32)}; GeneratorParam vectorize{"vectorize", true}; GeneratorParam intermediate_level{"intermediate_level", LoopLevel::root()}; @@ -52,7 +51,7 @@ class Complex : public Halide::Generator { // assert-fail, because there is no type constraint set: the type // will end up as whatever we infer from the values put into it. We'll use an // explicit GeneratorParam to allow us to set it. - untyped_buffer_output(x, y, c) = cast(untyped_buffer_output_type, untyped_buffer_input(x, y, c)); + untyped_buffer_output(x, y, c) = cast(untyped_buffer_output.output_type(), untyped_buffer_input(x, y, c)); // Gratuitous intermediate for the purpose of exercising // GeneratorParam diff --git a/python_bindings/correctness/pystub.py b/python_bindings/correctness/pystub.py index 79134ce2245e..ae96fbb89c08 100644 --- a/python_bindings/correctness/pystub.py +++ b/python_bindings/correctness/pystub.py @@ -38,21 +38,23 @@ def test_simple(gen): # ----------- Above set again, w/ GeneratorParam mixed in k = 42 + gp = { "offset": k } + # (positional) - f = gen(target, b_in, f_in, 3.5, offset=k) + f = gen(target, b_in, f_in, 3.5, generator_params=gp) _realize_and_check(f, k) # (keyword) - f = gen(target, offset=k, buffer_input=b_in, func_input=f_in, float_arg=3.5) + f = gen(target, generator_params=gp, buffer_input=b_in, func_input=f_in, float_arg=3.5) _realize_and_check(f, k) - f = gen(target, buffer_input=b_in, offset=k, func_input=f_in, float_arg=3.5) + f = gen(target, buffer_input=b_in, generator_params=gp, func_input=f_in, float_arg=3.5) _realize_and_check(f, k) - f = gen(target, buffer_input=b_in, func_input=f_in, offset=k, float_arg=3.5) + f = gen(target, buffer_input=b_in, func_input=f_in, generator_params=gp, float_arg=3.5) _realize_and_check(f, k) - f = gen(target, buffer_input=b_in, float_arg=3.5, func_input=f_in, offset=k) + f = gen(target, buffer_input=b_in, float_arg=3.5, func_input=f_in, generator_params=gp) _realize_and_check(f, k) # ----------- Test various failure modes @@ -104,19 +106,35 @@ def test_simple(gen): else: assert False, 'Did not see expected exception!' + try: + # generator_params is not a dict + f = gen(target, b_in, f_in, 3.5, generator_params=[1, 2, 3]) + except TypeError as e: + assert "cannot convert dictionary" in str(e) + else: + assert False, 'Did not see expected exception!' + + try: + # Bad gp name + f = gen(target, b_in, f_in, 3.5, generator_params={"foo": 0}) + except RuntimeError as e: + assert "has no GeneratorParam named: foo" in str(e) + else: + assert False, 'Did not see expected exception!' + try: # Bad input name - f = gen(target, buffer_input=b_in, float_arg=3.5, offset=k, funk_input=f_in) + f = gen(target, buffer_input=b_in, float_arg=3.5, generator_params=gp, funk_input=f_in) except RuntimeError as e: - assert "Expected exactly 3 keyword args for inputs, but saw 2." in str(e) + assert "Unknown input 'funk_input' specified via keyword argument." in str(e) else: assert False, 'Did not see expected exception!' try: # Bad gp name - f = gen(target, buffer_input=b_in, float_arg=3.5, offset=k, func_input=f_in, nonexistent_generator_param="wat") + f = gen(target, buffer_input=b_in, float_arg=3.5, generator_params=gp, func_input=f_in, nonexistent_generator_param="wat") except RuntimeError as e: - assert "has no GeneratorParam named: nonexistent_generator_param" in str(e) + assert "Unknown input 'nonexistent_generator_param' specified via keyword argument." in str(e) else: assert False, 'Did not see expected exception!' @@ -132,7 +150,9 @@ def test_looplevel(gen): simple_compute_at = hl.LoopLevel() simple = gen(target, buffer_input, func_input, 3.5, - compute_level=simple_compute_at) + generator_params = { + "compute_level": simple_compute_at + }) computed_output = hl.Func('computed_output') computed_output[x, y] = simple[x, y] + 3 @@ -171,9 +191,11 @@ def test_complex(gen): array_input=[ input, input ], float_arg=float_arg, int_arg=[ int_arg, int_arg ], - untyped_buffer_output_type="uint8", extra_func_input=func_input, - vectorize=True) + generator_params = { + "untyped_buffer_output.type": hl.UInt(8), + "vectorize": True + }) # return value is a tuple; unpack separately to avoid # making the callsite above unreadable diff --git a/python_bindings/stub/PyStubImpl.cpp b/python_bindings/stub/PyStubImpl.cpp index c0683ca8e42b..dc35a8535af1 100644 --- a/python_bindings/stub/PyStubImpl.cpp +++ b/python_bindings/stub/PyStubImpl.cpp @@ -114,7 +114,8 @@ py::object generate_impl(const GeneratorFactory &factory, const GeneratorContext // arg, they all must be specified that way; otherwise they must all be // positional, in the order declared in the Generator.) // - // GeneratorParams can only be specified by name, and are always optional. + // GeneratorParams are always specified as an optional named parameter + // called "generator_params", which is expected to be a python dict. std::map> kw_inputs; for (const auto &name : names.inputs) { @@ -127,34 +128,38 @@ py::object generate_impl(const GeneratorFactory &factory, const GeneratorContext // Process the kwargs first. for (auto kw : kwargs) { - // If the kwarg is the name of a known input, stick it in the input - // vector. If not, stick it in the GeneratorParamsMap (if it's invalid, - // an error will be reported further downstream). - std::string name = kw.first.cast(); - py::handle value = kw.second; - auto it = kw_inputs.find(name); - if (it != kw_inputs.end()) { - _halide_user_assert(it->second.empty()) - << "Generator Input named '" << it->first << "' was specified more than once."; - it->second = to_stub_inputs(py::cast(value)); - kw_inputs_specified++; - } else { - if (py::isinstance(value)) { - generator_params[name] = value.cast(); - } else if (py::isinstance(value)) { - // Convert [hl.UInt(8), hl.Int(16)] -> uint8,int16 - std::string v; - for (auto t : value) { - if (!v.empty()) { - v += ","; + const std::string name = kw.first.cast(); + const py::handle value = kw.second; + + if (name == "generator_params") { + py::dict gp = py::cast(value); + for (auto item : gp) { + const std::string gp_name = py::str(item.first).cast(); + const py::handle gp_value = item.second; + if (py::isinstance(gp_value)) { + generator_params[gp_name] = gp_value.cast(); + } else if (py::isinstance(gp_value)) { + // Convert [hl.UInt(8), hl.Int(16)] -> uint8,int16 + std::string v; + for (auto t : gp_value) { + if (!v.empty()) { + v += ","; + } + v += py::str(t).cast(); } - v += py::str(t).cast(); + generator_params[gp_name] = v; + } else { + generator_params[gp_name] = py::str(gp_value).cast(); } - generator_params[name] = v; - } else { - generator_params[name] = py::str(value).cast(); } + continue; } + + auto it = kw_inputs.find(name); + _halide_user_assert(it != kw_inputs.end()) << "Unknown input '" << name << "' specified via keyword argument."; + _halide_user_assert(it->second.empty()) << "Generator Input named '" << it->first << "' was specified more than once."; + it->second = to_stub_inputs(py::cast(value)); + kw_inputs_specified++; } std::vector> inputs; diff --git a/src/Func.cpp b/src/Func.cpp index d68e9f449841..c5bf309f7df7 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -211,7 +211,7 @@ const Type &Func::output_type() const { const std::vector &Func::output_types() const { const auto &types = defined() ? func.output_types() : func.required_types(); user_assert(!types.empty()) - << "Can't call Func::output_type on Func \"" << name() + << "Can't call Func::output_types on Func \"" << name() << "\" because it is undefined or has no type requirements.\n"; return types; } diff --git a/src/Generator.cpp b/src/Generator.cpp index b5c3fff0561b..66c1bf9f3128 100644 --- a/src/Generator.cpp +++ b/src/Generator.cpp @@ -79,6 +79,11 @@ bool is_valid_name(const std::string &n) { return false; } } + // prohibit this specific string so that we can use it for + // passing GeneratorParams in Python. + if (n == "generator_params") { + return false; + } return true; }