Skip to content

Commit

Permalink
add channel-last test
Browse files Browse the repository at this point in the history
  • Loading branch information
mbrookhart committed Feb 26, 2021
1 parent f84ea99 commit 6f62717
Showing 1 changed file with 38 additions and 10 deletions.
48 changes: 38 additions & 10 deletions tests/python/relay/test_pass_simplify_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,22 +127,40 @@ def after_right(x, elem_op, value):
def test_simplify_conv_pad():
convs = [relay.nn.conv1d, relay.nn.conv2d, relay.nn.conv3d]

def validate(ndim, pad_width, pad_value, pad_mode, orig_padding):
shape = [1, 3] + [10] * ndim
wshape = [8, 3] + [3] * ndim
def validate(ndim, pad_width, pad_value, pad_mode, orig_padding, layout):
if layout[1] == "C":
shape = [1, 3] + [10] * ndim
wshape = [8, 3] + [3] * ndim
elif layout[-1] == "C":
shape = [1] + [10] * ndim + [3]
wshape = [8] + [3] * ndim + [3]
else:
raise ValueError("This test only supports NC* and N*C")

x = relay.var("x", shape=shape, dtype="float32")
w = relay.var("w", shape=wshape, dtype="float32")
pad = relay.nn.pad(x, pad_width, pad_value, pad_mode)
conv = convs[ndim - 1](pad, w, padding=orig_padding)
if layout[1] == "C":
conv = convs[ndim - 1](pad, w, padding=orig_padding)
else:
conv = convs[ndim - 1](
pad, w, padding=orig_padding, data_layout=layout, kernel_layout="DHWIO"[3 - ndim :]
)

if pad_mode == "constant" and pad_value == 0:
new_padding = []
for j in range(2):
for i in range(2, len(pad_width)):
new_padding.append(pad_width[i][j])
for i in range(len(pad_width)):
if layout[i] in ["D", "H", "W"]:
new_padding.append(pad_width[i][j])
for i in range(len(new_padding)):
new_padding[i] += orig_padding[i]
after = convs[ndim - 1](x, w, padding=new_padding)
if layout[1] == "C":
after = convs[ndim - 1](x, w, padding=new_padding)
else:
after = convs[ndim - 1](
x, w, padding=new_padding, data_layout=layout, kernel_layout="DHWIO"[3 - ndim :]
)
else:
after = conv

Expand All @@ -166,10 +184,20 @@ def validate(ndim, pad_width, pad_value, pad_mode, orig_padding):
for orig_pad in [[0, 0], [2, 0], [0, 2]]:
for i_pad in [[0, 0], [1, 1], [1, 0]]:
for ndim in [1, 2, 3]:
validate(ndim, [[0, 0]] * 2 + [i_pad] * ndim, 0, "constant", orig_pad * ndim)
for channels_last in [0, 1]:
if channels_last:
layout = "NDHWC"
layout = layout[0:1] + layout[4 - ndim : 4] + layout[-1:]
padding = [[0, 0]] + [i_pad] * ndim + [[0, 0]]
else:
layout = "NCDHW"
layout = layout[0:2] + layout[5 - ndim :]
padding = [[0, 0]] * 2 + [i_pad] * ndim

validate(ndim, padding, 0, "constant", orig_pad * ndim, layout)
ndim = 2
validate(ndim, [[0, 0]] * 2 + [i_pad] * ndim, 1, "constant", orig_pad * ndim)
validate(ndim, [[0, 0]] * 2 + [i_pad] * ndim, 0, "edge", orig_pad * ndim)
validate(ndim, [[0, 0]] * 2 + [i_pad] * ndim, 1, "constant", orig_pad * ndim, "NCHW")
validate(ndim, [[0, 0]] * 2 + [i_pad] * ndim, 0, "edge", orig_pad * ndim, "NCHW")


if __name__ == "__main__":
Expand Down

0 comments on commit 6f62717

Please sign in to comment.