From 3d2e54dd1ee711604e71fcc8367799eadaed4965 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Tue, 16 Jul 2019 14:13:19 -0700 Subject: [PATCH] fix --- src/relay/pass/gradient.cc | 2 +- tests/python/relay/test_pass_gradient.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relay/pass/gradient.cc b/src/relay/pass/gradient.cc index 1abe7a94b621..4eb08f8779e6 100644 --- a/src/relay/pass/gradient.cc +++ b/src/relay/pass/gradient.cc @@ -247,7 +247,7 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) { return FunctionNode::make(f->params, body, GradRetType(GetRef(f)), {}); } -TVM_REGISTER_API("relay._analysis.first_order_gradient") +TVM_REGISTER_API("relay._transform.first_order_gradient") .set_body_typed(FirstOrderGradient); struct ReverseADType : TypeMutator { diff --git a/tests/python/relay/test_pass_gradient.py b/tests/python/relay/test_pass_gradient.py index 3fc1d74de876..3da2436aaa75 100644 --- a/tests/python/relay/test_pass_gradient.py +++ b/tests/python/relay/test_pass_gradient.py @@ -35,7 +35,7 @@ def test_id(): t = relay.TensorType(shape, dtype) x = relay.var("x", t) func = relay.Function([x], x) - back_func = run_infer_type(gradient(func)) + back_func = run_infer_type(gradient(func, mode="first_order")) assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) ex = create_executor() x = rand(dtype, *shape)