Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Augment Halide::Func to allow for constraining Type and Dimensionality #6734

Merged
merged 1 commit into from
Apr 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions python_bindings/correctness/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,11 +309,44 @@ def test_bool_conversion():
# Verify that this doesn't fail with 'Argument passed to specialize must be of type bool'
f.compute_root().specialize(True)

def test_typed_funcs():
x = hl.Var('x')
y = hl.Var('y')

f = hl.Func(hl.Int(32), 1, 'f')
try:
f[x, y] = hl.i32(0);
f.realize([10, 10])
except RuntimeError as e:
assert 'is constrained to have exactly 1 dimensions, but is defined with 2 dimensions' in str(e)
else:
assert False, 'Did not see expected exception!'

f = hl.Func(hl.Int(32), 2, 'f')
try:
f[x, y] = hl.i16(0);
f.realize([10, 10])
except RuntimeError as e:
assert 'is constrained to only hold values of type int32 but is defined with values of type int16' in str(e)
else:
assert False, 'Did not see expected exception!'

f = hl.Func((hl.Int(32), hl.Float(32)), 2, 'f')
try:
f[x, y] = (hl.i16(0), hl.f64(0))
f.realize([10, 10])
except RuntimeError as e:
assert 'is constrained to only hold values of type (int32, float32) but is defined with values of type (int16, float64)' in str(e)
else:
assert False, 'Did not see expected exception!'


if __name__ == "__main__":
test_compiletime_error()
test_runtime_error()
test_misused_and()
test_misused_or()
test_typed_funcs()
test_float_or_int()
test_operator_order()
test_int_promotion()
Expand Down
2 changes: 2 additions & 0 deletions python_bindings/src/PyFunc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ void define_func(py::module &m) {
py::class_<Func>(m, "Func")
.def(py::init<>())
.def(py::init<std::string>())
.def(py::init<Type, int, std::string>(), py::arg("required_type"), py::arg("required_dimensions"), py::arg("name"))
.def(py::init<std::vector<Type>, int, std::string>(), py::arg("required_types"), py::arg("required_dimensions"), py::arg("name"))
.def(py::init<Expr>())
.def(py::init([](Buffer<> &b) -> Func { return Func(b); }))

Expand Down
6 changes: 4 additions & 2 deletions src/Buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

namespace Halide {

template<typename T = void, int Dims = Halide::Runtime::AnyDims>
constexpr int AnyDims = Halide::Runtime::AnyDims; // -1

template<typename T = void, int Dims = AnyDims>
class Buffer;

struct JITUserContext;
Expand Down Expand Up @@ -153,7 +155,7 @@ class Buffer {
}

public:
static constexpr int AnyDims = Halide::Runtime::AnyDims;
static constexpr int AnyDims = Halide::AnyDims;
static_assert(Dims == AnyDims || Dims >= 0);

typedef T ElemType;
Expand Down
15 changes: 15 additions & 0 deletions src/Func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@ Func::Func(const string &name)
: func(unique_name(name)) {
}

Func::Func(const Type &required_type, int required_dims, const string &name)
: func({required_type}, required_dims, unique_name(name)) {
}

Func::Func(const std::vector<Type> &required_types, int required_dims, const string &name)
: func(required_types, required_dims, unique_name(name)) {
}

Func::Func()
: func(make_entity_name(this, "Halide:.*:Func", 'f')) {
}
Expand Down Expand Up @@ -2926,6 +2934,8 @@ Stage FuncRef::operator=(const FuncRef &e) {
}
}

namespace {

// Inject a suitable base-case definition given an update
// definition. This is a helper for FuncRef::operator+= and co.
Func define_base_case(const Internal::Function &func, const vector<Expr> &a, const Tuple &e) {
Expand Down Expand Up @@ -2955,8 +2965,12 @@ Func define_base_case(const Internal::Function &func, const vector<Expr> &a, con
return define_base_case(func, a, Tuple(e));
}

} // namespace

template<typename BinaryOp>
Stage FuncRef::func_ref_update(const Tuple &e, int init_val) {
func.check_types(e);

internal_assert(e.size() > 1);

vector<Expr> init_values(e.size());
Expand All @@ -2975,6 +2989,7 @@ Stage FuncRef::func_ref_update(const Tuple &e, int init_val) {

template<typename BinaryOp>
Stage FuncRef::func_ref_update(Expr e, int init_val) {
func.check_types(e);
vector<Expr> expanded_args = args_with_implicit_vars({e});
FuncRef self_ref = define_base_case(func, expanded_args, cast(e.type(), init_val))(expanded_args);
return self_ref = BinaryOp()(Expr(self_ref), e);
Expand Down
13 changes: 13 additions & 0 deletions src/Func.h
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,19 @@ class Func {
/** Declare a new undefined function with the given name */
explicit Func(const std::string &name);

/** Declare a new undefined function with the given name.
* The function will be constrained to represent Exprs of required_type.
* If required_dims is not AnyDims, the function will be constrained to exactly
* that many dimensions. */
explicit Func(const Type &required_type, int required_dims, const std::string &name);

/** Declare a new undefined function with the given name.
* If required_types is not empty, the function will be constrained to represent
* Tuples of the same arity and types. (If required_types is empty, there is no constraint.)
* If required_dims is not AnyDims, the function will be constrained to exactly
* that many dimensions. */
explicit Func(const std::vector<Type> &required_types, int required_dims, const std::string &name);

/** Declare a new undefined function with an
* automatically-generated unique name */
Func();
Expand Down
124 changes: 122 additions & 2 deletions src/Function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ typedef map<FunctionPtr, FunctionPtr> DeepCopyMap;
struct FunctionContents;

namespace {

// Weaken all the references to a particular Function to break
// reference cycles. Also count the number of references found.
class WeakenFunctionPtrs : public IRMutator {
Expand Down Expand Up @@ -58,13 +59,30 @@ class WeakenFunctionPtrs : public IRMutator {
: func(f) {
}
};

} // namespace

struct FunctionContents {
std::string name;
std::string origin_name;
std::vector<Type> output_types;

/** Optional type constraints on the Function:
* - If empty, there are no constraints.
* - If size == 1, the Func is only allowed to have values of Expr with that type
* - If size > 1, the Func is only allowed to have values of Tuple with those types
*
* Note that when this is nonempty, then output_types should match
* required_types for all defined Functions.
*/
std::vector<Type> required_types;

/** Optional dimension constraints on the Function:
* - If required_dims == AnyDims, there are no constraints.
* - Otherwise, the Function's dimensionality must exactly match required_dims.
*/
int required_dims = AnyDims;

// The names of the dimensions of the Function. Corresponds to the
// LHS of the pure definition if there is one. Is also the initial
// stage of the dims and storage_dims. Used to identify dimensions
Expand Down Expand Up @@ -306,9 +324,100 @@ Function::Function(const std::string &n) {
contents->origin_name = n;
}

Function::Function(const std::vector<Type> &required_types, int required_dims, const std::string &n)
: Function(n) {
user_assert(required_dims >= AnyDims);
contents->required_types = required_types;
contents->required_dims = required_dims;
}

namespace {

template<typename T>
struct PrintTypeList {
const std::vector<T> &list_;

explicit PrintTypeList(const std::vector<T> &list)
: list_(list) {
}

friend std::ostream &operator<<(std::ostream &s, const PrintTypeList &self) {
const size_t n = self.list_.size();
if (n != 1) {
s << "(";
}
const char *comma = "";
for (const auto &t : self.list_) {
if constexpr (std::is_same<Type, T>::value) {
s << comma << t;
} else {
s << comma << t.type();
}
comma = ", ";
}
if (n != 1) {
s << ")";
}
return s;
}
};

bool types_match(const std::vector<Type> &types, const std::vector<Expr> &exprs) {
size_t n = types.size();
if (n != exprs.size()) {
return false;
}
for (size_t i = 0; i < n; i++) {
if (types[i] != exprs[i].type()) {
return false;
}
}
return true;
}

} // namespace

void Function::check_types(const Expr &e) const {
check_types(std::vector<Expr>{e});
}

void Function::check_types(const Tuple &t) const {
check_types(t.as_vector());
}

void Function::check_types(const Type &t) const {
check_types(std::vector<Type>{t});
}

void Function::check_types(const std::vector<Expr> &exprs) const {
if (!contents->required_types.empty()) {
user_assert(types_match(contents->required_types, exprs))
<< "Func \"" << name() << "\" is constrained to only hold values of type " << PrintTypeList(contents->required_types)
<< " but is defined with values of type " << PrintTypeList(exprs) << ".\n";
}
}

void Function::check_types(const std::vector<Type> &types) const {
if (!contents->required_types.empty()) {
user_assert(contents->required_types == types)
<< "Func \"" << name() << "\" is constrained to only hold values of type " << PrintTypeList(contents->required_types)
<< " but is defined with values of type " << PrintTypeList(types) << ".\n";
}
}

void Function::check_dims(int dims) const {
if (contents->required_dims != AnyDims) {
user_assert(contents->required_dims == dims)
<< "Func \"" << name() << "\" is constrained to have exactly " << contents->required_dims
<< " dimensions, but is defined with " << dims << " dimensions.\n";
}
}

namespace {

// Return deep-copy of ExternFuncArgument 'src'
ExternFuncArgument deep_copy_extern_func_argument_helper(
const ExternFuncArgument &src, DeepCopyMap &copied_map) {
ExternFuncArgument deep_copy_extern_func_argument_helper(const ExternFuncArgument &src,
DeepCopyMap &copied_map) {
ExternFuncArgument copy;
copy.arg_type = src.arg_type;
copy.buffer = src.buffer;
Expand All @@ -330,6 +439,8 @@ ExternFuncArgument deep_copy_extern_func_argument_helper(
return copy;
}

} // namespace

void Function::deep_copy(const FunctionPtr &copy, DeepCopyMap &copied_map) const {
internal_assert(copy.defined() && contents.defined())
<< "Cannot deep-copy undefined Function\n";
Expand Down Expand Up @@ -456,6 +567,8 @@ void Function::define(const vector<string> &args, vector<Expr> values) {
<< "In pure definition of Func \"" << name() << "\":\n"
<< "Func is already defined.\n";

check_types(values);
check_dims((int)args.size());
contents->args = args;

std::vector<Expr> init_def_args;
Expand Down Expand Up @@ -485,6 +598,11 @@ void Function::define(const vector<string> &args, vector<Expr> values) {
contents->output_types[i] = values[i].type();
}

if (!contents->required_types.empty()) {
// Just a reality check; mismatches here really should have been caught earlier
internal_assert(contents->required_types == contents->output_types);
}

for (size_t i = 0; i < values.size(); i++) {
string buffer_name = name();
if (values.size() > 1) {
Expand Down Expand Up @@ -703,6 +821,8 @@ void Function::define_extern(const std::string &function_name,
const std::vector<Var> &args,
NameMangling mangling,
DeviceAPI device_api) {
check_types(types);
check_dims((int)args.size());

user_assert(!has_pure_definition() && !has_update_definition())
<< "In extern definition for Func \"" << name() << "\":\n"
Expand Down
20 changes: 20 additions & 0 deletions src/Function.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
namespace Halide {

struct ExternFuncArgument;
class Tuple;

class Var;

Expand Down Expand Up @@ -57,6 +58,13 @@ class Function {
/** Construct a new function with the given name */
explicit Function(const std::string &n);

/** Construct a new function with the given name,
* with a requirement that it can only represent Expr(s) of the given type(s),
* and must have exactly the give nnumber of dimensions.
* required_types.empty() means there are no constraints on the type(s).
* required_dims == AnyDims means there are no constraints on the dimensions. */
explicit Function(const std::vector<Type> &required_types, int required_dims, const std::string &n);

/** Construct a Function from an existing FunctionContents pointer. Must be non-null */
explicit Function(const FunctionPtr &);

Expand Down Expand Up @@ -292,6 +300,18 @@ class Function {

/** Return true iff the name matches one of the Function's pure args. */
bool is_pure_arg(const std::string &name) const;

/** If the Function has type requirements, check that the given argument
* is compatible with them. If not, assert-fail. (If there are no type requirements, do nothing.) */
void check_types(const Expr &e) const;
void check_types(const Tuple &t) const;
void check_types(const Type &t) const;
void check_types(const std::vector<Expr> &exprs) const;
void check_types(const std::vector<Type> &types) const;

/** If the Function has dimension requirements, check that the given argument
* is compatible with them. If not, assert-fail. (If there are no dimension requirements, do nothing.) */
void check_dims(int dims) const;
};

/** Deep copy an entire Function DAG. */
Expand Down
2 changes: 2 additions & 0 deletions src/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,8 @@ Realization Pipeline::realize(JITUserContext *context,
user_assert(defined()) << "Pipeline is undefined\n";
vector<Buffer<>> bufs;
for (auto &out : contents->outputs) {
user_assert((int)sizes.size() == out.dimensions())
<< "Func " << out.name() << " is defined with " << out.dimensions() << " dimensions, but realize() is requesting a realization with " << sizes.size() << " dimensions.\n";
user_assert(out.has_pure_definition() || out.has_extern_definition()) << "Can't realize Pipeline with undefined output Func: " << out.name() << ".\n";
for (Type t : out.output_types()) {
bufs.emplace_back(t, nullptr, sizes);
Expand Down
8 changes: 8 additions & 0 deletions test/error/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ tests(GROUPS error
five_d_gpu_buffer.cpp
float_arg.cpp
forward_on_undefined_buffer.cpp
func_expr_dim_mismatch.cpp
func_expr_type_mismatch.cpp
func_expr_update_type_mismatch.cpp
func_extern_dim_mismatch.cpp
func_extern_type_mismatch.cpp
func_tuple_dim_mismatch.cpp
func_tuple_types_mismatch.cpp
func_tuple_update_types_mismatch.cpp
implicit_args.cpp
impossible_constraints.cpp
init_def_should_be_all_vars.cpp
Expand Down
Loading