From 8f111b0c121d9e0383ad8afcf559f2fc754a8c99 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 24 Oct 2019 14:40:49 -0700 Subject: [PATCH] Fix test --- python/tvm/relay/expr.py | 10 ++++++++-- src/relay/ir/expr.cc | 4 ++-- tests/python/relay/test_ir_nodes.py | 14 ++++++++++---- 3 files changed, 20 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 3237ddfb1dcd..8d59e99d8388 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -27,6 +27,7 @@ from .._ffi import base as _base from .. import nd as _nd from .. import convert +from ..ndarray import NDArray # will be registered afterwards _op_make = None @@ -305,10 +306,15 @@ def __call__(self, *args): """ return Call(self, args, None, None) - def get_params(self, params): - return _expr.FunctionGet(self, params) + def get_params(self): + return _expr.FunctionGetParams(self) def set_params(self, params): + for key in params: + value = params[key] + if isinstance(value, NDArray): + params[key] = Constant(value) + return _expr.FunctionSetParams(self, params) diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index dfa20709c7fe..c36b4c8566b8 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -163,7 +163,7 @@ Function FunctionNode::SetParams(const tvm::Map& parameters) cons return FunctionSetAttr(GetRef(this), "__params__", parameters); } -TVM_REGISTER_API("relay._expr.FunctionSetParms") +TVM_REGISTER_API("relay._expr.FunctionSetParams") .set_body_typed&)>( [](const Function& func, const tvm::Map& parameters) { return func->SetParams(parameters); @@ -174,7 +174,7 @@ tvm::Map FunctionNode::GetParams() const { return Downcast>(node_ref); } -TVM_REGISTER_API("relay._expr.FunctionGetParms") +TVM_REGISTER_API("relay._expr.FunctionGetParams") .set_body_typed(const Function&)>([](const Function& func) { return func->GetParams(); }); diff --git a/tests/python/relay/test_ir_nodes.py b/tests/python/relay/test_ir_nodes.py index 69f087608b97..dec840a214a0 100644 --- a/tests/python/relay/test_ir_nodes.py +++ b/tests/python/relay/test_ir_nodes.py @@ -160,7 +160,6 @@ def test_global_var(): str(gv) check_json_roundtrip(gv) - def test_function(): param_names = ['a', 'b', 'c', 'd'] params = tvm.convert([relay.Var(n) for n in param_names]) @@ -184,7 +183,8 @@ def test_function_attrs(): fn = relay.Function(params, body, ret_type, type_params) model_params = {} for param in params[:1]: - tensor = np.random.rand(*param.shape).astype(param.dtype) + cty = param.type_annotation + tensor = np.random.rand(*[int(sh) for sh in cty.shape]).astype(cty.dtype) model_params[param] = tvm.nd.array(tensor) fn = fn.set_params(model_params) assert fn.params == params @@ -196,8 +196,12 @@ def test_function_attrs(): json_str = tvm.save_json(fn) fn_after = tvm.load_json(json_str) model_params_after = fn_after.get_params() - for p1, p2 in zip(model_params, model_params_after): - assert p1.asnumpy() == p2.asnumpy() + after_keys = [item[0] for item in model_params_after.items()] + for key1, key2 in zip(model_params, after_keys): + assert key1.name_hint == key2.name_hint + p1 = model_params[key1] + p2 = model_params_after[key2] + np.testing.assert_allclose(p1.data.asnumpy(), p2.data.asnumpy()) def test_call(): op = relay.Var('f') @@ -280,9 +284,11 @@ def test_conv2d_attrs(): test_local_var() test_global_var() test_function() + test_function_attrs() test_call() test_let() test_if() test_tuple_get_item() test_op() test_conv2d_attrs() +