Skip to content

Commit

Permalink
[topi] Add arm_cpu specific pooling schedules (apache#14855)
Browse files Browse the repository at this point in the history
This commit:
* Adds specialized `arm_cpu` pooling schedules for both fixed width and
  salable vectors.
* Enables topi testing of new `arm_cpu` schedules.
  • Loading branch information
FranklandJack authored Jul 10, 2023
1 parent dc7125b commit 9f8fe3c
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 25 deletions.
17 changes: 5 additions & 12 deletions python/tvm/relay/op/strategy/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from ....auto_scheduler import is_auto_scheduler_enabled
from ....meta_schedule import is_meta_schedule_enabled
from ....topi.generic import conv2d as conv2d_generic
from ....topi.arm_cpu.mprofile import dsp
from .. import op as _op
from .generic import *

Expand Down Expand Up @@ -63,19 +64,11 @@ def concatenate_strategy_arm_cpu(attrs, inputs, out_type, target):
def schedule_pool_arm_cpu(attrs, outs, target):
"""schedule pooling ops arm cpu"""
layout = attrs.layout
avg_pool = isinstance(attrs, relay.op.op_attrs.AvgPool2DAttrs)
with target:
if (
avg_pool
and target.features.has_dsp
and layout in ("NCW", "NCHW")
or not avg_pool
and target.features.has_dsp
and layout in ("NWC", "NHWC")
):
return topi.arm_cpu.schedule_pool(outs, layout)
logger.warning("pool is not optimized for arm cpu.")
return topi.generic.schedule_pool(outs, layout)
if target.features.has_dsp:
is_avg_pool = isinstance(attrs, relay.op.op_attrs.AvgPool2DAttrs)
return dsp.pool.schedule_pool(outs, layout, is_avg_pool)
return topi.arm_cpu.schedule_pool(outs, layout)


def _get_padding_width(padding):
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/topi/arm_cpu/mprofile/dsp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Schedule for arm_cpu targets supporting DSP"""
from .pool import schedule_pool
30 changes: 20 additions & 10 deletions python/tvm/topi/arm_cpu/mprofile/dsp/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,12 @@

import tvm

from tvm import te
from tvm import te, topi
from tvm.topi.utils import traverse_inline

from .micro_kernel.max_pool import (
intrin_max,
max_impl,
)
from .micro_kernel.max_pool import intrin_max, max_impl

from .micro_kernel.avg_pool import (
intrin_sum,
sum_impl,
)
from .micro_kernel.avg_pool import intrin_sum, sum_impl

logger = logging.getLogger("topi")

Expand Down Expand Up @@ -100,8 +94,24 @@ def schedule_avgpool_2d_nchw(s, op):
s[output].pragma(n, "import_c", sum_impl(pool_w, uniq_id))


def pool_dsp_schedule(outs, layout):
def schedule_pool(outs, layout, is_avg_pool):
"""Schedule function for v7e-m DSP instructions of pooling."""

if is_avg_pool and layout not in ["NCW", "NCHW"]:
logger.warning(
"avg pool not support for NCW or NCHW layouts on DSP"
"enabled targets, falling back on generic pool"
"implementation"
)
return topi.generic.schedule_pool(outs, layout)
elif not is_avg_pool and layout not in ["NWC", "NHWC"]:
logger.warning(
"max pool not support for NWC or NHWC layouts on DSP"
"enabled targets, falling back on generic pool"
"implementation"
)
return topi.generic.schedule_pool(outs, layout)

s = te.create_schedule([x.op for x in outs])

def _callback(op):
Expand Down
91 changes: 88 additions & 3 deletions python/tvm/topi/arm_cpu/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,94 @@
# pylint: disable=invalid-name, unused-variable
"""Schedule for pooling operators"""

from .mprofile.dsp.pool import pool_dsp_schedule
import logging
from tvm import topi, te
from tvm.target import Target
from .. import tag


def schedule_pool(outs, layout):
"""Create schedule for avgpool/maxpool with dsp"""
return pool_dsp_schedule(outs, layout)
"""Create schedule for avgpool/maxpool"""

if layout != "NHWC":
logger = logging.getLogger("topi")
logger.warning(
"""We currently only support NHWC target specific pools on arm_cpu,
falling back on generic pool scheduling"""
)
return topi.generic.schedule_pool(outs, layout)

return schedule_pool_2d(outs)


def schedule_pool_2d(outs):
"""Create arm_cpu specific 2D schedule for avgpool/maxpool"""

outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
schedule_ops = [x.op for x in outs]
schedule = te.create_schedule(schedule_ops)
scheduled_ops = []

def traverse(op):
# Recursively inline any injective operation that isn't the pooling
# operation or hasn't already been scheduled.
if tag.is_injective(op.tag):
if op not in schedule.outputs:
schedule[op].compute_inline()
for tensor in op.input_tensors:
if isinstance(tensor.op, te.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule the actual pooling operation
elif op.tag.startswith("pool"):
n, height, width, channel = schedule[op].op.axis
# Average pool consists of two parts; a sum then a division.
# We can schedule the division loop to parallelize across height and
# vectorize across width.
enable_explicit_vectorization = not Target.current(allow_none=False).features.has_sve
if op != outs[0].op:
output = outs[0]
output_fused = schedule[output].fuse(output.op.axis[1], output.op.axis[2])
schedule[output].parallel(output_fused)
vectorization_factor = (
8 if enable_explicit_vectorization else output.op.axis[3].dom.extent
)
_, inner = schedule[output].split(output.op.axis[3], vectorization_factor)
schedule[output].vectorize(inner)

padded_input = op.input_tensors[0]
if isinstance(padded_input.op, te.tensor.ComputeOp):
schedule[padded_input].compute_inline()

# For targets without SVE try explicitly vectorizing the channel
# loop, For SVE targets leave the loop in place for LLVM to convert
# into a scalable vector loop.
vectorization_factor = 8 if enable_explicit_vectorization else channel.dom.extent
channel_outer, channel_inner = schedule[op].split(channel, vectorization_factor)
schedule[op].vectorize(channel_inner)
schedule[op].parallel(height)
if len(schedule[op].op.reduce_axis) > 0:
filter_height, filter_width = schedule[op].op.reduce_axis
# We consider any filter of area < 10 to be small enough to
# unroll; 3x3 filters have shown better performance when
# unrolled.
if filter_height.dom.extent * filter_width.dom.extent <= 9:
# For small filters, unrolling the filter loops allows us to
# vectorize over channels without reordering anything.
schedule[op].unroll(filter_width)
schedule[op].unroll(filter_height)
else:
# Reordering so that channels is the fastest moving axis allows
# LLVM to vectorize across contiguous memory in the NHWC
# ordering.
schedule[op].reorder(
n, height, width, filter_height, filter_width, channel_outer, channel_inner
)
else:
schedule[op].reorder(n, height, width, channel_outer, channel_inner)
else:
raise RuntimeError("Unsupported operator: %s" % op.tag)

scheduled_ops.append(op)

traverse(outs[0].op)
return schedule
1 change: 1 addition & 0 deletions tests/python/topi/python/test_topi_pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

_pool_schedule = {
"generic": topi.generic.schedule_pool,
"arm_cpu": topi.arm_cpu.schedule_pool,
"cpu": topi.x86.schedule_pool,
"gpu": topi.cuda.schedule_pool,
"hls": topi.hls.schedule_pool,
Expand Down

0 comments on commit 9f8fe3c

Please sign in to comment.