Skip to content

Commit

Permalink
[Relay] Add support of conv2d with NHWC for Mali (#8422)
Browse files Browse the repository at this point in the history
* [Relay] Add support of conv2d with NHWC for Mali

Added template schedule for conv2d NHWC reusing similar strategy
as for NCHW layout. The schedule is also added to the
corresponding test that can be run to verify correctness.

* [Relay] Fix issue from pylint in conv2d for Mali
  • Loading branch information
AnastasiaStulova authored Jul 9, 2021
1 parent e3e03df commit e7c5349
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 70 deletions.
57 changes: 30 additions & 27 deletions python/tvm/relay/op/strategy/mali.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,36 +73,39 @@ def conv2d_strategy_mali(attrs, inputs, out_type, target):
elif layout == "NHWC":
assert kernel_layout == "HWIO"
if not is_auto_scheduler_enabled():
raise RuntimeError(
"conv2d NHWC layout is not enabled for mali without auto_scheduler."
)
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_nhwc, need_auto_scheduler_layout=True),
naive_schedule,
name="conv2d_nhwc.mali",
)
is_winograd_applicable = False
if len(kernel.shape) == 4:
kernel_h, kernel_w, _, _ = get_const_tuple(kernel.shape)
is_winograd_applicable = (
"float" in data.dtype
and "float" in kernel.dtype
and kernel_h == 3
and kernel_w == 3
and stride_h == 1
and stride_w == 1
and dilation_h == 1
and dilation_w == 1
strategy.add_implementation(
wrap_compute_conv2d(topi.mali.conv2d_nhwc_spatial_pack),
wrap_topi_schedule(topi.mali.schedule_conv2d_nhwc_spatial_pack),
name="conv2d_nhwc_spatial_pack.mali",
)
if is_winograd_applicable:
else:
strategy.add_implementation(
wrap_compute_conv2d(
topi.nn.conv2d_winograd_nhwc, need_auto_scheduler_layout=True
),
naive_schedule, # this implementation should never be picked by autotvm
name="conv2d_nhwc.winograd",
plevel=15,
wrap_compute_conv2d(topi.nn.conv2d_nhwc, need_auto_scheduler_layout=True),
naive_schedule,
name="conv2d_nhwc.mali",
)
is_winograd_applicable = False
if len(kernel.shape) == 4:
kernel_h, kernel_w, _, _ = get_const_tuple(kernel.shape)
is_winograd_applicable = (
"float" in data.dtype
and "float" in kernel.dtype
and kernel_h == 3
and kernel_w == 3
and stride_h == 1
and stride_w == 1
and dilation_h == 1
and dilation_w == 1
)
if is_winograd_applicable:
strategy.add_implementation(
wrap_compute_conv2d(
topi.nn.conv2d_winograd_nhwc, need_auto_scheduler_layout=True
),
naive_schedule, # this implementation should never be picked by autotvm
name="conv2d_nhwc.winograd",
plevel=15,
)

else:
raise RuntimeError("Unsupported conv2d layout {} for mali".format(layout))
Expand Down
15 changes: 11 additions & 4 deletions python/tvm/topi/arm_cpu/conv2d_spatial_pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def schedule_conv2d_spatial_pack_nchw(cfg, s, data_vec, kernel_vec, conv, output
return s


def conv2d_spatial_pack_nhwc(cfg, data, kernel, strides, padding, dilation, out_dtype):
def conv2d_spatial_pack_nhwc(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=2):
"""Spatial pack compute for Conv2d NHWC"""
out_dtype = out_dtype or data.dtype

Expand Down Expand Up @@ -276,9 +276,16 @@ def conv2d_spatial_pack_nhwc(cfg, data, kernel, strides, padding, dilation, out_
n, oc, oh, ow = cfg.axis(N), cfg.axis(OC), cfg.axis(OH), cfg.axis(OW)
ic, kh, kw = cfg.reduce_axis(IC), cfg.reduce_axis(KH), cfg.reduce_axis(KW)

oco, oci = cfg.define_split("tile_co", oc, num_outputs=2)
oho, ohi = cfg.define_split("tile_oh", oh, num_outputs=2)
owo, owi = cfg.define_split("tile_ow", ow, num_outputs=2)
if num_tile == 2: # for arm cpu
oco, oci = cfg.define_split("tile_co", oc, num_outputs=2)
oho, ohi = cfg.define_split("tile_oh", oh, num_outputs=2)
owo, owi = cfg.define_split("tile_ow", ow, num_outputs=2)
elif num_tile == 3: # for mali gpu
oco, _, oci = cfg.define_split("tile_co", oc, num_outputs=3)
oho, _, ohi = cfg.define_split("tile_oh", oh, num_outputs=3)
owo, _, owi = cfg.define_split("tile_ow", ow, num_outputs=3)
else:
raise RuntimeError("Invalid num_tile")

cfg.define_reorder(
"reorder_conv",
Expand Down
124 changes: 85 additions & 39 deletions python/tvm/topi/mali/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

# reuse some compute declarations from ARM CPU
from ..arm_cpu.conv2d_spatial_pack import conv2d_spatial_pack_nchw
from ..arm_cpu.conv2d_spatial_pack import conv2d_spatial_pack_nhwc

logger = logging.getLogger("topi")

Expand Down Expand Up @@ -95,59 +96,89 @@ def schedule_conv2d_nchw_spatial_pack(cfg, outs):
def _callback(op):
# schedule conv2d
if "spatial_conv2d_output" in op.tag:
output = op.output(0)
conv = op.input_tensors[0]
_schedule_spatial_pack(cfg, s, op, layout="NCHW")

traverse_inline(s, outs[0].op, _callback)
return s


@autotvm.register_topi_compute("conv2d_nhwc_spatial_pack.mali")
def conv2d_nhwc_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype):
"""Compute conv2d with NHWC layout"""
return conv2d_spatial_pack_nhwc(
cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=3
)

data_vec = conv.op.input_tensors[0]
data_pad = data_vec.op.input_tensors[0]
s[data_pad].compute_inline()

kernel_vec = conv.op.input_tensors[1]
if kernel_vec.op.name == "kernel_vec":
kernel = kernel_vec.op.input_tensors[0]
else:
kernel = kernel_vec
if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag:
s[kernel].compute_inline()
@autotvm.register_topi_schedule("conv2d_nhwc_spatial_pack.mali")
def schedule_conv2d_nhwc_spatial_pack(cfg, outs):
"""Create schedule for conv2d_nhwc"""
s = te.create_schedule([x.op for x in outs])

_schedule_spatial_pack(cfg, s, output, conv, data_vec, kernel_vec)
def _callback(op):
# schedule conv2d
if "spatial_conv_output_NHWC" in op.tag:
_schedule_spatial_pack(cfg, s, op, layout="NHWC")

traverse_inline(s, outs[0].op, _callback)
return s


def _schedule_spatial_pack(cfg, s, output, conv, data_vec, kernel_vec):
def _schedule_spatial_pack(cfg, s, op, layout):
"""schedule the spatial packing for conv2d"""

assert layout in ("NCHW", "NHWC")

output = op.output(0)
conv = op.input_tensors[0]
data_vec = conv.op.input_tensors[0]
data_pad = data_vec.op.input_tensors[0]
s[data_pad].compute_inline()
kernel_vec = conv.op.input_tensors[1]
if kernel_vec.op.name == "kernel_vec":
kernel = kernel_vec.op.input_tensors[0]
else:
kernel = kernel_vec
if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag:
s[kernel].compute_inline()
data = s[data_vec].op.input_tensors[0]

max_unroll = 16
vec_size = [1, 2, 4, 8, 16]
# get tunable parameters (they are defined in compute)
BC, TC, VC = cfg["tile_co"].size
BH, TH, VH = cfg["tile_oh"].size
BW, TW, VW = cfg["tile_ow"].size
_, TC, VC = cfg["tile_co"].size
_, TH, VH = cfg["tile_oh"].size
_, TW, VW = cfg["tile_ow"].size

# schedule padding
if isinstance(data.op, tvm.te.ComputeOp) and "pad" in data.op.tag:
data_pad = data
s[data_pad].compute_inline()

# schedule data packing
if isinstance(data_vec.op, tvm.te.ComputeOp) and data_vec.op.name == "data_vec_undilated":
_, h, w, ci, _, _, vh, vw = s[data_vec].op.axis
if layout == "NCHW":
if isinstance(data_vec.op, tvm.te.ComputeOp) and data_vec.op.name == "data_vec_undilated":
_, h, w, ci, _, _, vh, vw = s[data_vec].op.axis
else:
_, h, w, ci, vh, vw = s[data_vec].op.axis
z, y, x, unroll1, unroll2 = h, w, ci, vh, vw
else:
_, h, w, ci, vh, vw = s[data_vec].op.axis
tile_and_bind3d(s, data_vec, h, w, ci, 1)
if vh.dom.extent.value < max_unroll:
s[data_vec].unroll(vh)
if vw.dom.extent.value < max_unroll:
s[data_vec].unroll(vw)
if isinstance(data_vec.op, tvm.te.ComputeOp) and data_vec.op.name == "data_vec_undilated":
_, oho, owo, _, _, ic, ohi, owi = s[data_vec].op.axis
else:
_, oho, owo, ohi, owi, ic = s[data_vec].op.axis
z, y, x, unroll1, unroll2 = oho, owo, ohi, ic, owi
tile_and_bind3d(s, data_vec, z, y, x, 1)
if unroll1.dom.extent.value < max_unroll:
s[data_vec].unroll(unroll1)
if unroll2.dom.extent.value < max_unroll:
s[data_vec].unroll(unroll2)

if isinstance(kernel_vec.op, tvm.te.ComputeOp) and kernel_vec.name == "kernel_vec":
if not autotvm.GLOBAL_SCOPE.in_tuning:
max_threads = tvm.target.Target.current(allow_none=False).max_num_threads
co, ci, kh, kw, vc = s[kernel_vec].op.axis
fused = s[kernel_vec].fuse(co, ci, kh, kw, vc)
ax1, ax2, ax3, ax4, ax5 = s[kernel_vec].op.axis
fused = s[kernel_vec].fuse(ax1, ax2, ax3, ax4, ax5)
fused, vec = s[kernel_vec].split(fused, VC)
bb, tt = s[kernel_vec].split(fused, max_threads)
s[kernel_vec].bind(bb, te.thread_axis("blockIdx.x"))
Expand All @@ -156,25 +187,37 @@ def _schedule_spatial_pack(cfg, s, output, conv, data_vec, kernel_vec):
s[kernel_vec].vectorize(vec)

# schedule convolution
n, c, h, w, vh, vw, vc = s[conv].op.axis
kc, kh, kw = s[conv].op.reduce_axis

cfg["reorder_0"].apply(s, conv, [n, c, h, w, kc, kh, kw, vh, vw, vc])
tile_and_bind3d(s, conv, c, h, w, TC, TH, TW)

ic, kh, kw = s[conv].op.reduce_axis
if layout == "NCHW":
kh_dim, kw_dim = kernel_vec.shape[2], kernel_vec.shape[3]
else:
kh_dim, kw_dim = kernel_vec.shape[0], kernel_vec.shape[1]
cfg["ann_reduce"].apply(
s,
conv,
[kh, kw],
axis_lens=[get_const_int(kernel_vec.shape[2]), get_const_int(kernel_vec.shape[3])],
axis_lens=[get_const_int(kh_dim), get_const_int(kw_dim)],
max_unroll=max_unroll,
)

if layout == "NCHW":
n, c, h, w, vh, vw, vc = s[conv].op.axis
cfg["reorder_0"].apply(s, conv, [n, c, h, w, ic, kh, kw, vh, vw, vc])
tile_and_bind3d(s, conv, c, h, w, TC, TH, TW)
unroll_vec_axes = [vh, vw, vc]
axis_lens = [VH, VW, VC]
else:
n, oho, owo, oco, ohi, owi, oci = s[conv].op.axis
cfg["reorder_conv"].apply(s, conv, [n, oho, owo, oco, kh, kw, ic, ohi, owi, oci])
tile_and_bind3d(s, conv, oho, owo, oco, TH, TW, TC)
unroll_vec_axes = [ohi, owi, oci]
axis_lens = [VH, VW, VC]

cfg["ann_spatial"].apply(
s,
conv,
[vh, vw, vc],
axis_lens=[VH, VW, VC],
unroll_vec_axes,
axis_lens,
max_unroll=max_unroll,
vec_size=vec_size,
cfg=cfg,
Expand All @@ -184,9 +227,12 @@ def _schedule_spatial_pack(cfg, s, output, conv, data_vec, kernel_vec):
if output.op not in s.outputs: # has bias
s[output].compute_inline()
output = s.outputs[0]

_, co, oh, ow = s[output].op.axis
tile_and_bind3d(s, output, co, oh, ow, TC, TH, TW)
if layout == "NCHW":
_, co, oh, ow = s[output].op.axis
tile_and_bind3d(s, output, co, oh, ow, TC, TH, TW)
else:
_, oh, ow, co = s[output].op.axis
tile_and_bind3d(s, output, oh, ow, co, TH, TW, TC)

return s

Expand Down
4 changes: 4 additions & 0 deletions tests/python/topi/python/test_topi_conv2d_nhwc.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@
topi.arm_cpu.conv2d_nhwc_spatial_pack,
topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack,
),
"mali": (
topi.mali.conv2d_nhwc_spatial_pack,
topi.mali.schedule_conv2d_nhwc_spatial_pack,
),
"hls": (topi.nn.conv2d_nhwc, topi.hls.schedule_conv2d_nhwc),
}

Expand Down

0 comments on commit e7c5349

Please sign in to comment.