Skip to content

Commit

Permalink
[Relay][ConvertLayout] Support for qnn.conv2d_transpose
Browse files Browse the repository at this point in the history
  • Loading branch information
abraham-arun committed Sep 28, 2021
1 parent 9df2ae8 commit 89deb54
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 0 deletions.
48 changes: 48 additions & 0 deletions python/tvm/relay/qnn/op/layout_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,51 @@ def convert_qnn_conv2d(attrs, inputs, tinfos, desired_layouts):
return relay.qnn.op.conv2d(*inputs, **new_attrs)

raise ValueError("Layout %s is not yet supported" % desired_data_layout)


@reg.register_convert_op_layout("qnn.conv2d_transpose")
def convert_qnn_conv2d_transpose(attrs, inputs, tinfos, desired_layouts):
"""Convert Layout pass registration for QNN conv2d_transpose op.
Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
tinfos : list of types
List of input and output types
desired_layouts : list of layout strings
List of layouts defining our desired
layout for the data and kernel inputs respectively.
Returns
-------
result : tvm.relay.Expr
The transformed expr
"""
# pylint: disable=import-outside-toplevel
from tvm import relay

assert (
len(desired_layouts) == 2
), "A desired layout is expected for both of qnn.conv2d_transpose's inputs"
desired_data_layout, desired_kernel_layout = map(str, desired_layouts)
assert desired_data_layout != "default", "Data layout cannot be default"

new_attrs = dict(attrs)
new_attrs["data_layout"] = desired_data_layout

if desired_kernel_layout != "default":
new_attrs["kernel_layout"] = desired_kernel_layout
return relay.qnn.op.conv2d_transpose(*inputs, **new_attrs)

# Handle default kernel layouts
if desired_data_layout == "NCHW":
new_attrs["kernel_layout"] = "OIHW"
return relay.qnn.op.conv2d_transpose(*inputs, **new_attrs)
if desired_data_layout == "NHWC":
new_attrs["kernel_layout"] = "HWIO"
return relay.qnn.op.conv2d_transpose(*inputs, **new_attrs)

raise ValueError("Layout %s is not yet supported" % desired_data_layout)
68 changes: 68 additions & 0 deletions tests/python/relay/test_pass_convert_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,6 +1100,74 @@ def expected():
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)


def test_qnn_conv_transpose_requantize_convert_layout():
def before():
x = relay.var("x", shape=(1, 56, 56, 64), dtype="int8")
weight = relay.var("weight", shape=(3, 3, 64, 64), dtype="int8")
y = relay.qnn.op.conv2d_transpose(
x,
weight,
relay.const(1, "int32"),
relay.const(1, "int32"),
relay.const(1, "float32"),
relay.const(1, "float32"),
channels=64,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NHWC",
kernel_layout="HWIO",
out_dtype="int32",
)
y = relay.qnn.op.requantize(
y,
relay.const(1, "float32"),
relay.const(1, "int32"),
relay.const(1, "float32"),
relay.const(1, "int32"),
out_dtype="int32",
)
y = relay.nn.relu(y)
y = relay.Function([x, weight], y)
return y

def expected():
x = relay.var("x", shape=(1, 56, 56, 64), dtype="int8")
weight = relay.var("weight", shape=(3, 3, 64, 64), dtype="int8")
x = relay.layout_transform(x, "NHWC", "NCHW")
weight = relay.layout_transform(weight, "HWIO", "OIHW")
y = relay.qnn.op.conv2d_transpose(
x,
weight,
relay.const(1, "int32"),
relay.const(1, "int32"),
relay.const(1, "float32"),
relay.const(1, "float32"),
channels=64,
kernel_size=(3, 3),
padding=(1, 1),
out_dtype="int32",
)
y = relay.qnn.op.requantize(
y,
relay.const(1, "float32"),
relay.const(1, "int32"),
relay.const(1, "float32"),
relay.const(1, "int32"),
axis=1,
out_dtype="int32",
)
y = relay.nn.relu(y)
y = relay.layout_transform(y, "NCHW", "NHWC")
y = relay.Function(relay.analysis.free_vars(y), y)
return y

a = before()
a = run_opt_pass(a, transform.ConvertLayout({"qnn.conv2d_transpose": ["NCHW", "default"]}))
b = run_opt_pass(expected(), transform.InferType())

assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)


def test_conv_convert_kernel_layout():
"""Check that convolution kernel layout is correctly transformed."""

Expand Down

0 comments on commit 89deb54

Please sign in to comment.