Skip to content

Commit

Permalink
Fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch committed Oct 24, 2019
1 parent b179697 commit 6f81479
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 8 deletions.
10 changes: 8 additions & 2 deletions python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .._ffi import base as _base
from .. import nd as _nd
from .. import convert
from ..ndarray import NDArray

# will be registered afterwards
_op_make = None
Expand Down Expand Up @@ -305,10 +306,15 @@ def __call__(self, *args):
"""
return Call(self, args, None, None)

def get_params(self, params):
return _expr.FunctionGet(self, params)
def get_params(self):
return _expr.FunctionGetParams(self)

def set_params(self, params):
for key in params:
value = params[key]
if isinstance(value, NDArray):
params[key] = Constant(value)

return _expr.FunctionSetParams(self, params)


Expand Down
4 changes: 2 additions & 2 deletions src/relay/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ Function FunctionNode::SetParams(const tvm::Map<Var, Constant>& parameters) cons
return FunctionSetAttr(GetRef<Function>(this), "__params__", parameters);
}

TVM_REGISTER_API("relay._expr.FunctionSetParms")
TVM_REGISTER_API("relay._expr.FunctionSetParams")
.set_body_typed<Function(const Function&, const tvm::Map<Var, Constant>&)>(
[](const Function& func, const tvm::Map<Var, Constant>& parameters) {
return func->SetParams(parameters);
Expand All @@ -174,7 +174,7 @@ tvm::Map<Var, Constant> FunctionNode::GetParams() const {
return Downcast<tvm::Map<Var, Constant>>(node_ref);
}

TVM_REGISTER_API("relay._expr.FunctionGetParms")
TVM_REGISTER_API("relay._expr.FunctionGetParams")
.set_body_typed<tvm::Map<Var, Constant>(const Function&)>([](const Function& func) {
return func->GetParams();
});
Expand Down
14 changes: 10 additions & 4 deletions tests/python/relay/test_ir_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ def test_global_var():
str(gv)
check_json_roundtrip(gv)


def test_function():
param_names = ['a', 'b', 'c', 'd']
params = tvm.convert([relay.Var(n) for n in param_names])
Expand All @@ -184,7 +183,8 @@ def test_function_attrs():
fn = relay.Function(params, body, ret_type, type_params)
model_params = {}
for param in params[:1]:
tensor = np.random.rand(*param.shape).astype(param.dtype)
cty = param.type_annotation
tensor = np.random.rand(*[int(sh) for sh in cty.shape]).astype(cty.dtype)
model_params[param] = tvm.nd.array(tensor)
fn = fn.set_params(model_params)
assert fn.params == params
Expand All @@ -196,8 +196,12 @@ def test_function_attrs():
json_str = tvm.save_json(fn)
fn_after = tvm.load_json(json_str)
model_params_after = fn_after.get_params()
for p1, p2 in zip(model_params, model_params_after):
assert p1.asnumpy() == p2.asnumpy()
after_keys = [item[0] for item in model_params_after.items()]
for key1, key2 in zip(model_params, after_keys):
assert key1.name_hint == key2.name_hint
p1 = model_params[key1]
p2 = model_params_after[key2]
np.testing.assert_allclose(p1.data.asnumpy(), p2.data.asnumpy())

def test_call():
op = relay.Var('f')
Expand Down Expand Up @@ -280,9 +284,11 @@ def test_conv2d_attrs():
test_local_var()
test_global_var()
test_function()
test_function_attrs()
test_call()
test_let()
test_if()
test_tuple_get_item()
test_op()
test_conv2d_attrs()

0 comments on commit 6f81479

Please sign in to comment.