Skip to content

Commit

Permalink
[Relay][Legalize][ARM_CPU] Handling NHWC layout for arm_cpu.
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 committed Aug 12, 2019
1 parent 3ac27fc commit e5f278e
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 2 deletions.
5 changes: 3 additions & 2 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,11 @@ def alter_op_layout_conv2d(attrs, inputs, tinfos):
from ... import op
return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, op)

# A placeholder to have at least one invocation of register legalize to register FTVMLegalize.
@reg.register_legalize("nn.conv2d")
def legalize_conv2d(attrs, inputs, arg_dtypes):
return None
"""Legalize conv2d"""
from ... import op
return topi.nn.conv2d_legalize(attrs, inputs, arg_dtypes, op)

reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)

Expand Down
46 changes: 46 additions & 0 deletions tests/python/relay/test_pass_legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
# specific language governing permissions and limitations
# under the License.
"""Test legalize pass"""
import numpy as np
import tvm

from tvm import relay
from tvm.contrib import graph_runtime
from tvm.relay.op import register_legalize
from tvm.relay import transform, analysis

Expand Down Expand Up @@ -123,8 +125,52 @@ def expected():

assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)

def test_legalize_arm_layout_functional():
"""Test if the legalized conversion yields same result as original"""
def get_output(func, data_val, parameters):
with relay.build_config(opt_level=0):
graph, lib, params = relay.build(func, target='llvm', params=parameters)
m = graph_runtime.create(graph, lib, tvm.cpu())
m.set_input("data", data_val)
m.set_input(**params)
m.run()
out = m.get_output(0, tvm.nd.empty((1, 224, 224, 32), 'float32')).asnumpy()
return out

def before():
n, ic, ih, iw, oc, kh, kw = 1, 16, 224, 224, 32, 3, 3
data = relay.var("data", relay.TensorType((n, ih, iw, ic), 'float32'))
kernel = relay.var("kernel", relay.TensorType((kh, kw, ic, oc), 'float32'))
y = relay.nn.conv2d(data, kernel,
kernel_size=(kh, kw),
channels=oc,
padding=(1, 1),
dilation=(1, 1),
data_layout='NHWC',
kernel_layout='HWIO',
out_dtype='float32')
func = relay.Function([data, kernel], y)
return func

@register_legalize("nn.conv2d", level=101)
def legalize_conv2d(attrs, inputs, arg_types):
from topi.arm_cpu.conv2d import _conv2d_legalize
return _conv2d_legalize(attrs, inputs, arg_types, tvm.relay.op)

a = before()
b = run_opt_pass(a, transform.Legalize())
assert b.astext().count('transpose') == 3

wdata = np.random.rand(3, 3, 16, 32) * 10
parameters = {"kernel": tvm.nd.array(wdata.astype('float32'))}
data_val = np.random.rand(1, 224, 224, 16).astype('float32')
ref_out = get_output(a, data_val, parameters)
legalized_out = get_output(b, data_val, parameters)
np.testing.assert_allclose(ref_out, legalized_out, rtol=0.01)


if __name__ == "__main__":
test_legalize()
test_legalize_none()
test_legalize_multi_input()
test_legalize_arm_layout_functional()
29 changes: 29 additions & 0 deletions topi/python/topi/arm_cpu/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
conv2d_winograd_without_weight_transform, \
conv2d_winograd_nnpack_without_weight_transform, \
depthwise_conv2d_nchw
from ..nn import conv2d_legalize
from ..nn.util import get_const_int, get_pad_tuple
from ..nn.winograd_util import winograd_transform_matrices

Expand Down Expand Up @@ -783,3 +784,31 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
# currently we only have contrib_spatial_pack and direct template
# add more schedule templates.
return None

@conv2d_legalize.register("arm_cpu")
def _conv2d_legalize(attrs, inputs, arg_types, F):
if F.__name__ != 'tvm.relay.op':
return None
if attrs['data_layout'] == 'NHWC':
data, kernel = inputs
if attrs['kernel_layout'] == 'HWIO':
# Handle HWIO layout. This is common in TF graph.
kernel = F.transpose(kernel, axes=(3, 2, 0, 1))
elif attrs['kernel_layout'] == 'HWOI':
# Handle HWOI layout. This is common in TF depthwise conv2d graph.
kernel = F.transpose(kernel, axes=(2, 3, 0, 1))
elif attrs['kernel_layout'] != 'OIHW':
return None

# Set new attrs for the tranposed conv.
new_attrs = {k: attrs[k] for k in attrs.keys()}
new_attrs['data_layout'] = 'NCHW'
new_attrs['kernel_layout'] = 'OIHW'

# Convert from NHWC to NCHW.
data = F.transpose(data, axes=(0, 3, 1, 2))
conv = F.nn.conv2d(data, kernel, **new_attrs)
# Convert back to original NHWC layout.
out = F.transpose(conv, axes=(0, 2, 3, 1))
return out
return None
22 changes: 22 additions & 0 deletions topi/python/topi/nn/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,28 @@ def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=N
raise ValueError("not support this layout {} yet".format(layout))


@tvm.target.generic_func
def conv2d_legalize(attrs, inputs, arg_dtypes, F):
"""Legalizes Conv2D op.
Parameters
----------
attrs : nnvm.top.AttrDict or tvm.attrs.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized.
arg_dtypes : list of types
List of types of input arguments
F: symbol
The context, can be either nnvm.sym or relay.op
Note
----
Unlike other TOPI functions, this function operates on both graph level and operator level,
so we have to pass 'F' to make it support our two versions of graph IR, NNVM and Relay.
"""
# not to change by default
return None


@tvm.target.generic_func
def conv2d_alter_layout(attrs, inputs, tinfos, F):
"""Change Conv2D layout.
Expand Down

0 comments on commit e5f278e

Please sign in to comment.