diff --git a/apps/fft/fft.cpp b/apps/fft/fft.cpp index 79382129c763..862b3f3e81e5 100644 --- a/apps/fft/fft.cpp +++ b/apps/fft/fft.cpp @@ -107,7 +107,7 @@ ComplexExpr mul(ComplexExpr a, float re_b, float im_b) { // Specializations for some small DFTs of the first dimension of a // Func f. ComplexFunc dft2(ComplexFunc f, const string &prefix) { - Type type = f.output_types()[0]; + Type type = f.types()[0]; ComplexFunc F(prefix + "X2"); F(f.args()) = undef_z(type); @@ -122,7 +122,7 @@ ComplexFunc dft2(ComplexFunc f, const string &prefix) { } ComplexFunc dft4(ComplexFunc f, int sign, const string &prefix) { - Type type = f.output_types()[0]; + Type type = f.types()[0]; ComplexFunc F(prefix + "X4"); F(f.args()) = undef_z(type); @@ -156,7 +156,7 @@ ComplexFunc dft6(ComplexFunc f, int sign, const string &prefix) { ComplexExpr W2_3(re_W1_3, -im_W1_3); ComplexExpr W4_3 = W1_3; - Type type = f.output_types()[0]; + Type type = f.types()[0]; ComplexFunc F(prefix + "X8"); F(f.args()) = undef_z(type); @@ -187,7 +187,7 @@ ComplexFunc dft6(ComplexFunc f, int sign, const string &prefix) { ComplexFunc dft8(ComplexFunc f, int sign, const string &prefix) { const float sqrt2_2 = 0.70710678f; - Type type = f.output_types()[0]; + Type type = f.types()[0]; ComplexFunc F(prefix + "X8"); F(f.args()) = undef_z(type); @@ -346,7 +346,7 @@ ComplexFunc fft_dim1(ComplexFunc x, // The vector width is the least common multiple of the previous vector // width and the natural vector size for this stage. - vector_width = lcm(vector_width, target.natural_vector_size(v.output_types()[0])); + vector_width = lcm(vector_width, target.natural_vector_size(v.types()[0])); // Compute the R point DFT of the subtransform. ComplexFunc V = dft1d_c2c(v, R, sign, prefix); @@ -355,7 +355,7 @@ ComplexFunc fft_dim1(ComplexFunc x, // pass. Since the pure stage is undef, we explicitly generate the // arg list (because we can't use placeholders in an undef // definition). - exchange(A({n0, n1}, args)) = undef_z(V.output_types()[0]); + exchange(A({n0, n1}, args)) = undef_z(V.types()[0]); RDom rs(0, R, 0, N / R); r_ = rs.x; @@ -444,7 +444,7 @@ std::pair tiled_transpose(FuncType f, int max_tile_size, } const int tile_size = - std::min(max_tile_size, target.natural_vector_size(f.output_types()[0])); + std::min(max_tile_size, target.natural_vector_size(f.types()[0])); vector args = f.args(); Var x(args[0]), y(args[1]); @@ -685,7 +685,7 @@ ComplexFunc fft2d_r2c(Func r, int N0 = product(R0); int N1 = product(R1); - const int natural_vector_size = target.natural_vector_size(r.output_types()[0]); + const int natural_vector_size = target.natural_vector_size(r.types()[0]); // If this FFT is small, the logic related to zipping and unzipping // the FFT may be expensive compared to just brute forcing with a complex @@ -705,7 +705,7 @@ ComplexFunc fft2d_r2c(Func r, result(A({n0, n1}, args)) = dft(A({n0, n1}, args)); result.bound(n0, 0, N0); result.bound(n1, 0, (N1 + 1) / 2 + 1); - result.vectorize(n0, std::min(N0, target.natural_vector_size(result.output_types()[0]))); + result.vectorize(n0, std::min(N0, target.natural_vector_size(result.types()[0]))); dft.compute_at(result, outer); return result; } @@ -731,7 +731,7 @@ ComplexFunc fft2d_r2c(Func r, ComplexFunc zipped(prefix + "zipped"); int zip_width = desc.vector_width; if (zip_width <= 0) { - zip_width = target.natural_vector_size(r.output_types()[0]); + zip_width = target.natural_vector_size(r.types()[0]); } // Ensure the zip width divides the zipped extent. zip_width = gcd(zip_width, N0 / 2); @@ -911,7 +911,7 @@ Func fft2d_c2r(ComplexFunc c, // If this FFT is small, the logic related to zipping and unzipping // the FFT may be expensive compared to just brute forcing with a complex // FFT. - const int natural_vector_size = target.natural_vector_size(c.output_types()[0]); + const int natural_vector_size = target.natural_vector_size(c.types()[0]); bool skip_zip = N0 < natural_vector_size * 2; @@ -967,7 +967,7 @@ Func fft2d_c2r(ComplexFunc c, // The vector width of the zipping performed below. int zip_width = desc.vector_width; if (zip_width <= 0) { - zip_width = gcd(target.natural_vector_size(dft0T.output_types()[0]), N1 / 2); + zip_width = gcd(target.natural_vector_size(dft0T.types()[0]), N1 / 2); } // transpose so we can take the DFT of the columns again. diff --git a/python_bindings/correctness/basics.py b/python_bindings/correctness/basics.py index 45996e4465bf..75581378e6d3 100644 --- a/python_bindings/correctness/basics.py +++ b/python_bindings/correctness/basics.py @@ -316,7 +316,7 @@ def test_typed_funcs(): f = hl.Func('f') assert not f.defined() try: - assert f.output_type() == Int(32) + assert f.type() == Int(32) except hl.HalideError as e: assert 'it is undefined' in str(e) else: @@ -339,21 +339,21 @@ def test_typed_funcs(): f = hl.Func(hl.Int(32), 2, 'f') assert not f.defined() - assert f.output_type() == hl.Int(32) - assert f.output_types() == [hl.Int(32)] + assert f.type() == hl.Int(32) + assert f.types() == [hl.Int(32)] assert f.outputs() == 1 assert f.dimensions() == 2 f = hl.Func([hl.Int(32), hl.Float(64)], 3, 'f') assert not f.defined() try: - assert f.output_type() == hl.Int(32) + assert f.type() == hl.Int(32) except hl.HalideError as e: assert 'it returns a Tuple' in str(e) else: assert False, 'Did not see expected exception!' - assert f.output_types() == [hl.Int(32), hl.Float(64)] + assert f.types() == [hl.Int(32), hl.Float(64)] assert f.outputs() == 2 assert f.dimensions() == 3 diff --git a/python_bindings/correctness/generators/complex_generator.cpp b/python_bindings/correctness/generators/complex_generator.cpp index 037bd5d71fc5..f39cc9e4a31b 100644 --- a/python_bindings/correctness/generators/complex_generator.cpp +++ b/python_bindings/correctness/generators/complex_generator.cpp @@ -51,7 +51,7 @@ class Complex : public Halide::Generator { // assert-fail, because there is no type constraint set: the type // will end up as whatever we infer from the values put into it. We'll use an // explicit GeneratorParam to allow us to set it. - untyped_buffer_output(x, y, c) = cast(untyped_buffer_output.output_type(), untyped_buffer_input(x, y, c)); + untyped_buffer_output(x, y, c) = cast(untyped_buffer_output.type(), untyped_buffer_input(x, y, c)); // Gratuitous intermediate for the purpose of exercising // GeneratorParam diff --git a/python_bindings/src/PyFunc.cpp b/python_bindings/src/PyFunc.cpp index bb8e88061fd0..1e90b6e221d9 100644 --- a/python_bindings/src/PyFunc.cpp +++ b/python_bindings/src/PyFunc.cpp @@ -161,8 +161,25 @@ void define_func(py::module &m) { }) .def("defined", &Func::defined) .def("outputs", &Func::outputs) - .def("output_type", &Func::output_type) - .def("output_types", &Func::output_types) + + .def("output_type", [](Func &f) { + // HALIDE_ATTRIBUTE_DEPRECATED("Func::output_type() is deprecated; call Func::type() instead.") + PyErr_WarnEx(PyExc_DeprecationWarning, + "Func.output_type() is deprecated; use Func.type() instead.", + 1); + return f.type(); + }) + + .def("output_types", [](Func &f) { + // HALIDE_ATTRIBUTE_DEPRECATED("Func::output_types() is deprecated; call Func::types() instead.") + PyErr_WarnEx(PyExc_DeprecationWarning, + "Func.output_types() is deprecated; use Func.types() instead.", + 1); + return f.types(); + }) + + .def("type", &Func::type) + .def("types", &Func::types) .def("bound", &Func::bound, py::arg("var"), py::arg("min"), py::arg("extent")) diff --git a/python_bindings/tutorial/lesson_14_types.py b/python_bindings/tutorial/lesson_14_types.py index 927995ff3694..b61a7b90ddf1 100644 --- a/python_bindings/tutorial/lesson_14_types.py +++ b/python_bindings/tutorial/lesson_14_types.py @@ -64,11 +64,11 @@ def main(): # You can also query any defined hl.Func for the types it produces. f1 = hl.Func("f1") f1[x] = hl.cast(hl.UInt(8), x) - assert f1.output_types()[0] == hl.UInt(8) + assert f1.types()[0] == hl.UInt(8) f2 = hl.Func("f2") f2[x] = (x, hl.sin(x)) - assert f2.output_types()[0] == hl.Int(32) and f2.output_types()[1] == hl.Float(32) + assert f2.types()[0] == hl.Int(32) and f2.types()[1] == hl.Float(32) # Type promotion rules. if True: diff --git a/src/Func.cpp b/src/Func.cpp index c5bf309f7df7..f29125adfce5 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -196,22 +196,22 @@ void Func::define_extern(const std::string &function_name, } /** Get the types of the buffers returned by an extern definition. */ -const Type &Func::output_type() const { +const Type &Func::type() const { const auto &types = defined() ? func.output_types() : func.required_types(); if (types.empty()) { - user_error << "Can't call Func::output_type on Func \"" << name() + user_error << "Can't call Func::type on Func \"" << name() << "\" because it is undefined or has no type requirements.\n"; } else if (types.size() > 1) { - user_error << "Can't call Func::output_type on Func \"" << name() + user_error << "Can't call Func::type on Func \"" << name() << "\" because it returns a Tuple.\n"; } return types[0]; } -const std::vector &Func::output_types() const { +const std::vector &Func::types() const { const auto &types = defined() ? func.output_types() : func.required_types(); user_assert(!types.empty()) - << "Can't call Func::output_types on Func \"" << name() + << "Can't call Func::types on Func \"" << name() << "\" because it is undefined or has no type requirements.\n"; return types; } @@ -981,7 +981,7 @@ Func Stage::rfactor(vector> preserved) { } if (!prover_result.xs[i].var.empty()) { - Expr prev_val = Call::make(intm.output_types()[i], func_name, + Expr prev_val = Call::make(intm.types()[i], func_name, f_store_args, Call::CallType::Halide, FunctionPtr(), i); replacements.emplace(prover_result.xs[i].var, prev_val); diff --git a/src/Func.h b/src/Func.h index 039a76b3ba2d..19c5d4c1ab1e 100644 --- a/src/Func.h +++ b/src/Func.h @@ -1204,13 +1204,28 @@ class Func { // @} /** Get the type(s) of the outputs of this Func. - * If the Func isn't yet defined, but was specified with required types, + * + * It is not legal to call type() unless the Func has non-Tuple elements. + * + * If the Func isn't yet defined, and was not specified with required types, + * a runtime error will occur. + * + * If the Func isn't yet defined, but *was* specified with required types, * the requirements will be returned. */ // @{ - const Type &output_type() const; - const std::vector &output_types() const; + const Type &type() const; + const std::vector &types() const; // @} + HALIDE_ATTRIBUTE_DEPRECATED("Func::output_type() is deprecated; use Func::type() instead.") + const Type &output_type() const { + return type(); + } + HALIDE_ATTRIBUTE_DEPRECATED("Func::output_types() is deprecated; use Func::types() instead.") + const std::vector &output_types() const { + return types(); + } + /** Get the number of outputs of this Func. Corresponds to the * size of the Tuple this Func was defined to return. * If the Func isn't yet defined, but was specified with required types, diff --git a/src/Generator.cpp b/src/Generator.cpp index 4dbdd73a72ab..eaca8a38578e 100644 --- a/src/Generator.cpp +++ b/src/Generator.cpp @@ -1301,7 +1301,7 @@ GeneratorParamInfo::GeneratorParamInfo(GeneratorBase *generator, const size_t si const std::string &n = gio->name(); const std::string &gn = generator->generator_registered_name; - owned_synthetic_params.push_back(GeneratorParam_Synthetic::make(generator, gn, n + ".type", *gio, SyntheticParamType::Type, gio->types_defined())); + owned_synthetic_params.push_back(GeneratorParam_Synthetic::make(generator, gn, n + ".type", *gio, SyntheticParamType::Type, gio->gio_types_defined())); filter_generator_params.push_back(owned_synthetic_params.back().get()); if (gio->kind() != IOKind::Scalar) { @@ -1569,13 +1569,13 @@ Pipeline GeneratorBase::get_pipeline() { << "\" requires dimensions=" << output->dims() << " but was defined as dimensions=" << f.dimensions() << ".\n"; } - if (output->types_defined()) { - user_assert((int)f.outputs() == (int)output->types().size()) << "Output \"" << f.name() - << "\" requires a Tuple of size " << output->types().size() - << " but was defined as Tuple of size " << f.outputs() << ".\n"; - for (size_t i = 0; i < f.output_types().size(); ++i) { - Type expected = output->types().at(i); - Type actual = f.output_types()[i]; + if (output->gio_types_defined()) { + user_assert((int)f.outputs() == (int)output->gio_types().size()) << "Output \"" << f.name() + << "\" requires a Tuple of size " << output->gio_types().size() + << " but was defined as Tuple of size " << f.outputs() << ".\n"; + for (size_t i = 0; i < f.types().size(); ++i) { + Type expected = output->gio_types().at(i); + Type actual = f.types()[i]; user_assert(expected == actual) << "Output \"" << f.name() << "\" requires type " << expected << " but was defined as type " << actual << ".\n"; @@ -1616,7 +1616,7 @@ Module GeneratorBase::build_module(const std::string &function_name, for (size_t i = 0; i < output->funcs().size(); ++i) { auto from = output->funcs()[i].name(); auto to = output->array_name(i); - size_t tuple_size = output->types_defined() ? output->types().size() : 1; + size_t tuple_size = output->gio_types_defined() ? output->gio_types().size() : 1; for (size_t t = 0; t < tuple_size; ++t) { std::string suffix = (tuple_size > 1) ? ("." + std::to_string(t)) : ""; result.remap_metadata_name(from + suffix, to + suffix); @@ -1660,12 +1660,12 @@ Module GeneratorBase::build_gradient_module(const std::string &function_name) { // support for Tupled outputs could be added with some effort, so if this // is somehow deemed critical, go for it) for (const auto *input : pi.inputs()) { - const size_t tuple_size = input->types_defined() ? input->types().size() : 1; + const size_t tuple_size = input->gio_types_defined() ? input->gio_types().size() : 1; // Note: this should never happen internal_assert(tuple_size == 1) << "Tuple Inputs are not yet supported by build_gradient_module()"; } for (const auto *output : pi.outputs()) { - const size_t tuple_size = output->types_defined() ? output->types().size() : 1; + const size_t tuple_size = output->gio_types_defined() ? output->gio_types().size() : 1; internal_assert(tuple_size == 1) << "Tuple Outputs are not yet supported by build_gradient_module"; } @@ -1699,7 +1699,7 @@ Module GeneratorBase::build_gradient_module(const std::string &function_name) { const std::string grad_in_name = replace_all(grad_input_pattern, "$OUT$", output_name); // TODO(srj): does it make sense for gradient to be a non-float type? // For now, assume it's always float32 (unless the output is already some float). - const Type grad_in_type = output->type().is_float() ? output->type() : Float(32); + const Type grad_in_type = output->gio_type().is_float() ? output->gio_type() : Float(32); const int grad_in_dimensions = f.dimensions(); const ArgumentEstimates grad_in_estimates = f.output_buffer().parameter().get_argument_estimates(); internal_assert((int)grad_in_estimates.buffer_estimates.size() == grad_in_dimensions); @@ -1856,34 +1856,34 @@ IOKind GIOBase::kind() const { return kind_; } -bool GIOBase::types_defined() const { +bool GIOBase::gio_types_defined() const { return !types_.empty(); } -const std::vector &GIOBase::types() const { +const std::vector &GIOBase::gio_types() const { // If types aren't defined, but we have one Func that is, // we probably just set an Output and should propagate the types. - if (!types_defined()) { + if (!gio_types_defined()) { // use funcs_, not funcs(): the latter could give a much-less-helpful error message // in this case. const auto &f = funcs_; if (f.size() == 1 && f.at(0).defined()) { - check_matching_types(f.at(0).output_types()); + check_matching_types(f.at(0).types()); } } - user_assert(types_defined()) << "Type is not defined for " << input_or_output() << " '" << name() << "'; you may need to specify '" << name() << ".type' as a GeneratorParam, or call set_type() from the configure() method.\n"; + user_assert(gio_types_defined()) << "Type is not defined for " << input_or_output() << " '" << name() << "'; you may need to specify '" << name() << ".type' as a GeneratorParam, or call set_type() from the configure() method.\n"; return types_; } -Type GIOBase::type() const { - const auto &t = types(); +Type GIOBase::gio_type() const { + const auto &t = gio_types(); internal_assert(t.size() == 1) << "Expected types_.size() == 1, saw " << t.size() << " for " << name() << "\n"; return t.at(0); } void GIOBase::set_type(const Type &type) { generator->check_exact_phase(GeneratorBase::ConfigureCalled); - user_assert(!types_defined()) << "set_type() may only be called on an Input or Output that has no type specified."; + user_assert(!gio_types_defined()) << "set_type() may only be called on an Input or Output that has no type specified."; types_ = {type}; } @@ -1942,20 +1942,20 @@ void GIOBase::verify_internals() { << "Expected outputs() == " << 1 << " but got " << f.outputs() << " for " << name() << "\n"; - user_assert(f.output_types().size() == 1) - << "Expected output_types().size() == " << 1 + user_assert(f.types().size() == 1) + << "Expected types().size() == " << 1 << " but got " << f.outputs() << " for " << name() << "\n"; - user_assert(f.output_types()[0] == type()) - << "Expected type " << type() - << " but got " << f.output_types()[0] + user_assert(f.types()[0] == gio_type()) + << "Expected type " << gio_type() + << " but got " << f.types()[0] << " for " << name() << "\n"; } } else { for (const Expr &e : exprs()) { user_assert(e.defined()) << "Input/Ouput " << name() << " is not defined.\n"; - user_assert(e.type() == type()) - << "Expected type " << type() + user_assert(e.type() == gio_type()) + << "Expected type " << gio_type() << " but got " << e.type() << " for " << name() << "\n"; } @@ -1973,10 +1973,10 @@ std::string GIOBase::array_name(size_t i) const { // If our type(s) are defined, ensure it matches the ones passed in, asserting if not. // If our type(s) are not defined, just set to the ones passed in. void GIOBase::check_matching_types(const std::vector &t) const { - if (types_defined()) { - user_assert(types().size() == t.size()) << "Type mismatch for " << name() << ": expected " << types().size() << " types but saw " << t.size(); + if (gio_types_defined()) { + user_assert(gio_types().size() == t.size()) << "Type mismatch for " << name() << ": expected " << gio_types().size() << " types but saw " << t.size(); for (size_t i = 0; i < t.size(); ++i) { - user_assert(types().at(i) == t.at(i)) << "Type mismatch for " << name() << ": expected " << types().at(i) << " saw " << t.at(i); + user_assert(gio_types().at(i) == t.at(i)) << "Type mismatch for " << name() << ": expected " << gio_types().at(i) << " saw " << t.at(i); } } else { types_ = t; @@ -2054,7 +2054,7 @@ void GeneratorInputBase::verify_internals() { void GeneratorInputBase::init_internals() { // Call these for the side-effect of asserting if the values aren't defined. (void)array_size(); - (void)types(); + (void)gio_types(); (void)dims(); parameters_.clear(); @@ -2062,13 +2062,13 @@ void GeneratorInputBase::init_internals() { funcs_.clear(); for (size_t i = 0; i < array_size(); ++i) { auto name = array_name(i); - parameters_.emplace_back(type(), kind() != IOKind::Scalar, dims(), name); + parameters_.emplace_back(gio_type(), kind() != IOKind::Scalar, dims(), name); auto &p = parameters_[i]; if (kind() != IOKind::Scalar) { internal_assert(dims() == p.dimensions()); funcs_.push_back(make_param_func(p, name)); } else { - Expr e = Internal::Variable::make(type(), name, p); + Expr e = Internal::Variable::make(gio_type(), name, p); exprs_.push_back(e); } } @@ -2089,10 +2089,10 @@ void GeneratorInputBase::set_inputs(const std::vector &inputs) { if (kind() == IOKind::Function) { auto f = in.func(); user_assert(f.defined()) << "The input for " << name() << " is an undefined Func. Please define it.\n"; - check_matching_types(f.output_types()); + check_matching_types(f.types()); check_matching_dims(f.dimensions()); funcs_.push_back(f); - parameters_.emplace_back(f.output_types().at(0), true, f.dimensions(), array_name(i)); + parameters_.emplace_back(f.types().at(0), true, f.dimensions(), array_name(i)); } else if (kind() == IOKind::Buffer) { auto p = in.parameter(); user_assert(p.defined()) << "The input for " << name() << " is an undefined Buffer. Please define it.\n"; @@ -2177,7 +2177,7 @@ void GeneratorOutputBase::init_internals() { exprs_.clear(); funcs_.clear(); if (array_size_defined()) { - const auto t = types_defined() ? types() : std::vector{}; + const auto t = gio_types_defined() ? gio_types() : std::vector{}; const int d = dims_defined() ? dims() : -1; for (size_t i = 0; i < array_size(); ++i) { funcs_.emplace_back(t, d, array_name(i)); diff --git a/src/Generator.h b/src/Generator.h index cdb8e8416fe0..6ab342219765 100644 --- a/src/Generator.h +++ b/src/Generator.h @@ -1427,6 +1427,15 @@ class StubInput { */ class GIOBase { public: + virtual ~GIOBase() = default; + + // These should only be called from configure() methods. + // TODO: find a way to enforce this. Better yet, find a way to remove these. + void set_type(const Type &type); + void set_dimensions(int dims); + void set_array_size(int size); + +protected: bool array_size_defined() const; size_t array_size() const; virtual bool is_array() const; @@ -1434,9 +1443,9 @@ class GIOBase { const std::string &name() const; IOKind kind() const; - bool types_defined() const; - const std::vector &types() const; - Type type() const; + bool gio_types_defined() const; + const std::vector &gio_types() const; + Type gio_type() const; bool dims_defined() const; int dims() const; @@ -1444,13 +1453,6 @@ class GIOBase { const std::vector &funcs() const; const std::vector &exprs() const; - virtual ~GIOBase() = default; - - void set_type(const Type &type); - void set_dimensions(int dims); - void set_array_size(int size); - -protected: GIOBase(size_t array_size, const std::string &name, IOKind kind, @@ -1499,6 +1501,7 @@ class GIOBase { private: template friend class GeneratorParam_Synthetic; + friend class GeneratorStub; public: GIOBase(const GIOBase &) = delete; @@ -1800,6 +1803,7 @@ class GeneratorInput_Buffer : public GeneratorInputImpl { HALIDE_FORWARD_METHOD_CONST(ImageParam, channels) HALIDE_FORWARD_METHOD_CONST(ImageParam, trace_loads) HALIDE_FORWARD_METHOD_CONST(ImageParam, add_trace_tag) + HALIDE_FORWARD_METHOD_CONST(ImageParam, type) // }@ }; @@ -1912,12 +1916,23 @@ class GeneratorInput_Func : public GeneratorInputImpl { // @{ HALIDE_FORWARD_METHOD_CONST(Func, args) HALIDE_FORWARD_METHOD_CONST(Func, defined) + HALIDE_FORWARD_METHOD_CONST(Func, dimensions) HALIDE_FORWARD_METHOD_CONST(Func, has_update_definition) HALIDE_FORWARD_METHOD_CONST(Func, num_update_definitions) - HALIDE_FORWARD_METHOD_CONST(Func, output_type) - HALIDE_FORWARD_METHOD_CONST(Func, output_types) + HALIDE_ATTRIBUTE_DEPRECATED("Func::output_type() is deprecated; use Func::type() instead.") + const Type &output_type() const { + this->check_gio_access(); + return this->as().type(); + } + HALIDE_ATTRIBUTE_DEPRECATED("Func::output_types() is deprecated; use Func::types() instead.") + const std::vector &output_types() const { + this->check_gio_access(); + return this->as().types(); + } HALIDE_FORWARD_METHOD_CONST(Func, outputs) HALIDE_FORWARD_METHOD_CONST(Func, rvars) + HALIDE_FORWARD_METHOD_CONST(Func, type) + HALIDE_FORWARD_METHOD_CONST(Func, types) HALIDE_FORWARD_METHOD_CONST(Func, update_args) HALIDE_FORWARD_METHOD_CONST(Func, update_value) HALIDE_FORWARD_METHOD_CONST(Func, update_values) @@ -1964,6 +1979,10 @@ class GeneratorInput_DynamicScalar : public GeneratorInputImpl { p.set_estimate(value); } } + + Type type() const { + return Expr(*this).type(); + } }; template @@ -2066,6 +2085,10 @@ class GeneratorInput_Scalar : public GeneratorInputImpl { } this->parameters_.at(index).set_estimate(e); } + + Type type() const { + return Expr(*this).type(); + } }; template @@ -2235,6 +2258,7 @@ class GeneratorOutputBase : public GIOBase { static_assert(std::is_same::value, "Only Func allowed here"); internal_assert(kind() != IOKind::Scalar); internal_assert(exprs_.empty()); + user_assert(!funcs_.empty()) << "No funcs_ are defined yet"; user_assert(funcs_.size() == 1) << "Use [] to access individual Funcs in Output"; return funcs_[0]; } @@ -2257,6 +2281,7 @@ class GeneratorOutputBase : public GIOBase { HALIDE_FORWARD_METHOD(Func, copy_to_host) HALIDE_FORWARD_METHOD(Func, define_extern) HALIDE_FORWARD_METHOD_CONST(Func, defined) + HALIDE_FORWARD_METHOD_CONST(Func, dimensions) HALIDE_FORWARD_METHOD(Func, fold_storage) HALIDE_FORWARD_METHOD(Func, fuse) HALIDE_FORWARD_METHOD(Func, gpu) @@ -2269,8 +2294,16 @@ class GeneratorOutputBase : public GIOBase { HALIDE_FORWARD_METHOD(Func, in) HALIDE_FORWARD_METHOD(Func, memoize) HALIDE_FORWARD_METHOD_CONST(Func, num_update_definitions) - HALIDE_FORWARD_METHOD_CONST(Func, output_type) - HALIDE_FORWARD_METHOD_CONST(Func, output_types) + HALIDE_ATTRIBUTE_DEPRECATED("Func::output_type() is deprecated; use Func::type() instead.") + const Type &output_type() const { + this->check_gio_access(); + return this->as().type(); + } + HALIDE_ATTRIBUTE_DEPRECATED("Func::output_types() is deprecated; use Func::types() instead.") + const std::vector &output_types() const { + this->check_gio_access(); + return this->as().types(); + } HALIDE_FORWARD_METHOD_CONST(Func, outputs) HALIDE_FORWARD_METHOD(Func, parallel) HALIDE_FORWARD_METHOD(Func, prefetch) @@ -2288,6 +2321,8 @@ class GeneratorOutputBase : public GIOBase { HALIDE_FORWARD_METHOD(Func, store_root) HALIDE_FORWARD_METHOD(Func, tile) HALIDE_FORWARD_METHOD(Func, trace_stores) + HALIDE_FORWARD_METHOD_CONST(Func, type) + HALIDE_FORWARD_METHOD_CONST(Func, types) HALIDE_FORWARD_METHOD(Func, unroll) HALIDE_FORWARD_METHOD(Func, update) HALIDE_FORWARD_METHOD_CONST(Func, update_args) @@ -2296,6 +2331,7 @@ class GeneratorOutputBase : public GIOBase { HALIDE_FORWARD_METHOD_CONST(Func, value) HALIDE_FORWARD_METHOD_CONST(Func, values) HALIDE_FORWARD_METHOD(Func, vectorize) + // }@ #undef HALIDE_OUTPUT_FORWARD @@ -2438,24 +2474,24 @@ class GeneratorOutput_Buffer : public GeneratorOutputImpl { internal_assert(f.defined()); - if (this->types_defined()) { - const auto &my_types = this->types(); - user_assert(my_types.size() == f.output_types().size()) + if (this->gio_types_defined()) { + const auto &my_types = this->gio_types(); + user_assert(my_types.size() == f.types().size()) << "Cannot assign Func \"" << f.name() << "\" to Output \"" << this->name() << "\"\n" << "Output " << this->name() << " is declared to have " << my_types.size() << " tuple elements" << " but Func " << f.name() - << " has " << f.output_types().size() << " tuple elements.\n"; + << " has " << f.types().size() << " tuple elements.\n"; for (size_t i = 0; i < my_types.size(); i++) { - user_assert(my_types[i] == f.output_types().at(i)) + user_assert(my_types[i] == f.types().at(i)) << "Cannot assign Func \"" << f.name() << "\" to Output \"" << this->name() << "\"\n" << (my_types.size() > 1 ? "In tuple element " + std::to_string(i) + ", " : "") << "Output " << this->name() << " has declared type " << my_types[i] << " but Func " << f.name() - << " has type " << f.output_types().at(i) << "\n"; + << " has type " << f.types().at(i) << "\n"; } } if (this->dims_defined()) { @@ -2563,9 +2599,9 @@ class GeneratorOutput_Buffer : public GeneratorOutputImpl { << "Cannot assign to the Output \"" << this->name() << "\": the expression is not convertible to the same Buffer type and/or dimensions.\n"; - if (this->types_defined()) { - user_assert(Type(buffer.type()) == this->type()) - << "Output " << this->name() << " should have type=" << this->type() << " but saw type=" << Type(buffer.type()) << "\n"; + if (this->gio_types_defined()) { + user_assert(Type(buffer.type()) == this->gio_type()) + << "Output " << this->name() << " should have type=" << this->gio_type() << " but saw type=" << Type(buffer.type()) << "\n"; } if (this->dims_defined()) { user_assert(buffer.dimensions() == this->dims()) diff --git a/test/correctness/typed_func.cpp b/test/correctness/typed_func.cpp index cd55ac5cdf6a..8eeca0829e3e 100644 --- a/test/correctness/typed_func.cpp +++ b/test/correctness/typed_func.cpp @@ -12,7 +12,7 @@ int main(int argc, char **argv) { assert(!f.defined()); // undefined funcs assert-fail for these calls. // but return 0 for outputs() and dimensions(). - // assert(f.output_type() == Int(32)); + // assert(f.type() == Int(32)); // assert(f.outputs() == 0); // assert(f.dimensions() == 0); } @@ -24,8 +24,8 @@ int main(int argc, char **argv) { assert(!f.defined()); const std::vector expected = {Int(32)}; - assert(f.output_type() == expected[0]); - assert(f.output_types() == expected); + assert(f.type() == expected[0]); + assert(f.types() == expected); assert(f.outputs() == 1); assert(f.dimensions() == 2); } @@ -36,8 +36,8 @@ int main(int argc, char **argv) { const std::vector expected = {Int(32), Float(64)}; assert(!f.defined()); - // assert(f.output_type() == expected[0]); // will assert-fail - assert(f.output_types() == expected); + // assert(f.type() == expected[0]); // will assert-fail + assert(f.types() == expected); assert(f.outputs() == 2); assert(f.dimensions() == 3); } diff --git a/test/generator/stubtest_generator.cpp b/test/generator/stubtest_generator.cpp index 08d7a6e6751e..f8889162ba82 100644 --- a/test/generator/stubtest_generator.cpp +++ b/test/generator/stubtest_generator.cpp @@ -78,7 +78,7 @@ class StubTest : public Halide::Generator { // Verify that Output::type() and ::dims() are well-defined after we define the Func assert(tuple_output.types()[0] == Float(32)); assert(tuple_output.types()[1] == Float(32)); - assert(tuple_output.dims() == 3); + assert(tuple_output.dimensions() == 3); array_output.resize(array_input.size()); for (size_t i = 0; i < array_input.size(); ++i) { diff --git a/tutorial/lesson_15_generators.cpp b/tutorial/lesson_15_generators.cpp index 969158b4926b..eee00ed05e60 100644 --- a/tutorial/lesson_15_generators.cpp +++ b/tutorial/lesson_15_generators.cpp @@ -155,7 +155,7 @@ class MySecondGenerator : public Halide::Generator { if (rotation != Rotation::None) { rotated .compute_at(output, y) - .vectorize(x, natural_vector_size(rotated.output_types()[0])); + .vectorize(x, natural_vector_size(rotated.types()[0])); } } };