Skip to content

Commit

Permalink
[Relay] Ensure nested higher-order functions are treated correctly (a…
Browse files Browse the repository at this point in the history
  • Loading branch information
slyubomirsky authored and wweic committed Mar 12, 2019
1 parent d85e780 commit 8332af8
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 4 deletions.
4 changes: 1 addition & 3 deletions python/tvm/relay/prelude.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,9 +394,7 @@ def define_iterate(self):
f = Var("f", FuncType([a], a))
x = Var("x", self.nat())
y = Var("y", self.nat())
z = Var("z")
z_case = Clause(PatternConstructor(self.z), Function([z], z))
# todo: fix typechecker so Function([z], z) can be replaced by self.id
z_case = Clause(PatternConstructor(self.z), self.id)
s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]),
self.compose(f, self.iterate(f, y)))
self.mod[self.iterate] = Function([f, x],
Expand Down
29 changes: 28 additions & 1 deletion src/relay/pass/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,17 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
Type Unify(const Type& t1, const Type& t2, const NodeRef& expr) {
// TODO(tqchen, jroesch): propagate span to solver
try {
return solver_.Unify(t1, t2, expr);
// instantiate higher-order func types when unifying because
// we only allow polymorphism at the top level
Type first = t1;
Type second = t2;
if (auto* ft1 = t1.as<FuncTypeNode>()) {
first = InstantiateFuncType(ft1);
}
if (auto* ft2 = t2.as<FuncTypeNode>()) {
second = InstantiateFuncType(ft2);
}
return solver_.Unify(first, second, expr);
} catch (const dmlc::Error &e) {
this->ReportFatalError(
expr,
Expand Down Expand Up @@ -351,6 +361,20 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
return Downcast<FuncType>(inst_ty);
}

// instantiates starting from incompletes
FuncType InstantiateFuncType(const FuncTypeNode* fn_ty) {
if (fn_ty->type_params.size() == 0) {
return GetRef<FuncType>(fn_ty);
}

Array<Type> type_args;
for (size_t i = 0; i < fn_ty->type_params.size(); i++) {
type_args.push_back(IncompleteTypeNode::make(Kind::kType));
}
return InstantiateFuncType(fn_ty, type_args);
}


void AddTypeArgs(const Expr& expr, Array<Type> type_args) {
auto type_info = type_map_.find(expr);
if (type_info == type_map_.end()) {
Expand Down Expand Up @@ -464,6 +488,9 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
arg_types.push_back(GetType(param));
}
Type rtype = GetType(f->body);
if (auto* ft = rtype.as<FuncTypeNode>()) {
rtype = InstantiateFuncType(ft);
}
if (f->ret_type.defined()) {
rtype = this->Unify(f->ret_type, rtype, GetRef<Function>(f));
}
Expand Down
52 changes: 52 additions & 0 deletions tests/python/relay/test_type_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,58 @@ def test_incomplete_call():
assert ft.checked_type == relay.FuncType([tt, f_type], tt)


def test_higher_order_argument():
a = relay.TypeVar('a')
x = relay.Var('x', a)
id_func = relay.Function([x], x, a, [a])

b = relay.TypeVar('b')
f = relay.Var('f', relay.FuncType([b], b))
y = relay.Var('y', b)
ho_func = relay.Function([f, y], f(y), b, [b])

# id func should be an acceptable argument to the higher-order
# function even though id_func takes a type parameter
ho_call = ho_func(id_func, relay.const(0, 'int32'))

hc = relay.ir_pass.infer_type(ho_call)
expected = relay.scalar_type('int32')
assert hc.checked_type == expected


def test_higher_order_return():
a = relay.TypeVar('a')
x = relay.Var('x', a)
id_func = relay.Function([x], x, a, [a])

b = relay.TypeVar('b')
nested_id = relay.Function([], id_func, relay.FuncType([b], b), [b])

ft = relay.ir_pass.infer_type(nested_id)
assert ft.checked_type == relay.FuncType([], relay.FuncType([b], b), [b])


def test_higher_order_nested():
a = relay.TypeVar('a')
x = relay.Var('x', a)
id_func = relay.Function([x], x, a, [a])

choice_t = relay.FuncType([], relay.scalar_type('bool'))
f = relay.Var('f', choice_t)

b = relay.TypeVar('b')
z = relay.Var('z')
top = relay.Function(
[f],
relay.If(f(), id_func, relay.Function([z], z)),
relay.FuncType([b], b),
[b])

expected = relay.FuncType([choice_t], relay.FuncType([b], b), [b])
ft = relay.ir_pass.infer_type(top)
assert ft.checked_type == expected


def test_tuple():
tp = relay.TensorType((10,))
x = relay.var("x", tp)
Expand Down

0 comments on commit 8332af8

Please sign in to comment.