From cc02471a788bf3788c423f878763130acf434e0e Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Fri, 28 Jun 2019 08:34:41 -0700 Subject: [PATCH] me find type checker problem --- src/relay/pass/type_infer.cc | 12 +----------- tests/python/relay/test_pass_to_cps.py | 2 -- 2 files changed, 1 insertion(+), 13 deletions(-) diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 4b126e5299cfd..4cf8baa93fbd5 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -141,17 +141,7 @@ class TypeInferencer : private ExprFunctor, Type Unify(const Type& t1, const Type& t2, const NodeRef& expr) { // TODO(tqchen, jroesch): propagate span to solver try { - // instantiate higher-order func types when unifying because - // we only allow polymorphism at the top level - Type first = t1; - Type second = t2; - if (auto* ft1 = t1.as()) { - first = InstantiateFuncType(ft1); - } - if (auto* ft2 = t2.as()) { - second = InstantiateFuncType(ft2); - } - return solver_.Unify(first, second, expr); + return solver_.Unify(t1, t2, expr); } catch (const dmlc::Error &e) { this->ReportFatalError( expr, diff --git a/tests/python/relay/test_pass_to_cps.py b/tests/python/relay/test_pass_to_cps.py index 5acf52dfaac0a..1dfeb607e6c0d 100644 --- a/tests/python/relay/test_pass_to_cps.py +++ b/tests/python/relay/test_pass_to_cps.py @@ -58,7 +58,6 @@ def test_ad_cps_pe(): stmt = relay.Let(f_ref, relay.RefCreate(unit), relay.Let(relay.Var("x"), if_expr, relay.Call(relay.RefRead(f_ref), []))) - stmt = unit# relay.RefCreate(unit) F = relay.Function([cond], stmt) print(F) @@ -69,7 +68,6 @@ def test_ad_cps_pe(): F = relay.ir_pass.to_cps(F) print(F) relay.ir_pass.infer_type(F) - raise F = relay.ir_pass.partial_evaluate(F) print(F)