From 97cfa8892cd66a50d695a4934c02e81f980d0afc Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Tue, 19 Nov 2019 13:56:51 -0800 Subject: [PATCH] [PERF] Parallelize reduction for CPU (#4158) * [PERF] parallel reduction in cpu * fix * x * update * lint * fix --- topi/python/topi/x86/__init__.py | 1 + topi/python/topi/x86/reduction.py | 120 ++++++++++++++++++++++++++++++ 2 files changed, 121 insertions(+) create mode 100644 topi/python/topi/x86/reduction.py diff --git a/topi/python/topi/x86/__init__.py b/topi/python/topi/x86/__init__.py index 2e5e512c97af..6e41a1709f9c 100644 --- a/topi/python/topi/x86/__init__.py +++ b/topi/python/topi/x86/__init__.py @@ -25,6 +25,7 @@ from .nn import * from .conv2d_int8 import * from .injective import * +from .reduction import * from .pooling import schedule_pool, schedule_adaptive_pool from .bitserial_conv2d import schedule_bitserial_conv2d from .bitserial_dense import schedule_bitserial_dense diff --git a/topi/python/topi/x86/reduction.py b/topi/python/topi/x86/reduction.py new file mode 100644 index 000000000000..41e8e9ecb56b --- /dev/null +++ b/topi/python/topi/x86/reduction.py @@ -0,0 +1,120 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""x86 declaration and schedules.""" +from __future__ import absolute_import as _abs +import tvm +from .. import tag +from .. import generic +from ..util import get_const_tuple + +def _schedule_reduce(sch, op, is_idx_reduce=False): + if is_idx_reduce: + real_out = op.output(0) + fused = sch[real_out].fuse(*sch[real_out].op.axis) + out = op.input_tensors[0] + else: + out = op.output(0) + + const_shape = True + out_shape = get_const_tuple(out.shape) + for d in out_shape: + if not isinstance(d, int): + const_shape = False + break + + if const_shape: + naxes = len(sch[out].op.axis) + parallelism = 1 + fuse_axes = [] + # We choose a heuristic number 128 to limit the maximum parallelism + while len(fuse_axes) < naxes and parallelism < 128: + ivar = sch[out].op.axis[len(fuse_axes)] + parallelism *= int(ivar.dom.extent) + fuse_axes.append(ivar) + fused = sch[out].fuse(*fuse_axes) + sch[out].parallel(fused) + else: + if len(sch[out].op.axis) >= 5: + # avoid too many parallelism + fused = sch[out].fuse(sch[out].op.axis[0], sch[out].op.axis[1], sch[out].op.axis[2]) + sch[out].parallel(fused) + else: + fused = sch[out].fuse(*sch[out].op.axis) + sch[out].parallel(fused) + + +@generic.schedule_reduce.register(["cpu"]) +def schedule_reduce(outs): + """X86 schedule for reduction op. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of injective in the format + of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs + sch = tvm.create_schedule([x.op for x in outs]) + scheduled_ops = [] + + def traverse_before_reduce(operator): + """Internal travserse function""" + if isinstance(operator, tvm.tensor.PlaceholderOp): + return + if tag.is_injective(operator.tag): + sch[operator].compute_inline() + for tensor in operator.input_tensors: + if tensor.op not in scheduled_ops: + traverse_before_reduce(tensor.op) + else: + raise RuntimeError("Unsupported operator: %s" % operator.tag) + + scheduled_ops.append(operator) + + def traverse_after_reduce(operator): + """Internal travserse function""" + if tag.is_broadcast(operator.tag): + if operator not in scheduled_ops: + generic.schedule_injective_from_existing(sch, operator) + for tensor in operator.input_tensors: + traverse_after_reduce(tensor.op) + elif operator.tag == 'comm_reduce': + _schedule_reduce(sch, operator, is_idx_reduce=False) + for tensor in operator.input_tensors: + if tensor.op not in scheduled_ops: + traverse_before_reduce(tensor.op) + elif operator.tag == 'comm_reduce_idx': + _schedule_reduce(sch, operator, is_idx_reduce=True) + input_tensors = operator.input_tensors[0].op.input_tensors + for tensor in input_tensors: + if tensor.op not in scheduled_ops: + traverse_before_reduce(tensor.op) + elif isinstance(operator, tvm.tensor.PlaceholderOp): + pass + else: + raise RuntimeError("Unsupported operator: %s (tag: %s)" % (operator, operator.tag)) + + scheduled_ops.append(operator) + + traverse_after_reduce(outs[0].op) + return sch