Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Func::type()/types(), deprecate Func::output_type()/output_types() #6772

Merged
merged 5 commits into from
May 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions apps/fft/fft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ ComplexExpr mul(ComplexExpr a, float re_b, float im_b) {
// Specializations for some small DFTs of the first dimension of a
// Func f.
ComplexFunc dft2(ComplexFunc f, const string &prefix) {
Type type = f.output_types()[0];
Type type = f.types()[0];

ComplexFunc F(prefix + "X2");
F(f.args()) = undef_z(type);
Expand All @@ -122,7 +122,7 @@ ComplexFunc dft2(ComplexFunc f, const string &prefix) {
}

ComplexFunc dft4(ComplexFunc f, int sign, const string &prefix) {
Type type = f.output_types()[0];
Type type = f.types()[0];

ComplexFunc F(prefix + "X4");
F(f.args()) = undef_z(type);
Expand Down Expand Up @@ -156,7 +156,7 @@ ComplexFunc dft6(ComplexFunc f, int sign, const string &prefix) {
ComplexExpr W2_3(re_W1_3, -im_W1_3);
ComplexExpr W4_3 = W1_3;

Type type = f.output_types()[0];
Type type = f.types()[0];

ComplexFunc F(prefix + "X8");
F(f.args()) = undef_z(type);
Expand Down Expand Up @@ -187,7 +187,7 @@ ComplexFunc dft6(ComplexFunc f, int sign, const string &prefix) {
ComplexFunc dft8(ComplexFunc f, int sign, const string &prefix) {
const float sqrt2_2 = 0.70710678f;

Type type = f.output_types()[0];
Type type = f.types()[0];

ComplexFunc F(prefix + "X8");
F(f.args()) = undef_z(type);
Expand Down Expand Up @@ -346,7 +346,7 @@ ComplexFunc fft_dim1(ComplexFunc x,

// The vector width is the least common multiple of the previous vector
// width and the natural vector size for this stage.
vector_width = lcm(vector_width, target.natural_vector_size(v.output_types()[0]));
vector_width = lcm(vector_width, target.natural_vector_size(v.types()[0]));

// Compute the R point DFT of the subtransform.
ComplexFunc V = dft1d_c2c(v, R, sign, prefix);
Expand All @@ -355,7 +355,7 @@ ComplexFunc fft_dim1(ComplexFunc x,
// pass. Since the pure stage is undef, we explicitly generate the
// arg list (because we can't use placeholders in an undef
// definition).
exchange(A({n0, n1}, args)) = undef_z(V.output_types()[0]);
exchange(A({n0, n1}, args)) = undef_z(V.types()[0]);

RDom rs(0, R, 0, N / R);
r_ = rs.x;
Expand Down Expand Up @@ -444,7 +444,7 @@ std::pair<FuncType, FuncType> tiled_transpose(FuncType f, int max_tile_size,
}

const int tile_size =
std::min(max_tile_size, target.natural_vector_size(f.output_types()[0]));
std::min(max_tile_size, target.natural_vector_size(f.types()[0]));

vector<Var> args = f.args();
Var x(args[0]), y(args[1]);
Expand Down Expand Up @@ -685,7 +685,7 @@ ComplexFunc fft2d_r2c(Func r,
int N0 = product(R0);
int N1 = product(R1);

const int natural_vector_size = target.natural_vector_size(r.output_types()[0]);
const int natural_vector_size = target.natural_vector_size(r.types()[0]);

// If this FFT is small, the logic related to zipping and unzipping
// the FFT may be expensive compared to just brute forcing with a complex
Expand All @@ -705,7 +705,7 @@ ComplexFunc fft2d_r2c(Func r,
result(A({n0, n1}, args)) = dft(A({n0, n1}, args));
result.bound(n0, 0, N0);
result.bound(n1, 0, (N1 + 1) / 2 + 1);
result.vectorize(n0, std::min(N0, target.natural_vector_size(result.output_types()[0])));
result.vectorize(n0, std::min(N0, target.natural_vector_size(result.types()[0])));
dft.compute_at(result, outer);
return result;
}
Expand All @@ -731,7 +731,7 @@ ComplexFunc fft2d_r2c(Func r,
ComplexFunc zipped(prefix + "zipped");
int zip_width = desc.vector_width;
if (zip_width <= 0) {
zip_width = target.natural_vector_size(r.output_types()[0]);
zip_width = target.natural_vector_size(r.types()[0]);
}
// Ensure the zip width divides the zipped extent.
zip_width = gcd(zip_width, N0 / 2);
Expand Down Expand Up @@ -911,7 +911,7 @@ Func fft2d_c2r(ComplexFunc c,
// If this FFT is small, the logic related to zipping and unzipping
// the FFT may be expensive compared to just brute forcing with a complex
// FFT.
const int natural_vector_size = target.natural_vector_size(c.output_types()[0]);
const int natural_vector_size = target.natural_vector_size(c.types()[0]);

bool skip_zip = N0 < natural_vector_size * 2;

Expand Down Expand Up @@ -967,7 +967,7 @@ Func fft2d_c2r(ComplexFunc c,
// The vector width of the zipping performed below.
int zip_width = desc.vector_width;
if (zip_width <= 0) {
zip_width = gcd(target.natural_vector_size(dft0T.output_types()[0]), N1 / 2);
zip_width = gcd(target.natural_vector_size(dft0T.types()[0]), N1 / 2);
}

// transpose so we can take the DFT of the columns again.
Expand Down
10 changes: 5 additions & 5 deletions python_bindings/correctness/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def test_typed_funcs():
f = hl.Func('f')
assert not f.defined()
try:
assert f.output_type() == Int(32)
assert f.type() == Int(32)
except hl.HalideError as e:
assert 'it is undefined' in str(e)
else:
Expand All @@ -339,21 +339,21 @@ def test_typed_funcs():

f = hl.Func(hl.Int(32), 2, 'f')
assert not f.defined()
assert f.output_type() == hl.Int(32)
assert f.output_types() == [hl.Int(32)]
assert f.type() == hl.Int(32)
assert f.types() == [hl.Int(32)]
assert f.outputs() == 1
assert f.dimensions() == 2

f = hl.Func([hl.Int(32), hl.Float(64)], 3, 'f')
assert not f.defined()
try:
assert f.output_type() == hl.Int(32)
assert f.type() == hl.Int(32)
except hl.HalideError as e:
assert 'it returns a Tuple' in str(e)
else:
assert False, 'Did not see expected exception!'

assert f.output_types() == [hl.Int(32), hl.Float(64)]
assert f.types() == [hl.Int(32), hl.Float(64)]
assert f.outputs() == 2
assert f.dimensions() == 3

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class Complex : public Halide::Generator<Complex> {
// assert-fail, because there is no type constraint set: the type
// will end up as whatever we infer from the values put into it. We'll use an
// explicit GeneratorParam to allow us to set it.
untyped_buffer_output(x, y, c) = cast(untyped_buffer_output.output_type(), untyped_buffer_input(x, y, c));
untyped_buffer_output(x, y, c) = cast(untyped_buffer_output.type(), untyped_buffer_input(x, y, c));

// Gratuitous intermediate for the purpose of exercising
// GeneratorParam<LoopLevel>
Expand Down
21 changes: 19 additions & 2 deletions python_bindings/src/PyFunc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,25 @@ void define_func(py::module &m) {
})
.def("defined", &Func::defined)
.def("outputs", &Func::outputs)
.def("output_type", &Func::output_type)
.def("output_types", &Func::output_types)

.def("output_type", [](Func &f) {
// HALIDE_ATTRIBUTE_DEPRECATED("Func::output_type() is deprecated; call Func::type() instead.")
PyErr_WarnEx(PyExc_DeprecationWarning,
"Func.output_type() is deprecated; use Func.type() instead.",
1);
return f.type();
})

.def("output_types", [](Func &f) {
// HALIDE_ATTRIBUTE_DEPRECATED("Func::output_types() is deprecated; call Func::types() instead.")
PyErr_WarnEx(PyExc_DeprecationWarning,
"Func.output_types() is deprecated; use Func.types() instead.",
1);
return f.types();
})

.def("type", &Func::type)
.def("types", &Func::types)

.def("bound", &Func::bound, py::arg("var"), py::arg("min"), py::arg("extent"))

Expand Down
4 changes: 2 additions & 2 deletions python_bindings/tutorial/lesson_14_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ def main():
# You can also query any defined hl.Func for the types it produces.
f1 = hl.Func("f1")
f1[x] = hl.cast(hl.UInt(8), x)
assert f1.output_types()[0] == hl.UInt(8)
assert f1.types()[0] == hl.UInt(8)

f2 = hl.Func("f2")
f2[x] = (x, hl.sin(x))
assert f2.output_types()[0] == hl.Int(32) and f2.output_types()[1] == hl.Float(32)
assert f2.types()[0] == hl.Int(32) and f2.types()[1] == hl.Float(32)

# Type promotion rules.
if True:
Expand Down
12 changes: 6 additions & 6 deletions src/Func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,22 +196,22 @@ void Func::define_extern(const std::string &function_name,
}

/** Get the types of the buffers returned by an extern definition. */
const Type &Func::output_type() const {
const Type &Func::type() const {
const auto &types = defined() ? func.output_types() : func.required_types();
if (types.empty()) {
user_error << "Can't call Func::output_type on Func \"" << name()
user_error << "Can't call Func::type on Func \"" << name()
<< "\" because it is undefined or has no type requirements.\n";
} else if (types.size() > 1) {
user_error << "Can't call Func::output_type on Func \"" << name()
user_error << "Can't call Func::type on Func \"" << name()
<< "\" because it returns a Tuple.\n";
}
return types[0];
}

const std::vector<Type> &Func::output_types() const {
const std::vector<Type> &Func::types() const {
const auto &types = defined() ? func.output_types() : func.required_types();
user_assert(!types.empty())
<< "Can't call Func::output_types on Func \"" << name()
<< "Can't call Func::types on Func \"" << name()
<< "\" because it is undefined or has no type requirements.\n";
return types;
}
Expand Down Expand Up @@ -981,7 +981,7 @@ Func Stage::rfactor(vector<pair<RVar, Var>> preserved) {
}

if (!prover_result.xs[i].var.empty()) {
Expr prev_val = Call::make(intm.output_types()[i], func_name,
Expr prev_val = Call::make(intm.types()[i], func_name,
f_store_args, Call::CallType::Halide,
FunctionPtr(), i);
replacements.emplace(prover_result.xs[i].var, prev_val);
Expand Down
21 changes: 18 additions & 3 deletions src/Func.h
Original file line number Diff line number Diff line change
Expand Up @@ -1204,13 +1204,28 @@ class Func {
// @}

/** Get the type(s) of the outputs of this Func.
* If the Func isn't yet defined, but was specified with required types,
*
* It is not legal to call type() unless the Func has non-Tuple elements.
*
* If the Func isn't yet defined, and was not specified with required types,
* a runtime error will occur.
*
* If the Func isn't yet defined, but *was* specified with required types,
* the requirements will be returned. */
// @{
const Type &output_type() const;
const std::vector<Type> &output_types() const;
const Type &type() const;
const std::vector<Type> &types() const;
// @}

HALIDE_ATTRIBUTE_DEPRECATED("Func::output_type() is deprecated; use Func::type() instead.")
const Type &output_type() const {
return type();
}
HALIDE_ATTRIBUTE_DEPRECATED("Func::output_types() is deprecated; use Func::types() instead.")
const std::vector<Type> &output_types() const {
return types();
}

/** Get the number of outputs of this Func. Corresponds to the
* size of the Tuple this Func was defined to return.
* If the Func isn't yet defined, but was specified with required types,
Expand Down
Loading