diff --git a/python/tvm/relay/op/strategy/__init__.py b/python/tvm/relay/op/strategy/__init__.py index 1be5425e702c..29b6b89a6131 100644 --- a/python/tvm/relay/op/strategy/__init__.py +++ b/python/tvm/relay/op/strategy/__init__.py @@ -30,3 +30,4 @@ from . import intel_graphics from . import hexagon from . import adreno +from . import riscv_cpu diff --git a/python/tvm/relay/op/strategy/riscv_cpu.py b/python/tvm/relay/op/strategy/riscv_cpu.py new file mode 100644 index 000000000000..7853d193757a --- /dev/null +++ b/python/tvm/relay/op/strategy/riscv_cpu.py @@ -0,0 +1,95 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import +"""Definition of RISCV CPU operator strategy.""" + +import logging + +from tvm import topi +from .. import op as _op +from .generic import * +from .x86 import conv2d_strategy_cpu + +logger = logging.getLogger("strategy") + + +@schedule_injective.register("riscv_cpu") +def schedule_injective_riscv_cpu(_, outs, target): + """schedule injective ops for riscv_cpu""" + with target: + return topi.riscv_cpu.schedule_injective(outs) + + +@schedule_reduce.register("riscv_cpu") +def schedule_reduce_riscv_cpu(_, outs, target): + """schedule reduction ops for riscv_cpu""" + with target: + return topi.x86.schedule_reduce(outs) + + +@conv2d_strategy.register("riscv_cpu") +def conv2d_strategy_riscv_cpu(attrs, inputs, out_type, target): + """conv2d riscv_cpu strategy""" + strategy = _op.OpStrategy() + data, kernel = inputs + dilation_h, dilation_w = attrs.get_int_tuple("dilation") + groups = attrs.groups + layout = attrs.data_layout + kernel_layout = attrs.kernel_layout + if dilation_h < 1 or dilation_w < 1: + raise ValueError("dilation should be positive value") + + if groups == 1: + if layout == "NCHW": + assert kernel_layout == "OIHW" + is_int8 = topi.riscv_cpu.is_int8_hw_support(data.dtype, kernel.dtype) + # Vector instructions with int8 show more performance at a larger size. + if is_int8 and kernel.shape[1] >= 128: + strategy.add_implementation( + wrap_compute_conv2d(topi.riscv_cpu.conv2d_nchw_int8), + wrap_topi_schedule(topi.riscv_cpu.schedule_conv2d_nchw_int8), + name="conv2d_nchw_int8.riscv", + plevel=15, + ) + return strategy + + return conv2d_strategy_cpu(attrs, inputs, out_type, target) + + +@conv2d_NCHWc_strategy.register("riscv_cpu") +def conv2d_NCHWc_strategy_riscv_cpu(attrs, inputs, out_type, target): + """conv2d_NCHWc adopted from x86""" + strategy = _op.OpStrategy() + data, kernel = inputs + is_int8 = topi.riscv_cpu.is_int8_hw_support(data.dtype, kernel.dtype) + # Vector instructions with int8 show more performance at a larger size. + if is_int8 and kernel.shape[1] >= 128: + strategy.add_implementation( + wrap_compute_conv2d( + topi.riscv_cpu.conv2d_NCHWc_int8, need_data_layout=True, need_out_layout=True + ), + wrap_topi_schedule(topi.riscv_cpu.schedule_conv2d_NCHWc_int8), + name="conv2d_NCHWc_int8.riscv_cpu", + ) + else: + strategy.add_implementation( + wrap_compute_conv2d(topi.x86.conv2d_NCHWc, need_data_layout=True, need_out_layout=True), + wrap_topi_schedule(topi.x86.schedule_conv2d_NCHWc), + name="conv2d_NCHWc.x86", + ) + return strategy diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index 70cd7a02dab0..268e5429c071 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -1002,6 +1002,20 @@ def _corstone300_compile_time_check(): requires_vitis_ai = Feature("vitis_ai", "Vitis AI", cmake_flag="USE_VITIS_AI") +def _riscv_spike_run_time_check(): + if shutil.which("spike") is None: + return "Spike RISC-V ISA Simulator unavailable" + return True + + +# Mark a test as requiring Spike to run +requires_riscv_spike = Feature( + "spike", + "Spike RISC-V ISA Simulator", + run_time_check=_riscv_spike_run_time_check, +) + + def _arm_dot_supported(): arch = platform.machine() diff --git a/python/tvm/topi/__init__.py b/python/tvm/topi/__init__.py index fc316fd19307..e97cebd8c32f 100644 --- a/python/tvm/topi/__init__.py +++ b/python/tvm/topi/__init__.py @@ -65,6 +65,7 @@ from . import random from . import hexagon from . import adreno +from . import riscv_cpu # error reporting from .utils import InvalidShapeError diff --git a/python/tvm/topi/riscv_cpu/__init__.py b/python/tvm/topi/riscv_cpu/__init__.py new file mode 100644 index 000000000000..0248ae1d0bde --- /dev/null +++ b/python/tvm/topi/riscv_cpu/__init__.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=redefined-builtin, wildcard-import +"""RISCV specific declaration and schedules.""" + +from .injective import * +from .conv2d_int8 import * diff --git a/python/tvm/topi/riscv_cpu/conv2d_int8.py b/python/tvm/topi/riscv_cpu/conv2d_int8.py new file mode 100644 index 000000000000..31c7ce1f8581 --- /dev/null +++ b/python/tvm/topi/riscv_cpu/conv2d_int8.py @@ -0,0 +1,362 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-variable,unused-argument,no-member +"""Conv2D int8 schedule on RISCV""" +from tvm import te, target, autotvm +from ..utils import traverse_inline, get_const_tuple +from ..generic import conv2d as conv2d_generic +from .. import nn +from ..nn.conv2d import _get_workload as _get_conv2d_workload, unpack_NCHWc_to_nchw +from ..x86.conv2d_int8 import _pack_data +from ..nn.utils import get_pad_tuple +from .tensor_intrin import dot_int8_int8_int32, int8_conv2d_impl + + +def _get_default_config(cfg, data, kernel, strides, padding, dilation, out_dtype): + """ + Get default int8 schedule config for the workload + """ + wkl = _get_conv2d_workload(data, kernel, strides, padding, dilation, out_dtype) + is_kernel_1x1 = wkl.kernel_h == 1 and wkl.kernel_w == 1 + if is_kernel_1x1: + conv2d_generic.fallback_schedule_cpu_1x1_int8(cfg, wkl, int32_lanes=4, num_int8_elements=4) + else: + conv2d_generic.fallback_schedule_cpu_common_int8( + cfg, wkl, int32_lanes=4, num_int8_elements=4 + ) + + +def is_int8_hw_support(data_dtype, kernel_dtype): + """ + Checks to ensure that we can use int8 on riscv_cpu. + 1) The datatypes are correct. + 2) The vector extension "V" is used. + """ + # 1) Check datatypes. + is_dtype_support = data_dtype == "uint8" and kernel_dtype == "int8" + + # 2) Check target. + current_target = target.Target.current(allow_none=False) + has_attr = "+v" in current_target.mattr + is_arch_support = "v" in current_target.arch[2:] + if not is_arch_support and "march" in current_target.attrs: + is_arch_support = "v" in current_target.attrs["march"] + is_target_support = has_attr or is_arch_support + + return is_dtype_support and is_target_support + + +@autotvm.register_topi_compute("conv2d_NCHWc_int8.riscv_cpu") +def conv2d_NCHWc_int8(cfg, data, kernel, strides, padding, dilation, layout, out_layout, out_dtype): + """Compute conv2d int8 with NCHWc layout""" + # layout and out_layout are not used here, + # we keep them for debug convenience when dumping autotvm workload + + if len(data.shape) == 5: # data is in nchwc + n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape) + in_channel = ic_chunk * ic_bn + + oc_chunk, ic_chunk, kh, kw, ic_bn, oc_bn, _ = get_const_tuple(kernel.shape) + num_filter = oc_chunk * oc_bn + else: + # data is nchw, implicitly treat it as nchw1c + n, in_channel, ih, iw = get_const_tuple(data.shape) + num_filter, _, kh, kw = get_const_tuple(kernel.shape) + + # Define autotvm tuning space + is_kernel_1x1 = kh == 1 and kw == 1 + pt, pl, pb, pr = get_pad_tuple(padding, (kh, kw)) + sh, sw = strides if isinstance(strides, (tuple, list)) else (strides, strides) + dh, dw = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation) + dilated_kernel_h = (kh - 1) * dh + 1 + dilated_kernel_w = (kw - 1) * dw + 1 + oh = (ih - dilated_kernel_h + pt + pb) // sh + 1 + ow = (iw - dilated_kernel_w + pl + pr) // sw + 1 + + # input and output should be a multiple of 8 (intrinsics are 8 lanes) + cfg.define_split( + "tile_ic", in_channel, num_outputs=2, filter=lambda y: y.size[-1] % min(8, in_channel) == 0 + ) + cfg.define_split( + "tile_oc", num_filter, num_outputs=2, filter=lambda y: y.size[-1] % min(8, num_filter) == 0 + ) + cfg.define_split("tile_ow", ow, num_outputs=2, filter=lambda y: y.size[-1] <= 64) + if is_kernel_1x1: + cfg.define_knob("tile_oh", [1, 2] if oh > 1 else [1]) + else: + cfg.define_knob("unroll_kw", [True, False]) + + # If no config was set, we can fallback to NCHW config. + if cfg.is_fallback: + _get_default_config( + cfg, + te.placeholder((n, in_channel, ih, iw), dtype=data.dtype), + te.placeholder((num_filter, in_channel, kh, kw), dtype=kernel.dtype), + strides, + padding, + dilation, + out_dtype, + ) + # Pack data if raw 4-D data is provided. + # This can only happen when autotuning. + if len(data.shape) == 4: + data, kernel = _pack_data(cfg, data, kernel) + + n_elems = int(kernel.shape[-1]) + + return nn.conv2d_NCHWc_int8( + data, kernel, strides, padding, dilation, layout, out_layout, out_dtype, n_elems=n_elems + ) + + +@autotvm.register_topi_schedule("conv2d_NCHWc_int8.riscv_cpu") +def schedule_conv2d_NCHWc_int8(cfg, outs): + """Create schedule for tensors""" + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if "conv2d_NCHWc_int8" in op.tag: + inline_fused = False + conv_out = op.output(0) + kernel_vec = conv_out.op.input_tensors[1] + data_vec = conv_out.op.input_tensors[0] + data = ( + data_vec.op.input_tensors[0] + if isinstance(data_vec.op, te.tensor.ComputeOp) and "pad" not in data_vec.op.tag + else data_vec + ) + if isinstance(data.op, te.tensor.ComputeOp) and "pad" in data.op.tag: + data_pad = data + data = data_pad.op.input_tensors[0] + + # int8 conv kernel is 7-dim + _, _, kh, kw, _, _, n_elems = get_const_tuple(kernel_vec.shape) + + assert n_elems == 4 + + _, _, _, _, ic_bn = get_const_tuple(data_vec.shape) + _, _, _, _, oc_bn = get_const_tuple(conv_out.shape) + + # schedule pad + if isinstance(s[data_vec].op, te.tensor.ComputeOp) and "pad" in data_vec.op.tag: + batch, ic_chunk, ih, iw, ic_block = s[data_vec].op.axis + parallel_axis = s[data_vec].fuse(batch, ic_chunk, ih) + s[data_vec].parallel(parallel_axis) + data_vec = data_vec.op.input_tensors[0] + + if autotvm.GLOBAL_SCOPE.in_tuning: + # only in autotuning, input data of conv2d_NCHWc will be 4-D. + # skip this part during tuning to make records accurate. + # this part will be folded during Relay fold_constant pass. + if isinstance(data_vec.op, te.tensor.ComputeOp): + s[data_vec].pragma(s[data_vec].op.axis[0], "debug_skip_region") + if isinstance(kernel_vec.op, te.tensor.ComputeOp): + s[kernel_vec].pragma(s[kernel_vec].op.axis[0], "debug_skip_region") + elif isinstance(kernel_vec.op, te.tensor.ComputeOp) and kernel_vec.name == "kernel_vec": + # data and kernel are not pre-computed, schedule layout transform here. + # this should only be used by x86 conv2d_nchw, which is for + # testing purpose. + batch, ic_chunk, ih, _, ic_block = s[data_vec].op.axis + parallel_axis = s[data_vec].fuse(batch, ic_chunk, ih) + s[data_vec].parallel(parallel_axis) + + # conv2d_nchwc_int8 has 7D kernel + oc_chunk, ic_chunk, oh, ow, ic_block, oc_block, _ = s[kernel_vec].op.axis + s[kernel_vec].reorder(oc_chunk, oh, ic_chunk, ow, ic_block, oc_block) + oc_bn = cfg["tile_oc"].size[-1] + if oc_bn > 1: + s[kernel_vec].vectorize(oc_block) + parallel_axis = s[kernel_vec].fuse(oc_chunk, oh) + s[kernel_vec].parallel(parallel_axis) + + # schedule 5-D NCHW[x]c conv + C, O = conv_out, outs[0] + CC = s.cache_write(C, "global") + + batch, oc_chunk, oh, ow, oc_block = s[C].op.axis + + if kh == 1 and kw == 1: + oh_factor, ow_factor = cfg["tile_oh"].val, cfg["tile_ow"].size[-1] + oh_outer, oh_inner = s[C].split(oh, factor=oh_factor) + ow_outer, ow_inner = s[C].split(ow, factor=ow_factor) + s[C].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) + s[C].vectorize(oc_block) + + parallel_axis = s[C].fuse(batch, oc_chunk, oh_outer) + if C == O: + s[C].parallel(parallel_axis) + s[CC].compute_at(s[C], parallel_axis) + + _, oc_chunk, oh, ow, oc_block = s[CC].op.axis + kh, kw, ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis + + oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=4) + + oh_outer, oh_inner = s[CC].split(oh, factor=oh_factor) + ow_outer, ow_inner = s[CC].split(ow, factor=ow_factor) + + s[CC].reorder( + oc_chunk, + oh_outer, + ow_outer, + kh, + kw, + ic_outer, + ic_f_inner, + oh_inner, + ow_inner, + oc_f_inner, + oc_s_inner, + ic_s_inner, + ) + s[CC].fuse(oc_chunk, oh_outer) + + s[CC].tensorize(oc_s_inner, dot_int8_int8_int32()) + s[CC].pragma(oc_f_inner, "import_c", int8_conv2d_impl()) + + s[CC].unroll(ow_inner) + s[CC].unroll(oh_inner) + + if C != O: + out_ndim = len(s[O].op.axis) + if out_ndim == 5: + batch, oc_chunk, oh, ow, oc_block = s[O].op.axis + oh_outer, oh_inner = s[O].split(oh, factor=oh_factor) + ow_outer, ow_inner = s[O].split(ow, factor=ow_factor) + elif out_ndim == 4: + batch, oc, oh, ow = s[O].op.axis + oc_chunk, oc_block = s[O].split(oc, factor=oc_bn) + oh_outer, oh_inner = s[O].split(oh, factor=oh_factor) + ow_outer, ow_inner = s[O].split(ow, factor=ow_factor) + else: + raise ValueError("Unsupported output ndim: %s" % out_ndim) + + s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) + parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer) + if inline_fused: + s[C].compute_at(s[O], ow_inner) + else: + s[C].compute_at(s[O], parallel_axis) + + s[O].vectorize(oc_block) + s[O].parallel(parallel_axis) + else: + if isinstance(cfg["tile_ow"], int): + reg_n = cfg["tile_ow"] + else: + reg_n = cfg["tile_ow"].size[-1] + + if isinstance(cfg["unroll_kw"], (int, bool)): + unroll_kw = cfg["unroll_kw"] + else: + unroll_kw = cfg["unroll_kw"].val + + ow_chunk, ow_block = s[C].split(ow, factor=reg_n) + s[C].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) + parallel_axis = s[C].fuse(batch, oc_chunk, oh) + s[C].vectorize(oc_block) + if C == O: + s[C].parallel(parallel_axis) + + s[CC].compute_at(s[C], parallel_axis) + _, oc_chunk, oh, ow, oc_block = s[CC].op.axis + kh, kw, ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis + + ow_chunk, ow_block = s[CC].split(ow, factor=reg_n) + + assert oc_bn % 4 == 0, f"oc_bn={oc_bn} % int32_lanes={4} != 0" + assert ( + ic_bn % 4 == 0 + ), f"ic_bn={ic_bn} % int8_elems={4} != 0" # (u)int8 elements in (u)int32 + + oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=4) + + if unroll_kw: + s[CC].reorder( + oc_chunk, + oh, + ow_chunk, + ic_outer, + kh, + ic_f_inner, + kw, + ow_block, + oc_f_inner, + oc_s_inner, + ic_s_inner, + ) + s[CC].unroll(kw) + else: + s[CC].reorder( + oc_chunk, + oh, + ow_chunk, + ic_outer, + kh, + kw, + ic_f_inner, + ow_block, + oc_f_inner, + oc_s_inner, + ic_s_inner, + ) + + s[CC].tensorize(oc_s_inner, dot_int8_int8_int32()) + s[CC].pragma(oc_f_inner, "import_c", int8_conv2d_impl()) + + s[CC].unroll(ow_block) + s[CC].unroll(oc_f_inner) + + if C != O: + out_ndim = len(s[O].op.axis) + if out_ndim == 5: + batch, oc_chunk, oh, ow, oc_block = s[O].op.axis + ow_chunk, ow_block = s[O].split(ow, factor=reg_n) + s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) + elif out_ndim == 4: + batch, oc, oh, ow = s[O].op.axis + ow_chunk, ow_block = s[O].split(ow, factor=reg_n) + oc_chunk, oc_block = s[O].split(oc, factor=oc_bn) + s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) + else: + raise ValueError("Unsupported output ndim: %s" % out_ndim) + parallel_axis = s[O].fuse(batch, oc_chunk, oh) + if inline_fused: + s[C].compute_at(s[O], ow_block) + else: + s[C].compute_at(s[O], parallel_axis) + s[O].vectorize(oc_block) + s[O].parallel(parallel_axis) + + traverse_inline(s, outs[0].op, _callback) + return s + + +def conv2d_nchw_int8(data, kernel, strides, padding, dilation, out_dtype): + """Compute conv2d with NCHW layout and int8 dtype""" + layout = "NCHW" + # pylint: disable=no-value-for-parameter + packed_out = conv2d_NCHWc_int8( + data, kernel, strides, padding, dilation, layout, layout, out_dtype + ) + return unpack_NCHWc_to_nchw(packed_out, out_dtype) + + +def schedule_conv2d_nchw_int8(outs): + """Create the schedule for conv2d_nchw_int8""" + # pylint: disable=no-value-for-parameter + return schedule_conv2d_NCHWc_int8(outs) diff --git a/python/tvm/topi/riscv_cpu/injective.py b/python/tvm/topi/riscv_cpu/injective.py new file mode 100644 index 000000000000..4a2581040b8c --- /dev/null +++ b/python/tvm/topi/riscv_cpu/injective.py @@ -0,0 +1,87 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-variable +"""Schedule for pooling operators""" +from tvm import te +from ..utils import is_empty_shape + + +def schedule_injective_from_existing(sch, out): + """Schedule for injective op from existing schedule. + Parameters + ---------- + sch: Schedule + The schedule to update. + out: Tensor + The tensor representing the injective op. + Returns + ------- + sch: Schedule + The updated schedule. + """ + if len(sch[out].op.axis) >= 5: + fused = sch[out].fuse(sch[out].op.axis[0], sch[out].op.axis[1], sch[out].op.axis[2]) + sch[out].parallel(fused) + elif len(sch[out].op.axis) >= 3: + fused = sch[out].fuse(sch[out].op.axis[0], sch[out].op.axis[1]) + sch[out].parallel(fused) + elif len(sch[out].op.axis) >= 1: + sch[out].parallel(sch[out].op.axis[0]) + + # Vectorize the inner most for loop. Tiling first to get a const extent + if len(sch[out].op.axis) >= 1: + l = sch[out].op.axis[-1] + lo, li = sch[out].split(l, factor=16) + sch[out].vectorize(li) + + # for 1D loop, the above split will break the parallel axis + # Need to make the outer loop parallel again + if len(sch[out].op.axis) == 1: + sch[out].parallel(lo) + + return sch + + +def schedule_injective(outs): + """RISCV schedule for injective op. + Parameters + ---------- + outs: Array of Tensor + The computation graph description of injective in the format + of an array of tensors. + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + te.schedule.AutoInlineInjective(s) + for x in outs: + # Vectorize "ADD32" operation. + if "add" in x.name: + is_int32 = ( + x.op.input_tensors[0].dtype == "int32" and x.op.input_tensors[1].dtype == "int32" + ) + is_even = x.shape[-1] % 2 == 0 + if is_int32 and is_even: + outer, inner = s[x].split(x.op.axis[-1], 2) + s[x].vectorize(inner) + elif not is_empty_shape(x.shape): + schedule_injective_from_existing(s, x) + + return s diff --git a/python/tvm/topi/riscv_cpu/tensor_intrin.py b/python/tvm/topi/riscv_cpu/tensor_intrin.py new file mode 100644 index 000000000000..8f5578d690b7 --- /dev/null +++ b/python/tvm/topi/riscv_cpu/tensor_intrin.py @@ -0,0 +1,217 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-variable,unused-argument,no-member +"""Core kernel of dot product of 4 Int8 operations""" + +import tvm +from tvm import te + + +def dot_int8_int8_int32(): + """Int8 dot product by every 4 elements using RISC-V V-extension instructions.""" + int32_lanes = 4 # 4 int32 lanes = 128. + num_int8_elements = 4 # 4 int8 elements in int32. + data = te.placeholder((num_int8_elements,), dtype="uint8", name="data") + kernel = te.placeholder((int32_lanes, num_int8_elements), dtype="int8", name="kernel") + k = te.reduce_axis((0, num_int8_elements), name="k") + C = te.compute( + (int32_lanes,), + lambda i: te.sum(data[k].astype("int32") * kernel[i, k].astype("int32"), axis=k), + name="C", + ) + + a_buffer = tvm.tir.decl_buffer( + data.shape, dtype="uint8", name="a_buffer", offset_factor=1, strides=[1] + ) + b_buffer = tvm.tir.decl_buffer( + kernel.shape, dtype="int8", name="b_buffer", offset_factor=1, strides=[te.var("ldw"), 1] + ) + + def _intrin_func(ins, outs): + def _body(): + ib = tvm.tir.ir_builder.create() + ib.emit( + tvm.tir.call_extern( + outs[0].dtype, + "dot_uint8_int8_int32_body", + ins[0].access_ptr("r"), + ins[1].access_ptr("r"), + outs[0].access_ptr("w"), + ) + ) + return ib.get() + + def _reset(): + ib = tvm.tir.ir_builder.create() + ib.emit( + tvm.tir.call_extern( + outs[0].dtype, "dot_uint8_int8_int32_reset", outs[0].access_ptr("w") + ) + ) + return ib.get() + + def _reduce_update(): + ib = tvm.tir.ir_builder.create() + ib.emit( + tvm.tir.call_extern( + outs[0].dtype, + "dot_uint8_int8_int32_update", + ins[0].access_ptr("r"), + ins[1].access_ptr("r"), + outs[0].access_ptr("w"), + ) + ) + return ib.get() + + return _body(), _reset(), _reduce_update() + + buffer_params = {"offset_factor": 1} + return te.decl_tensor_intrin( + C.op, + _intrin_func, + binds={data: a_buffer, kernel: b_buffer}, + default_buffer_params=buffer_params, + ) + + +def int8_conv2d_impl(): + """Emit C code for dot product impl.""" + cc_code = """ +#ifndef TVM_RISCV_CONV2D_INT8 +#define TVM_RISCV_CONV2D_INT8 +#include +#include + +#ifdef __cplusplus +extern "C" +#endif +int32_t dot_uint8_int8_int32_reset(int32_t* res) {{ + // In this function, we set all values in the output array to 0. + for (uint32_t i = 0; i < 4; ++i) + res[i] = 0; + return 0; +}} + +// In this function we need to multiply two vectors, then in the resulting vector calculate the sum of +// all its elements and store it to output array. +// +// Example without vectorization: +// for (unsigned i = 0; i < 4; ++i) {{ +// output[i] = 0; +// for (unsigned j = 0; j < 4; ++j) {{ +// output[i] += data[k] * kernel[i][j]; +// }} +// }} + +#ifdef __cplusplus +extern "C" +#endif +int32_t dot_uint8_int8_int32_body(uint8_t* data, int8_t* kernel, int32_t* output) {{ + // Load values from data into a vector. + vuint8mf2_t v_data = vle8_v_u8mf2(data, -1); + + // Dummy vector for operations. + vint32m1_t empty; + + // Load values from kernel[i][*], + // then we multiply two vectors with type extension: + // v_mul_[i] = v_kernel_[i] * v_data, + // then we count the sum in resulting vector: + // v_sum_[i] = sum(v_mul_[i]) + + vint8mf2_t v_kernel_1 = vle8_v_i8mf2(&kernel[0], -1); + vint16m1_t v_mul_1 = vwmulsu_vv_i16m1(v_kernel_1, v_data, -1); + vint32m1_t v_sum_1 = vwredsum(empty, v_mul_1, empty, 4); + + vint8mf2_t v_kernel_2 = vle8_v_i8mf2(&kernel[4], -1); + vint16m1_t v_mul_2 = vwmulsu_vv_i16m1(v_kernel_2, v_data, -1); + vint32m1_t v_sum_2 = vwredsum(empty, v_mul_2, empty, 4); + + vint8mf2_t v_kernel_3 = vle8_v_i8mf2(&kernel[8], -1); + vint16m1_t v_mul_3 = vwmulsu_vv_i16m1(v_kernel_3, v_data, -1); + vint32m1_t v_sum_3 = vwredsum(empty, v_mul_3, empty, 4); + + vint8mf2_t v_kernel_4 = vle8_v_i8mf2(&kernel[12], -1); + vint16m1_t v_mul_4 = vwmulsu_vv_i16m1(v_kernel_4, v_data, -1); + vint32m1_t v_sum_4 = vwredsum(empty, v_mul_4, empty, 4); + + // Save new values to output. + output[0] = vmv_x_s_i32m1_i32(v_sum_1); + output[1] = vmv_x_s_i32m1_i32(v_sum_2); + output[2] = vmv_x_s_i32m1_i32(v_sum_3); + output[3] = vmv_x_s_i32m1_i32(v_sum_4); + + return 0; +}} + +// In this function we need to multiply two vectors, then in the resulting vector calculate the sum of +// all its elements and add it to the value from output. +// +// Example without vectorization: +// for (unsigned i = 0; i < 4; ++i) +// for (unsigned j = 0; j < 4; ++j) +// output[i] += data[k] * kernel[i][j]; + +#ifdef __cplusplus +extern "C" +#endif +int32_t dot_uint8_int8_int32_update(uint8_t* data, int8_t* kernel, int32_t* output) {{ + // Load values from data into a vector. + vuint8mf2_t v_data = vle8_v_u8mf2(data, -1); + + // Dummy vector for operations. + vint32m1_t empty; + + // Load values from output into vectors. + vint32m1_t v_output_1 = vle32_v_i32m1(&output[0], -1); + vint32m1_t v_output_2 = vle32_v_i32m1(&output[1], -1); + vint32m1_t v_output_3 = vle32_v_i32m1(&output[2], -1); + vint32m1_t v_output_4 = vle32_v_i32m1(&output[3], -1); + + // Load values from kernel[i][*], + // then we multiply two vectors with type extension: + // v_mul_[i] = v_kernel_[i] * v_data, + // then we count the sum in resulting vector and add value from output: + // v_sum_[i] = v_output_[i] + sum(v_mul_[i]) + + vint8mf2_t v_kernel_1 = vle8_v_i8mf2(&kernel[0], -1); + vint16m1_t v_mul_1 = vwmulsu_vv_i16m1(v_kernel_1, v_data, -1); + vint32m1_t v_sum_1 = vwredsum(empty, v_mul_1, v_output_1, 4); + + vint8mf2_t v_kernel_2 = vle8_v_i8mf2(&kernel[4], -1); + vint16m1_t v_mul_2 = vwmulsu_vv_i16m1(v_kernel_2, v_data, -1); + vint32m1_t v_sum_2 = vwredsum(empty, v_mul_2, v_output_2, 4); + + vint8mf2_t v_kernel_3 = vle8_v_i8mf2(&kernel[8], -1); + vint16m1_t v_mul_3 = vwmulsu_vv_i16m1(v_kernel_3, v_data, -1); + vint32m1_t v_sum_3 = vwredsum(empty, v_mul_3, v_output_3, 4); + + vint8mf2_t v_kernel_4 = vle8_v_i8mf2(&kernel[12], -1); + vint16m1_t v_mul_4 = vwmulsu_vv_i16m1(v_kernel_4, v_data, -1); + vint32m1_t v_sum_4 = vwredsum(empty, v_mul_4, v_output_4, 4); + + // Save updated values to output. + output[0] = vmv_x_s_i32m1_i32(v_sum_1); + output[1] = vmv_x_s_i32m1_i32(v_sum_2); + output[2] = vmv_x_s_i32m1_i32(v_sum_3); + output[3] = vmv_x_s_i32m1_i32(v_sum_4); + + return 0; +}} +#endif + """ + return cc_code diff --git a/tests/python/relay/strategy/riscv_cpu/test_conv2d_int8_nchw.py b/tests/python/relay/strategy/riscv_cpu/test_conv2d_int8_nchw.py new file mode 100644 index 000000000000..acfa89281271 --- /dev/null +++ b/tests/python/relay/strategy/riscv_cpu/test_conv2d_int8_nchw.py @@ -0,0 +1,143 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pathlib + +import numpy as np + +import tvm +import tvm.testing +from tvm import relay +from tvm.relay.backend import Executor, Runtime +from tvm.testing.aot import generate_ref_data + + +def _make_session(temp_dir, mod): + template_project_dir = pathlib.Path(tvm.micro.get_microtvm_template_projects("riscv")) + options = { + "toolchain_path": "/opt/riscv", + "target": "riscv64-unknown-linux-gnu", + "march": "rv64gcv", + "verbose": 1, + } + project = tvm.micro.generate_project(template_project_dir, mod, temp_dir / "project", options) + project.build() + project.flash() + return tvm.micro.Session(project.transport()) + + +class RISCVConv2dInt8: + @tvm.testing.requires_riscv_spike + def test_conv2d_int8( + self, + data_shape, + kernel_size, + data_layout, + kernel_layout, + num_filter, + padding, + dtype, + wtype, + ): + weight_shape = (num_filter, data_shape[1], *kernel_size) + + data = relay.var("input", shape=data_shape, dtype=dtype) + + if "int" in wtype: + min_w_value = np.iinfo(wtype).min + max_w_value = np.iinfo(wtype).max + else: + min_w_value = np.finfo(wtype).min + max_w_value = np.finfo(wtype).max + + weight_data = np.random.randint( + low=min_w_value, high=max_w_value, size=weight_shape, dtype=wtype + ) + weight = relay.const(weight_data) + + func = relay.qnn.op.conv2d( + data, + weight, + relay.const(1, "int32"), + relay.const(1, "int32"), + relay.const(1, "float32"), + relay.const(1, "float32"), + channels=weight_shape[0], + kernel_size=kernel_size, + padding=padding, + data_layout=data_layout, + kernel_layout=kernel_layout, + ) + + ref_mod = relay.Function(relay.analysis.free_vars(func), func) + ref_mod = tvm.IRModule.from_expr(ref_mod) + + if "int" in dtype: + min_d_value = np.iinfo(dtype).min + max_d_value = np.iinfo(dtype).max + else: + min_d_value = np.finfo(dtype).min + max_d_value = np.finfo(dtype).max + + inputs = { + "input": np.random.randint( + low=min_d_value, high=max_d_value, size=data_shape, dtype=dtype + ) + } + + output_list = generate_ref_data(ref_mod, inputs) + + mod = relay.Function(relay.analysis.free_vars(func), func) + mod = tvm.IRModule.from_expr(mod) + + temp_dir = tvm.contrib.utils.tempdir() + target = "c -keys=riscv_cpu -march=rv64gcv" + target = tvm.target.Target(target, host="c") + runtime = Runtime("crt", {"system-lib": True}) + executor = Executor("aot") + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + factory = tvm.relay.build(mod, target=target, runtime=runtime, executor=executor) + + def do_test(): + aot_executor = tvm.micro.create_local_aot_executor(sess) + aot_executor.get_input("input").copyfrom(inputs["input"]) + aot_executor.run() + + out = aot_executor.get_output(0) + assert (out.numpy() == output_list["output"]).all() + + with _make_session(temp_dir, factory) as sess: + do_test() + + +class TestConv2d_NCHW(RISCVConv2dInt8): + (data_shape, kernel_size, num_filter,) = tvm.testing.parameters( + ((1, 128, 14, 14), (3, 3), 128), + ((1, 128, 14, 14), (1, 1), 256), + ((1, 256, 7, 7), (1, 1), 512), + ((1, 256, 7, 7), (3, 3), 512), + ((1, 512, 3, 3), (3, 3), 512), + ) + padding = tvm.testing.parameter((1, 1)) + data_layout = tvm.testing.parameter("NCHW") + kernel_layout = tvm.testing.parameter("OIHW") + dtype = tvm.testing.parameter("uint8") + wtype = tvm.testing.parameter("int8") + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relay/strategy/test_select_implementation.py b/tests/python/relay/strategy/test_select_implementation.py index 2bf1548d41d8..239cc88e8430 100644 --- a/tests/python/relay/strategy/test_select_implementation.py +++ b/tests/python/relay/strategy/test_select_implementation.py @@ -104,6 +104,53 @@ def test_int8_conv2d(target, expected_impl): assert impl.name == expected_impl +@pytest.mark.parametrize( + "target,expected_impl", + [ + ( + "c -keys=riscv_cpu -march=rv32gcv", + "conv2d_nchw_int8.riscv", + ), + ( + "c -keys=riscv_cpu -march=rv64gcv", + "conv2d_nchw_int8.riscv", + ), + ], +) +def test_riscv_conv2d(target, expected_impl): + target = tvm.target.Target(target) + + data_type = "uint8" + weight_type = "int8" + data_shape = (1, 128, 1, 1) + weight_shape = (128, 128, 1, 1) + data_layout = "NCHW" + kernel_layout = "OIHW" + channels = 128 + kernel_size = (1, 1) + + out = relay.nn.conv2d( + relay.var("data", shape=data_shape, dtype=data_type), + relay.var("weight", shape=weight_shape, dtype=weight_type), + kernel_size=kernel_size, + channels=channels, + data_layout=data_layout, + kernel_layout=kernel_layout, + ) + out = run_infer_type(out) + + with target: + impl, _ = relay.backend.te_compiler.select_implementation( + out.op, + out.attrs, + [te.placeholder(data_shape, data_type), te.placeholder(weight_shape, weight_type)], + out.checked_type, + target, + ) + + assert impl.name == expected_impl + + @pytest.mark.parametrize( "target,expected_impl", [ diff --git a/tests/scripts/task_riscv_microtvm.sh b/tests/scripts/task_riscv_microtvm.sh index c597506dfede..db36d295df0a 100755 --- a/tests/scripts/task_riscv_microtvm.sh +++ b/tests/scripts/task_riscv_microtvm.sh @@ -25,3 +25,5 @@ make cython3 # NOTE: this exists to ensure some tests run on RISC-V image. Without it, Jenkins reports a configuration error. # This line can be removed when RISC-V tests are added. run_pytest ctypes riscv-platform-minimal-test-0 tests/python/all-platform-minimal-test + +run_pytest ctypes python-relay-strategy-riscv_cpu tests/python/relay/strategy/riscv_cpu