diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 499d98ab5ded..2579749a351f 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1503,20 +1503,21 @@ def _init_bilinear(arr, f): arr[:] = weight.reshape(shape) return arr - arr = {'data': mx.random.uniform(-10.0, 10.0, data_shape, ctx=mx.cpu()).copyto(default_context()), - 'weight': mx.nd.array(_init_bilinear(mx.ndarray.empty(weight_shape).asnumpy(), root_scale))} - - out_grad = arr['data'] - data = mx.sym.Variable(name="data") - up = mx.sym.UpSampling(data, + up = mx.sym.UpSampling(mx.sym.Variable("data"), mx.sym.Variable('weight'), sample_type='bilinear', scale=root_scale, num_filter=num_filter, num_args=2) arg_shapes, out_shapes, _ = up.infer_shape(data=data_shape) + arr = {'data': mx.random.uniform(-5, 5, data_shape, ctx=mx.cpu()).copyto(default_context()), + 'weight': mx.nd.array(_init_bilinear(mx.ndarray.empty(arg_shapes[1]).asnumpy(), root_scale))} + arr_grad = [mx.nd.empty(s) for s in arg_shapes] exe = up.bind(default_context(), args=arr, args_grad=arr_grad) exe.forward(is_train=True) out = exe.outputs[0].asnumpy() - exe.backward(out_grad) + exe.backward(exe.outputs) + target_shape = (data_shape[2] * root_scale, data_shape[3] * root_scale) + assert out.shape == data_shape[:2] + target_shape + @with_seed() def test_nearest_upsampling():