Skip to content

Commit

Permalink
speed up reference resize kernel (#8592)
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Brookhart authored Jul 30, 2021
1 parent df96cba commit 22c7d61
Showing 1 changed file with 41 additions and 59 deletions.
100 changes: 41 additions & 59 deletions python/tvm/topi/testing/resize_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,51 +66,52 @@ def resize3d_nearest(arr, scale, coordinate_transformation_mode):

def resize3d_linear(data_in, scale, coordinate_transformation_mode):
"""Trilinear 3d scaling using python"""
dtype = data_in.dtype
d, h, w = data_in.shape
new_d, new_h, new_w = [int(round(i * s)) for i, s in zip(data_in.shape, scale)]
data_out = np.ones((new_d, new_h, new_w))

def _lerp(A, B, t):
return A * (1.0 - t) + B * t
indexes = np.mgrid[0:2, 0:2, 0:2]

def _in_coord(new_coord, in_shape, out_shape):
in_coord = get_inx(new_coord, in_shape, out_shape, coordinate_transformation_mode)
coord0 = int(math.floor(in_coord))
coord1 = max(min(coord0 + 1, in_shape - 1), 0)
coord0 = max(coord0, 0)
coord_lerp = in_coord - math.floor(in_coord)
return coord0, coord1, coord_lerp
def _get_patch(zint, yint, xint):
# Get the surrounding values
indices = indexes.copy()
indices[0] = np.maximum(np.minimum(indexes[0] + zint, d - 1), 0)
indices[1] = np.maximum(np.minimum(indexes[1] + yint, h - 1), 0)
indices[2] = np.maximum(np.minimum(indexes[2] + xint, w - 1), 0)
p = data_in[indices[0], indices[1], indices[2]]
return p

for m in range(new_d):
for j in range(new_h):
for k in range(new_w):
z0, z1, z_lerp = _in_coord(m, d, new_d)
y0, y1, y_lerp = _in_coord(j, h, new_h)
x0, x1, x_lerp = _in_coord(k, w, new_w)

A0 = data_in[z0][y0][x0]
B0 = data_in[z0][y0][x1]
C0 = data_in[z0][y1][x0]
D0 = data_in[z0][y1][x1]
A1 = data_in[z1][y0][x0]
B1 = data_in[z1][y0][x1]
C1 = data_in[z1][y1][x0]
D1 = data_in[z1][y1][x1]

A = _lerp(A0, A1, z_lerp)
B = _lerp(B0, B1, z_lerp)
C = _lerp(C0, C1, z_lerp)
D = _lerp(D0, D1, z_lerp)
top = _lerp(A, B, x_lerp)
bottom = _lerp(C, D, x_lerp)

data_out[m][j][k] = np.float32(_lerp(top, bottom, y_lerp))
in_z = get_inx(m, d, new_d, coordinate_transformation_mode)
in_y = get_inx(j, h, new_h, coordinate_transformation_mode)
in_x = get_inx(k, w, new_w, coordinate_transformation_mode)
zint = math.floor(in_z)
zfract = in_z - math.floor(in_z)

yint = math.floor(in_y)
yfract = in_y - math.floor(in_y)

xint = math.floor(in_x)
xfract = in_x - math.floor(in_x)

wz = np.array([1.0 - zfract, zfract], dtype=dtype)
wy = np.array([1.0 - yfract, yfract], dtype=dtype)
wx = np.array([1.0 - xfract, xfract], dtype=dtype)

p = _get_patch(zint, yint, xint)
l = np.sum(p * wx, axis=-1)
col = np.sum(l * wy, axis=-1)
data_out[m, j, k] = np.sum(col * wz)

return data_out


def resize3d_cubic(data_in, scale, coordinate_transformation_mode):
"""Tricubic 3d scaling using python"""
dtype = data_in.dtype
d, h, w = data_in.shape
new_d, new_h, new_w = [int(round(i * s)) for i, s in zip(data_in.shape, scale)]
data_out = np.ones((new_d, new_h, new_w))
Expand All @@ -123,29 +124,17 @@ def _cubic_spline_weights(t, alpha=-0.5):
w2 = (alpha + 2) * t3 - (3 + alpha) * t2 + 1
w3 = -(alpha + 2) * t3 + (3 + 2 * alpha) * t2 - alpha * t
w4 = -alpha * t3 + alpha * t2
return [w1, w2, w3, w4]
return np.array([w1, w2, w3, w4])

def _cubic_kernel(inputs, w):
"""perform cubic interpolation in 1D"""
return sum([a_i * w_i for a_i, w_i in zip(inputs, w)])

def _get_input_value(z, y, x):
z = max(min(z, d - 1), 0)
y = max(min(y, h - 1), 0)
x = max(min(x, w - 1), 0)
return data_in[z][y][x]
indexes = np.mgrid[-1:3, -1:3, -1:3]

def _get_patch(zint, yint, xint):
# Get the surrounding values
p = [[[0 for i in range(4)] for j in range(4)] for k in range(4)]
for kk in range(4):
for jj in range(4):
for ii in range(4):
p[kk][jj][ii] = _get_input_value(
zint + kk - 1,
yint + jj - 1,
xint + ii - 1,
)
indices = indexes.copy()
indices[0] = np.maximum(np.minimum(indexes[0] + zint, d - 1), 0)
indices[1] = np.maximum(np.minimum(indexes[1] + yint, h - 1), 0)
indices[2] = np.maximum(np.minimum(indexes[2] + xint, w - 1), 0)
p = data_in[indices[0], indices[1], indices[2]]
return p

for m in range(new_d):
Expand All @@ -169,16 +158,9 @@ def _get_patch(zint, yint, xint):

p = _get_patch(zint, yint, xint)

l = [[0 for i in range(4)] for j in range(4)]
for jj in range(4):
for ii in range(4):
l[jj][ii] = _cubic_kernel(p[jj][ii], wx)

col0 = _cubic_kernel(l[0], wy)
col1 = _cubic_kernel(l[1], wy)
col2 = _cubic_kernel(l[2], wy)
col3 = _cubic_kernel(l[3], wy)
data_out[m][j][k] = _cubic_kernel([col0, col1, col2, col3], wz)
l = np.sum(p * wx, axis=-1)
col = np.sum(l * wy, axis=-1)
data_out[m, j, k] = np.sum(col * wz)

return data_out

Expand Down

0 comments on commit 22c7d61

Please sign in to comment.