@@ -1493,43 +1493,19 @@ def check_nearest_upsampling_with_shape(shapes, scale, root_scale):
1493
1493
1494
1494
1495
1495
def check_bilinear_upsampling_with_shape (data_shape , weight_shape , scale , root_scale , num_filter ):
1496
- def py_bilinear_resize (x , outputHeight , outputWidth ):
1497
- batch , channel , inputHeight , inputWidth = x .shape
1498
- if outputHeight == inputHeight and outputWidth == inputWidth :
1499
- return x
1500
- y = np .empty ([batch , channel , outputHeight , outputWidth ])
1501
- rheight = 1.0 * (inputHeight - 1 ) / (outputHeight - 1 ) if outputHeight > 1 else 0.0
1502
- rwidth = 1.0 * (inputWidth - 1 ) / (outputWidth - 1 ) if outputWidth > 1 else 0.0
1503
- for h2 in range (outputHeight ):
1504
- h1r = 1.0 * h2 * rheight
1505
- h1 = int (np .floor (h1r ))
1506
- h1lambda = h1r - h1
1507
- h1p = 1 if h1 < (inputHeight - 1 ) else 0
1508
- for w2 in range (outputWidth ):
1509
- w1r = 1.0 * w2 * rwidth
1510
- w1 = int (np .floor (w1r ))
1511
- w1lambda = w1r - w1
1512
- w1p = 1 if w1 < (inputHeight - 1 ) else 0
1513
- for b in range (batch ):
1514
- for c in range (channel ):
1515
- y [b ][c ][h2 ][w2 ] = (1 - h1lambda )* ((1 - w1lambda )* x [b ][c ][h1 ][w1 ] + \
1516
- w1lambda * x [b ][c ][h1 ][w1 + w1p ]) + \
1517
- h1lambda * ((1 - w1lambda )* x [b ][c ][h1 + h1p ][w1 ] + \
1518
- w1lambda * x [b ][c ][h1 + h1p ][w1 + w1p ])
1519
- return y
1520
- def _init_bilinear (arr ):
1496
+ def _init_bilinear (arr , f ):
1521
1497
weight = np .zeros (np .prod (arr .shape ), dtype = 'float32' )
1522
1498
shape = arr .shape
1523
- f = np .ceil (shape [3 ] / 2. )
1524
1499
c = (2 * f - 1 - f % 2 ) / (2. * f )
1525
1500
for i in range (np .prod (shape )):
1526
1501
x = i % shape [3 ]
1527
1502
y = (i // shape [3 ]) % shape [2 ]
1528
1503
weight [i ] = (1 - abs (x / f - c )) * (1 - abs (y / f - c ))
1529
1504
arr [:] = weight .reshape (shape )
1530
1505
return arr
1506
+
1531
1507
arr = {'data' : mx .random .uniform (- 10.0 , 10.0 , data_shape , ctx = mx .cpu ()).copyto (default_context ()),
1532
- 'weight' : mx .nd .array (_init_bilinear (mx .ndarray .empty (weight_shape ).asnumpy ()))}
1508
+ 'weight' : mx .nd .array (_init_bilinear (mx .ndarray .empty (weight_shape ).asnumpy (), root_scale ))}
1533
1509
1534
1510
up = mx .sym .UpSampling (mx .sym .Variable ('data' ),
1535
1511
mx .sym .Variable ('weight' ), sample_type = 'bilinear' , scale = root_scale ,
@@ -1540,7 +1516,6 @@ def _init_bilinear(arr):
1540
1516
exe .forward (is_train = True )
1541
1517
out = exe .outputs [0 ].asnumpy ()
1542
1518
exe .backward (exe .outputs )
1543
- assert_allclose (out , py_bilinear_resize (arr ['data' ].asnumpy (), data_shape [2 ]* root_scale , data_shape [3 ]* root_scale ), rtol = 1e-4 )
1544
1519
1545
1520
1546
1521
@with_seed ()
0 commit comments