Skip to content

Commit

Permalink
Add Func::type()/types(), deprecate Func::output_type()/output_types() (
Browse files Browse the repository at this point in the history
halide#6772)

* rename GIOBase::type() and friends

* Func::output_type() -> Func::type()

* Add type() forwarders for inputs

* Add Func::dimensions() wrapper

* Update Func.h
  • Loading branch information
steven-johnson authored and ardier committed Mar 3, 2024
1 parent d6d28aa commit 10d5410
Show file tree
Hide file tree
Showing 12 changed files with 165 additions and 97 deletions.
24 changes: 12 additions & 12 deletions apps/fft/fft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ ComplexExpr mul(ComplexExpr a, float re_b, float im_b) {
// Specializations for some small DFTs of the first dimension of a
// Func f.
ComplexFunc dft2(ComplexFunc f, const string &prefix) {
Type type = f.output_types()[0];
Type type = f.types()[0];

ComplexFunc F(prefix + "X2");
F(f.args()) = undef_z(type);
Expand All @@ -122,7 +122,7 @@ ComplexFunc dft2(ComplexFunc f, const string &prefix) {
}

ComplexFunc dft4(ComplexFunc f, int sign, const string &prefix) {
Type type = f.output_types()[0];
Type type = f.types()[0];

ComplexFunc F(prefix + "X4");
F(f.args()) = undef_z(type);
Expand Down Expand Up @@ -156,7 +156,7 @@ ComplexFunc dft6(ComplexFunc f, int sign, const string &prefix) {
ComplexExpr W2_3(re_W1_3, -im_W1_3);
ComplexExpr W4_3 = W1_3;

Type type = f.output_types()[0];
Type type = f.types()[0];

ComplexFunc F(prefix + "X8");
F(f.args()) = undef_z(type);
Expand Down Expand Up @@ -187,7 +187,7 @@ ComplexFunc dft6(ComplexFunc f, int sign, const string &prefix) {
ComplexFunc dft8(ComplexFunc f, int sign, const string &prefix) {
const float sqrt2_2 = 0.70710678f;

Type type = f.output_types()[0];
Type type = f.types()[0];

ComplexFunc F(prefix + "X8");
F(f.args()) = undef_z(type);
Expand Down Expand Up @@ -346,7 +346,7 @@ ComplexFunc fft_dim1(ComplexFunc x,

// The vector width is the least common multiple of the previous vector
// width and the natural vector size for this stage.
vector_width = lcm(vector_width, target.natural_vector_size(v.output_types()[0]));
vector_width = lcm(vector_width, target.natural_vector_size(v.types()[0]));

// Compute the R point DFT of the subtransform.
ComplexFunc V = dft1d_c2c(v, R, sign, prefix);
Expand All @@ -355,7 +355,7 @@ ComplexFunc fft_dim1(ComplexFunc x,
// pass. Since the pure stage is undef, we explicitly generate the
// arg list (because we can't use placeholders in an undef
// definition).
exchange(A({n0, n1}, args)) = undef_z(V.output_types()[0]);
exchange(A({n0, n1}, args)) = undef_z(V.types()[0]);

RDom rs(0, R, 0, N / R);
r_ = rs.x;
Expand Down Expand Up @@ -444,7 +444,7 @@ std::pair<FuncType, FuncType> tiled_transpose(FuncType f, int max_tile_size,
}

const int tile_size =
std::min(max_tile_size, target.natural_vector_size(f.output_types()[0]));
std::min(max_tile_size, target.natural_vector_size(f.types()[0]));

vector<Var> args = f.args();
Var x(args[0]), y(args[1]);
Expand Down Expand Up @@ -685,7 +685,7 @@ ComplexFunc fft2d_r2c(Func r,
int N0 = product(R0);
int N1 = product(R1);

const int natural_vector_size = target.natural_vector_size(r.output_types()[0]);
const int natural_vector_size = target.natural_vector_size(r.types()[0]);

// If this FFT is small, the logic related to zipping and unzipping
// the FFT may be expensive compared to just brute forcing with a complex
Expand All @@ -705,7 +705,7 @@ ComplexFunc fft2d_r2c(Func r,
result(A({n0, n1}, args)) = dft(A({n0, n1}, args));
result.bound(n0, 0, N0);
result.bound(n1, 0, (N1 + 1) / 2 + 1);
result.vectorize(n0, std::min(N0, target.natural_vector_size(result.output_types()[0])));
result.vectorize(n0, std::min(N0, target.natural_vector_size(result.types()[0])));
dft.compute_at(result, outer);
return result;
}
Expand All @@ -731,7 +731,7 @@ ComplexFunc fft2d_r2c(Func r,
ComplexFunc zipped(prefix + "zipped");
int zip_width = desc.vector_width;
if (zip_width <= 0) {
zip_width = target.natural_vector_size(r.output_types()[0]);
zip_width = target.natural_vector_size(r.types()[0]);
}
// Ensure the zip width divides the zipped extent.
zip_width = gcd(zip_width, N0 / 2);
Expand Down Expand Up @@ -911,7 +911,7 @@ Func fft2d_c2r(ComplexFunc c,
// If this FFT is small, the logic related to zipping and unzipping
// the FFT may be expensive compared to just brute forcing with a complex
// FFT.
const int natural_vector_size = target.natural_vector_size(c.output_types()[0]);
const int natural_vector_size = target.natural_vector_size(c.types()[0]);

bool skip_zip = N0 < natural_vector_size * 2;

Expand Down Expand Up @@ -967,7 +967,7 @@ Func fft2d_c2r(ComplexFunc c,
// The vector width of the zipping performed below.
int zip_width = desc.vector_width;
if (zip_width <= 0) {
zip_width = gcd(target.natural_vector_size(dft0T.output_types()[0]), N1 / 2);
zip_width = gcd(target.natural_vector_size(dft0T.types()[0]), N1 / 2);
}

// transpose so we can take the DFT of the columns again.
Expand Down
10 changes: 5 additions & 5 deletions python_bindings/correctness/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def test_typed_funcs():
f = hl.Func('f')
assert not f.defined()
try:
assert f.output_type() == Int(32)
assert f.type() == Int(32)
except hl.HalideError as e:
assert 'it is undefined' in str(e)
else:
Expand All @@ -339,21 +339,21 @@ def test_typed_funcs():

f = hl.Func(hl.Int(32), 2, 'f')
assert not f.defined()
assert f.output_type() == hl.Int(32)
assert f.output_types() == [hl.Int(32)]
assert f.type() == hl.Int(32)
assert f.types() == [hl.Int(32)]
assert f.outputs() == 1
assert f.dimensions() == 2

f = hl.Func([hl.Int(32), hl.Float(64)], 3, 'f')
assert not f.defined()
try:
assert f.output_type() == hl.Int(32)
assert f.type() == hl.Int(32)
except hl.HalideError as e:
assert 'it returns a Tuple' in str(e)
else:
assert False, 'Did not see expected exception!'

assert f.output_types() == [hl.Int(32), hl.Float(64)]
assert f.types() == [hl.Int(32), hl.Float(64)]
assert f.outputs() == 2
assert f.dimensions() == 3

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class Complex : public Halide::Generator<Complex> {
// assert-fail, because there is no type constraint set: the type
// will end up as whatever we infer from the values put into it. We'll use an
// explicit GeneratorParam to allow us to set it.
untyped_buffer_output(x, y, c) = cast(untyped_buffer_output.output_type(), untyped_buffer_input(x, y, c));
untyped_buffer_output(x, y, c) = cast(untyped_buffer_output.type(), untyped_buffer_input(x, y, c));

// Gratuitous intermediate for the purpose of exercising
// GeneratorParam<LoopLevel>
Expand Down
21 changes: 19 additions & 2 deletions python_bindings/src/PyFunc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,25 @@ void define_func(py::module &m) {
})
.def("defined", &Func::defined)
.def("outputs", &Func::outputs)
.def("output_type", &Func::output_type)
.def("output_types", &Func::output_types)

.def("output_type", [](Func &f) {
// HALIDE_ATTRIBUTE_DEPRECATED("Func::output_type() is deprecated; call Func::type() instead.")
PyErr_WarnEx(PyExc_DeprecationWarning,
"Func.output_type() is deprecated; use Func.type() instead.",
1);
return f.type();
})

.def("output_types", [](Func &f) {
// HALIDE_ATTRIBUTE_DEPRECATED("Func::output_types() is deprecated; call Func::types() instead.")
PyErr_WarnEx(PyExc_DeprecationWarning,
"Func.output_types() is deprecated; use Func.types() instead.",
1);
return f.types();
})

.def("type", &Func::type)
.def("types", &Func::types)

.def("bound", &Func::bound, py::arg("var"), py::arg("min"), py::arg("extent"))

Expand Down
4 changes: 2 additions & 2 deletions python_bindings/tutorial/lesson_14_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ def main():
# You can also query any defined hl.Func for the types it produces.
f1 = hl.Func("f1")
f1[x] = hl.cast(hl.UInt(8), x)
assert f1.output_types()[0] == hl.UInt(8)
assert f1.types()[0] == hl.UInt(8)

f2 = hl.Func("f2")
f2[x] = (x, hl.sin(x))
assert f2.output_types()[0] == hl.Int(32) and f2.output_types()[1] == hl.Float(32)
assert f2.types()[0] == hl.Int(32) and f2.types()[1] == hl.Float(32)

# Type promotion rules.
if True:
Expand Down
12 changes: 6 additions & 6 deletions src/Func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,22 +196,22 @@ void Func::define_extern(const std::string &function_name,
}

/** Get the types of the buffers returned by an extern definition. */
const Type &Func::output_type() const {
const Type &Func::type() const {
const auto &types = defined() ? func.output_types() : func.required_types();
if (types.empty()) {
user_error << "Can't call Func::output_type on Func \"" << name()
user_error << "Can't call Func::type on Func \"" << name()
<< "\" because it is undefined or has no type requirements.\n";
} else if (types.size() > 1) {
user_error << "Can't call Func::output_type on Func \"" << name()
user_error << "Can't call Func::type on Func \"" << name()
<< "\" because it returns a Tuple.\n";
}
return types[0];
}

const std::vector<Type> &Func::output_types() const {
const std::vector<Type> &Func::types() const {
const auto &types = defined() ? func.output_types() : func.required_types();
user_assert(!types.empty())
<< "Can't call Func::output_types on Func \"" << name()
<< "Can't call Func::types on Func \"" << name()
<< "\" because it is undefined or has no type requirements.\n";
return types;
}
Expand Down Expand Up @@ -981,7 +981,7 @@ Func Stage::rfactor(vector<pair<RVar, Var>> preserved) {
}

if (!prover_result.xs[i].var.empty()) {
Expr prev_val = Call::make(intm.output_types()[i], func_name,
Expr prev_val = Call::make(intm.types()[i], func_name,
f_store_args, Call::CallType::Halide,
FunctionPtr(), i);
replacements.emplace(prover_result.xs[i].var, prev_val);
Expand Down
21 changes: 18 additions & 3 deletions src/Func.h
Original file line number Diff line number Diff line change
Expand Up @@ -1204,13 +1204,28 @@ class Func {
// @}

/** Get the type(s) of the outputs of this Func.
* If the Func isn't yet defined, but was specified with required types,
*
* It is not legal to call type() unless the Func has non-Tuple elements.
*
* If the Func isn't yet defined, and was not specified with required types,
* a runtime error will occur.
*
* If the Func isn't yet defined, but *was* specified with required types,
* the requirements will be returned. */
// @{
const Type &output_type() const;
const std::vector<Type> &output_types() const;
const Type &type() const;
const std::vector<Type> &types() const;
// @}

HALIDE_ATTRIBUTE_DEPRECATED("Func::output_type() is deprecated; use Func::type() instead.")
const Type &output_type() const {
return type();
}
HALIDE_ATTRIBUTE_DEPRECATED("Func::output_types() is deprecated; use Func::types() instead.")
const std::vector<Type> &output_types() const {
return types();
}

/** Get the number of outputs of this Func. Corresponds to the
* size of the Tuple this Func was defined to return.
* If the Func isn't yet defined, but was specified with required types,
Expand Down
Loading

0 comments on commit 10d5410

Please sign in to comment.