@@ -1526,17 +1526,32 @@ def check_nearest_upsampling_with_shape(shapes, scale, root_scale):
1526
1526
assert_allclose (arr [name ].asnumpy ()* root_scale ** 2 * scale ** (2 * k ), arr_grad [name ].asnumpy (), rtol = 1e-4 )
1527
1527
1528
1528
1529
- def check_bilinear_upsampling_with_shape (shapes , scale , root_scale ):
1530
- arr = {'arg_%d' % i : mx .random .uniform (- 10.0 , 10.0 , shape , ctx = mx .cpu ()).copyto (default_context ()) for i , shape in zip (range (len (shapes )), shapes )}
1531
- arr_grad = {'arg_%d' % i : mx .nd .zeros (shape ) for i , shape in zip (range (len (shapes )), shapes )}
1532
-
1533
- up = mx .sym .UpSampling (* [mx .sym .Variable ('arg_%d' % i ) for i in range (len (shapes ))], sample_type = 'bilinear' , scale = root_scale )
1529
+ def check_bilinear_upsampling_with_shape (data_shape , weight_shape , scale , root_scale , num_filter ):
1530
+ def _init_bilinear (arr , f ):
1531
+ weight = np .zeros (np .prod (arr .shape ), dtype = 'float32' )
1532
+ shape = arr .shape
1533
+ c = (2 * f - 1 - f % 2 ) / (2. * f )
1534
+ for i in range (np .prod (shape )):
1535
+ x = i % shape [3 ]
1536
+ y = (i // shape [3 ]) % shape [2 ]
1537
+ weight [i ] = (1 - abs (x / f - c )) * (1 - abs (y / f - c ))
1538
+ arr [:] = weight .reshape (shape )
1539
+ return arr
1540
+
1541
+ up = mx .sym .UpSampling (mx .sym .Variable ("data" ),
1542
+ mx .sym .Variable ('weight' ), sample_type = 'bilinear' , scale = root_scale ,
1543
+ num_filter = num_filter , num_args = 2 )
1544
+ arg_shapes , out_shapes , _ = up .infer_shape (data = data_shape )
1545
+ arr = {'data' : mx .random .uniform (- 5 , 5 , data_shape , ctx = mx .cpu ()).copyto (default_context ()),
1546
+ 'weight' : mx .nd .array (_init_bilinear (mx .ndarray .empty (arg_shapes [1 ]).asnumpy (), root_scale ))}
1547
+
1548
+ arr_grad = [mx .nd .empty (s ) for s in arg_shapes ]
1534
1549
exe = up .bind (default_context (), args = arr , args_grad = arr_grad )
1535
1550
exe .forward (is_train = True )
1551
+ out = exe .outputs [0 ].asnumpy ()
1536
1552
exe .backward (exe .outputs )
1537
- for k in range (len (shapes )):
1538
- name = 'arg_%d' % k
1539
- assert_allclose (arr [name ].asnumpy ()* root_scale ** 2 * scale ** (2 * k ), arr_grad [name ].asnumpy (), rtol = 1e-4 )
1553
+ target_shape = (data_shape [2 ] * root_scale , data_shape [3 ] * root_scale )
1554
+ assert out .shape == data_shape [:2 ] + target_shape
1540
1555
1541
1556
1542
1557
@with_seed ()
@@ -1549,6 +1564,22 @@ def test_nearest_upsampling():
1549
1564
check_nearest_upsampling_with_shape (shapes , scale , root_scale )
1550
1565
1551
1566
1567
+ @with_seed ()
1568
+ def test_bilinear_upsampling ():
1569
+ rootscale = [2 ,3 ]
1570
+ scales = [1 ,2 ,3 ]
1571
+ filters = [1 ,2 ,3 ]
1572
+ bases = [1 ,2 ,3 ]
1573
+ for params in itertools .product (rootscale , scales , filters , bases ):
1574
+ root_scale , scale , num_filter , base = params
1575
+ # bilinear upsampling takes only 1 data and 1 weight
1576
+ # multi input mode is not applicable
1577
+ dimension = base * root_scale * scale
1578
+ kernel = 2 * root_scale - root_scale % 2
1579
+ data_shape = (1 , num_filter , dimension , dimension )
1580
+ weight_shape = (1 , num_filter , kernel , kernel )
1581
+ check_bilinear_upsampling_with_shape (data_shape , weight_shape , scale , root_scale , num_filter )
1582
+
1552
1583
@with_seed ()
1553
1584
def test_batchnorm_training ():
1554
1585
def check_batchnorm_training (stype ):
0 commit comments