Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Test target shape
Browse files Browse the repository at this point in the history
  • Loading branch information
vandanavk committed Mar 16, 2019
1 parent ae0e081 commit c332006
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1503,20 +1503,21 @@ def _init_bilinear(arr, f):
arr[:] = weight.reshape(shape)
return arr

arr = {'data': mx.random.uniform(-10.0, 10.0, data_shape, ctx=mx.cpu()).copyto(default_context()),
'weight': mx.nd.array(_init_bilinear(mx.ndarray.empty(weight_shape).asnumpy(), root_scale))}

out_grad = arr['data']
data = mx.sym.Variable(name="data")
up = mx.sym.UpSampling(data,
up = mx.sym.UpSampling(mx.sym.Variable("data"),
mx.sym.Variable('weight'), sample_type='bilinear', scale=root_scale,
num_filter=num_filter, num_args=2)
arg_shapes, out_shapes, _ = up.infer_shape(data=data_shape)
arr = {'data': mx.random.uniform(-5, 5, data_shape, ctx=mx.cpu()).copyto(default_context()),
'weight': mx.nd.array(_init_bilinear(mx.ndarray.empty(arg_shapes[1]).asnumpy(), root_scale))}

arr_grad = [mx.nd.empty(s) for s in arg_shapes]
exe = up.bind(default_context(), args=arr, args_grad=arr_grad)
exe.forward(is_train=True)
out = exe.outputs[0].asnumpy()
exe.backward(out_grad)
exe.backward(exe.outputs)
target_shape = (data_shape[2] * root_scale, data_shape[3] * root_scale)
assert out.shape == data_shape[:2] + target_shape


@with_seed()
def test_nearest_upsampling():
Expand Down

0 comments on commit c332006

Please sign in to comment.