diff --git a/src/relay/ir/type_functor.cc b/src/relay/ir/type_functor.cc index 9fca2e0326859..516f4c875b20c 100644 --- a/src/relay/ir/type_functor.cc +++ b/src/relay/ir/type_functor.cc @@ -221,7 +221,7 @@ class TypeBinder : public TypeMutator { }; Type Bind(const Type& type, const tvm::Map& args_map) { - return TypeBinder(args_map).VisitType(type); + return type.defined() ? TypeBinder(args_map).VisitType(type) : type; } } // namespace relay diff --git a/tests/python/relay/test_pass_to_cps.py b/tests/python/relay/test_pass_to_cps.py index 1dfeb607e6c0d..0078fe0a2105f 100644 --- a/tests/python/relay/test_pass_to_cps.py +++ b/tests/python/relay/test_pass_to_cps.py @@ -69,6 +69,7 @@ def test_ad_cps_pe(): print(F) relay.ir_pass.infer_type(F) F = relay.ir_pass.partial_evaluate(F) + F = relay.ir_pass.dead_code_elimination(F) print(F) if __name__ == '__main__':