From 4f8b27106725947a55645c0e3b3aeb4af0fa8b20 Mon Sep 17 00:00:00 2001 From: katebern-grovety <132359118+katebern-grovety@users.noreply.github.com> Date: Fri, 12 May 2023 13:17:43 +0400 Subject: [PATCH 01/10] Dot product tensorization with RISC-V V-extension intrinsics for conv2d_nchw_int8 --- python/tvm/relay/op/strategy/__init__.py | 1 + python/tvm/relay/op/strategy/riscv_cpu.py | 94 +++++ python/tvm/testing/utils.py | 13 + python/tvm/topi/__init__.py | 1 + python/tvm/topi/riscv_cpu/__init__.py | 22 ++ python/tvm/topi/riscv_cpu/conv2d_int8.py | 362 ++++++++++++++++++ python/tvm/topi/riscv_cpu/injective.py | 85 ++++ python/tvm/topi/riscv_cpu/tensor_intrin.py | 218 +++++++++++ tests/python/relay/aot/riscv.mk | 85 ++++ .../riscv_cpu/test_conv2d_int8_nchw.py | 134 +++++++ tests/scripts/task_riscv_microtvm.sh | 2 + 11 files changed, 1017 insertions(+) create mode 100644 python/tvm/relay/op/strategy/riscv_cpu.py create mode 100644 python/tvm/topi/riscv_cpu/__init__.py create mode 100644 python/tvm/topi/riscv_cpu/conv2d_int8.py create mode 100644 python/tvm/topi/riscv_cpu/injective.py create mode 100644 python/tvm/topi/riscv_cpu/tensor_intrin.py create mode 100644 tests/python/relay/aot/riscv.mk create mode 100644 tests/python/relay/strategy/riscv_cpu/test_conv2d_int8_nchw.py diff --git a/python/tvm/relay/op/strategy/__init__.py b/python/tvm/relay/op/strategy/__init__.py index 1be5425e702c..29b6b89a6131 100644 --- a/python/tvm/relay/op/strategy/__init__.py +++ b/python/tvm/relay/op/strategy/__init__.py @@ -30,3 +30,4 @@ from . import intel_graphics from . import hexagon from . import adreno +from . import riscv_cpu diff --git a/python/tvm/relay/op/strategy/riscv_cpu.py b/python/tvm/relay/op/strategy/riscv_cpu.py new file mode 100644 index 000000000000..6eb6329dea74 --- /dev/null +++ b/python/tvm/relay/op/strategy/riscv_cpu.py @@ -0,0 +1,94 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import +"""Definition of RISCV CPU operator strategy.""" + +from functools import reduce +import logging + +from tvm import topi +from .. import op as _op +from .generic import * +from .x86 import conv2d_strategy_cpu + +logger = logging.getLogger("strategy") + + +@schedule_injective.register("riscv_cpu") +def schedule_injective_riscv_cpu(_, outs, target): + """schedule injective ops for riscv_cpu""" + with target: + return topi.riscv_cpu.schedule_injective(outs) + + +@schedule_reduce.register("riscv_cpu") +def schedule_reduce_riscv_cpu(_, outs, target): + """schedule reduction ops for riscv_cpu""" + with target: + return topi.x86.schedule_reduce(outs) + + +@conv2d_strategy.register("riscv_cpu") +def conv2d_strategy_riscv_cpu(attrs, inputs, out_type, target): + """conv2d riscv_cpu strategy""" + strategy = _op.OpStrategy() + data, kernel = inputs + dilation_h, dilation_w = attrs.get_int_tuple("dilation") + groups = attrs.groups + layout = attrs.data_layout + kernel_layout = attrs.kernel_layout + if dilation_h < 1 or dilation_w < 1: + raise ValueError("dilation should be positive value") + + if groups == 1: + if layout == "NCHW": + assert kernel_layout == "OIHW" + is_int8 = topi.riscv_cpu.is_int8_hw_support(data.dtype, kernel.dtype) + # Vector instructions with int8 show more performance at a larger size. + if is_int8 and kernel.shape[1] >= 128: + strategy.add_implementation( + wrap_compute_conv2d(topi.riscv_cpu.conv2d_nchw_int8), + wrap_topi_schedule(topi.riscv_cpu.schedule_conv2d_nchw_int8), + name="conv2d_nchw_int8.riscv", + plevel=15 + ) + return strategy + + return conv2d_strategy_cpu(attrs, inputs, out_type, target) + + +@conv2d_NCHWc_strategy.register("riscv_cpu") +def conv2d_NCHWc_strategy_riscv_cpu(attrs, inputs, out_type, target): + """conv2d_NCHWc adopted from x86""" + strategy = _op.OpStrategy() + data, kernel = inputs + is_int8 = topi.riscv_cpu.is_int8_hw_support(data.dtype, kernel.dtype) + # Vector instructions with int8 show more performance at a larger size. + if is_int8 and kernel.shape[1] >= 128: + strategy.add_implementation( + wrap_compute_conv2d(topi.riscv_cpu.conv2d_NCHWc_int8, need_data_layout=True, need_out_layout=True), + wrap_topi_schedule(topi.riscv_cpu.schedule_conv2d_NCHWc_int8), + name="conv2d_NCHWc_int8.riscv_cpu", + ) + else: + strategy.add_implementation( + wrap_compute_conv2d(topi.x86.conv2d_NCHWc, need_data_layout=True, need_out_layout=True), + wrap_topi_schedule(topi.x86.schedule_conv2d_NCHWc), + name="conv2d_NCHWc.x86", + ) + return strategy diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index 70cd7a02dab0..bf7d45cfaf53 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -1002,6 +1002,19 @@ def _corstone300_compile_time_check(): requires_vitis_ai = Feature("vitis_ai", "Vitis AI", cmake_flag="USE_VITIS_AI") +def _riscv_spike_run_time_check(): + if shutil.which("spike") is None: + return "Spike RISC-V ISA Simulator unavailable" + return True + + +# Mark a test as requiring Spike to run +requires_riscv_spike = Feature( + "spike", + "Spike RISC-V ISA Simulator", + run_time_check=_riscv_spike_run_time_check, +) + def _arm_dot_supported(): arch = platform.machine() diff --git a/python/tvm/topi/__init__.py b/python/tvm/topi/__init__.py index fc316fd19307..e97cebd8c32f 100644 --- a/python/tvm/topi/__init__.py +++ b/python/tvm/topi/__init__.py @@ -65,6 +65,7 @@ from . import random from . import hexagon from . import adreno +from . import riscv_cpu # error reporting from .utils import InvalidShapeError diff --git a/python/tvm/topi/riscv_cpu/__init__.py b/python/tvm/topi/riscv_cpu/__init__.py new file mode 100644 index 000000000000..0248ae1d0bde --- /dev/null +++ b/python/tvm/topi/riscv_cpu/__init__.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=redefined-builtin, wildcard-import +"""RISCV specific declaration and schedules.""" + +from .injective import * +from .conv2d_int8 import * diff --git a/python/tvm/topi/riscv_cpu/conv2d_int8.py b/python/tvm/topi/riscv_cpu/conv2d_int8.py new file mode 100644 index 000000000000..31c7ce1f8581 --- /dev/null +++ b/python/tvm/topi/riscv_cpu/conv2d_int8.py @@ -0,0 +1,362 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-variable,unused-argument,no-member +"""Conv2D int8 schedule on RISCV""" +from tvm import te, target, autotvm +from ..utils import traverse_inline, get_const_tuple +from ..generic import conv2d as conv2d_generic +from .. import nn +from ..nn.conv2d import _get_workload as _get_conv2d_workload, unpack_NCHWc_to_nchw +from ..x86.conv2d_int8 import _pack_data +from ..nn.utils import get_pad_tuple +from .tensor_intrin import dot_int8_int8_int32, int8_conv2d_impl + + +def _get_default_config(cfg, data, kernel, strides, padding, dilation, out_dtype): + """ + Get default int8 schedule config for the workload + """ + wkl = _get_conv2d_workload(data, kernel, strides, padding, dilation, out_dtype) + is_kernel_1x1 = wkl.kernel_h == 1 and wkl.kernel_w == 1 + if is_kernel_1x1: + conv2d_generic.fallback_schedule_cpu_1x1_int8(cfg, wkl, int32_lanes=4, num_int8_elements=4) + else: + conv2d_generic.fallback_schedule_cpu_common_int8( + cfg, wkl, int32_lanes=4, num_int8_elements=4 + ) + + +def is_int8_hw_support(data_dtype, kernel_dtype): + """ + Checks to ensure that we can use int8 on riscv_cpu. + 1) The datatypes are correct. + 2) The vector extension "V" is used. + """ + # 1) Check datatypes. + is_dtype_support = data_dtype == "uint8" and kernel_dtype == "int8" + + # 2) Check target. + current_target = target.Target.current(allow_none=False) + has_attr = "+v" in current_target.mattr + is_arch_support = "v" in current_target.arch[2:] + if not is_arch_support and "march" in current_target.attrs: + is_arch_support = "v" in current_target.attrs["march"] + is_target_support = has_attr or is_arch_support + + return is_dtype_support and is_target_support + + +@autotvm.register_topi_compute("conv2d_NCHWc_int8.riscv_cpu") +def conv2d_NCHWc_int8(cfg, data, kernel, strides, padding, dilation, layout, out_layout, out_dtype): + """Compute conv2d int8 with NCHWc layout""" + # layout and out_layout are not used here, + # we keep them for debug convenience when dumping autotvm workload + + if len(data.shape) == 5: # data is in nchwc + n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape) + in_channel = ic_chunk * ic_bn + + oc_chunk, ic_chunk, kh, kw, ic_bn, oc_bn, _ = get_const_tuple(kernel.shape) + num_filter = oc_chunk * oc_bn + else: + # data is nchw, implicitly treat it as nchw1c + n, in_channel, ih, iw = get_const_tuple(data.shape) + num_filter, _, kh, kw = get_const_tuple(kernel.shape) + + # Define autotvm tuning space + is_kernel_1x1 = kh == 1 and kw == 1 + pt, pl, pb, pr = get_pad_tuple(padding, (kh, kw)) + sh, sw = strides if isinstance(strides, (tuple, list)) else (strides, strides) + dh, dw = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation) + dilated_kernel_h = (kh - 1) * dh + 1 + dilated_kernel_w = (kw - 1) * dw + 1 + oh = (ih - dilated_kernel_h + pt + pb) // sh + 1 + ow = (iw - dilated_kernel_w + pl + pr) // sw + 1 + + # input and output should be a multiple of 8 (intrinsics are 8 lanes) + cfg.define_split( + "tile_ic", in_channel, num_outputs=2, filter=lambda y: y.size[-1] % min(8, in_channel) == 0 + ) + cfg.define_split( + "tile_oc", num_filter, num_outputs=2, filter=lambda y: y.size[-1] % min(8, num_filter) == 0 + ) + cfg.define_split("tile_ow", ow, num_outputs=2, filter=lambda y: y.size[-1] <= 64) + if is_kernel_1x1: + cfg.define_knob("tile_oh", [1, 2] if oh > 1 else [1]) + else: + cfg.define_knob("unroll_kw", [True, False]) + + # If no config was set, we can fallback to NCHW config. + if cfg.is_fallback: + _get_default_config( + cfg, + te.placeholder((n, in_channel, ih, iw), dtype=data.dtype), + te.placeholder((num_filter, in_channel, kh, kw), dtype=kernel.dtype), + strides, + padding, + dilation, + out_dtype, + ) + # Pack data if raw 4-D data is provided. + # This can only happen when autotuning. + if len(data.shape) == 4: + data, kernel = _pack_data(cfg, data, kernel) + + n_elems = int(kernel.shape[-1]) + + return nn.conv2d_NCHWc_int8( + data, kernel, strides, padding, dilation, layout, out_layout, out_dtype, n_elems=n_elems + ) + + +@autotvm.register_topi_schedule("conv2d_NCHWc_int8.riscv_cpu") +def schedule_conv2d_NCHWc_int8(cfg, outs): + """Create schedule for tensors""" + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if "conv2d_NCHWc_int8" in op.tag: + inline_fused = False + conv_out = op.output(0) + kernel_vec = conv_out.op.input_tensors[1] + data_vec = conv_out.op.input_tensors[0] + data = ( + data_vec.op.input_tensors[0] + if isinstance(data_vec.op, te.tensor.ComputeOp) and "pad" not in data_vec.op.tag + else data_vec + ) + if isinstance(data.op, te.tensor.ComputeOp) and "pad" in data.op.tag: + data_pad = data + data = data_pad.op.input_tensors[0] + + # int8 conv kernel is 7-dim + _, _, kh, kw, _, _, n_elems = get_const_tuple(kernel_vec.shape) + + assert n_elems == 4 + + _, _, _, _, ic_bn = get_const_tuple(data_vec.shape) + _, _, _, _, oc_bn = get_const_tuple(conv_out.shape) + + # schedule pad + if isinstance(s[data_vec].op, te.tensor.ComputeOp) and "pad" in data_vec.op.tag: + batch, ic_chunk, ih, iw, ic_block = s[data_vec].op.axis + parallel_axis = s[data_vec].fuse(batch, ic_chunk, ih) + s[data_vec].parallel(parallel_axis) + data_vec = data_vec.op.input_tensors[0] + + if autotvm.GLOBAL_SCOPE.in_tuning: + # only in autotuning, input data of conv2d_NCHWc will be 4-D. + # skip this part during tuning to make records accurate. + # this part will be folded during Relay fold_constant pass. + if isinstance(data_vec.op, te.tensor.ComputeOp): + s[data_vec].pragma(s[data_vec].op.axis[0], "debug_skip_region") + if isinstance(kernel_vec.op, te.tensor.ComputeOp): + s[kernel_vec].pragma(s[kernel_vec].op.axis[0], "debug_skip_region") + elif isinstance(kernel_vec.op, te.tensor.ComputeOp) and kernel_vec.name == "kernel_vec": + # data and kernel are not pre-computed, schedule layout transform here. + # this should only be used by x86 conv2d_nchw, which is for + # testing purpose. + batch, ic_chunk, ih, _, ic_block = s[data_vec].op.axis + parallel_axis = s[data_vec].fuse(batch, ic_chunk, ih) + s[data_vec].parallel(parallel_axis) + + # conv2d_nchwc_int8 has 7D kernel + oc_chunk, ic_chunk, oh, ow, ic_block, oc_block, _ = s[kernel_vec].op.axis + s[kernel_vec].reorder(oc_chunk, oh, ic_chunk, ow, ic_block, oc_block) + oc_bn = cfg["tile_oc"].size[-1] + if oc_bn > 1: + s[kernel_vec].vectorize(oc_block) + parallel_axis = s[kernel_vec].fuse(oc_chunk, oh) + s[kernel_vec].parallel(parallel_axis) + + # schedule 5-D NCHW[x]c conv + C, O = conv_out, outs[0] + CC = s.cache_write(C, "global") + + batch, oc_chunk, oh, ow, oc_block = s[C].op.axis + + if kh == 1 and kw == 1: + oh_factor, ow_factor = cfg["tile_oh"].val, cfg["tile_ow"].size[-1] + oh_outer, oh_inner = s[C].split(oh, factor=oh_factor) + ow_outer, ow_inner = s[C].split(ow, factor=ow_factor) + s[C].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) + s[C].vectorize(oc_block) + + parallel_axis = s[C].fuse(batch, oc_chunk, oh_outer) + if C == O: + s[C].parallel(parallel_axis) + s[CC].compute_at(s[C], parallel_axis) + + _, oc_chunk, oh, ow, oc_block = s[CC].op.axis + kh, kw, ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis + + oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=4) + + oh_outer, oh_inner = s[CC].split(oh, factor=oh_factor) + ow_outer, ow_inner = s[CC].split(ow, factor=ow_factor) + + s[CC].reorder( + oc_chunk, + oh_outer, + ow_outer, + kh, + kw, + ic_outer, + ic_f_inner, + oh_inner, + ow_inner, + oc_f_inner, + oc_s_inner, + ic_s_inner, + ) + s[CC].fuse(oc_chunk, oh_outer) + + s[CC].tensorize(oc_s_inner, dot_int8_int8_int32()) + s[CC].pragma(oc_f_inner, "import_c", int8_conv2d_impl()) + + s[CC].unroll(ow_inner) + s[CC].unroll(oh_inner) + + if C != O: + out_ndim = len(s[O].op.axis) + if out_ndim == 5: + batch, oc_chunk, oh, ow, oc_block = s[O].op.axis + oh_outer, oh_inner = s[O].split(oh, factor=oh_factor) + ow_outer, ow_inner = s[O].split(ow, factor=ow_factor) + elif out_ndim == 4: + batch, oc, oh, ow = s[O].op.axis + oc_chunk, oc_block = s[O].split(oc, factor=oc_bn) + oh_outer, oh_inner = s[O].split(oh, factor=oh_factor) + ow_outer, ow_inner = s[O].split(ow, factor=ow_factor) + else: + raise ValueError("Unsupported output ndim: %s" % out_ndim) + + s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) + parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer) + if inline_fused: + s[C].compute_at(s[O], ow_inner) + else: + s[C].compute_at(s[O], parallel_axis) + + s[O].vectorize(oc_block) + s[O].parallel(parallel_axis) + else: + if isinstance(cfg["tile_ow"], int): + reg_n = cfg["tile_ow"] + else: + reg_n = cfg["tile_ow"].size[-1] + + if isinstance(cfg["unroll_kw"], (int, bool)): + unroll_kw = cfg["unroll_kw"] + else: + unroll_kw = cfg["unroll_kw"].val + + ow_chunk, ow_block = s[C].split(ow, factor=reg_n) + s[C].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) + parallel_axis = s[C].fuse(batch, oc_chunk, oh) + s[C].vectorize(oc_block) + if C == O: + s[C].parallel(parallel_axis) + + s[CC].compute_at(s[C], parallel_axis) + _, oc_chunk, oh, ow, oc_block = s[CC].op.axis + kh, kw, ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis + + ow_chunk, ow_block = s[CC].split(ow, factor=reg_n) + + assert oc_bn % 4 == 0, f"oc_bn={oc_bn} % int32_lanes={4} != 0" + assert ( + ic_bn % 4 == 0 + ), f"ic_bn={ic_bn} % int8_elems={4} != 0" # (u)int8 elements in (u)int32 + + oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=4) + + if unroll_kw: + s[CC].reorder( + oc_chunk, + oh, + ow_chunk, + ic_outer, + kh, + ic_f_inner, + kw, + ow_block, + oc_f_inner, + oc_s_inner, + ic_s_inner, + ) + s[CC].unroll(kw) + else: + s[CC].reorder( + oc_chunk, + oh, + ow_chunk, + ic_outer, + kh, + kw, + ic_f_inner, + ow_block, + oc_f_inner, + oc_s_inner, + ic_s_inner, + ) + + s[CC].tensorize(oc_s_inner, dot_int8_int8_int32()) + s[CC].pragma(oc_f_inner, "import_c", int8_conv2d_impl()) + + s[CC].unroll(ow_block) + s[CC].unroll(oc_f_inner) + + if C != O: + out_ndim = len(s[O].op.axis) + if out_ndim == 5: + batch, oc_chunk, oh, ow, oc_block = s[O].op.axis + ow_chunk, ow_block = s[O].split(ow, factor=reg_n) + s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) + elif out_ndim == 4: + batch, oc, oh, ow = s[O].op.axis + ow_chunk, ow_block = s[O].split(ow, factor=reg_n) + oc_chunk, oc_block = s[O].split(oc, factor=oc_bn) + s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) + else: + raise ValueError("Unsupported output ndim: %s" % out_ndim) + parallel_axis = s[O].fuse(batch, oc_chunk, oh) + if inline_fused: + s[C].compute_at(s[O], ow_block) + else: + s[C].compute_at(s[O], parallel_axis) + s[O].vectorize(oc_block) + s[O].parallel(parallel_axis) + + traverse_inline(s, outs[0].op, _callback) + return s + + +def conv2d_nchw_int8(data, kernel, strides, padding, dilation, out_dtype): + """Compute conv2d with NCHW layout and int8 dtype""" + layout = "NCHW" + # pylint: disable=no-value-for-parameter + packed_out = conv2d_NCHWc_int8( + data, kernel, strides, padding, dilation, layout, layout, out_dtype + ) + return unpack_NCHWc_to_nchw(packed_out, out_dtype) + + +def schedule_conv2d_nchw_int8(outs): + """Create the schedule for conv2d_nchw_int8""" + # pylint: disable=no-value-for-parameter + return schedule_conv2d_NCHWc_int8(outs) diff --git a/python/tvm/topi/riscv_cpu/injective.py b/python/tvm/topi/riscv_cpu/injective.py new file mode 100644 index 000000000000..d9b82fea4663 --- /dev/null +++ b/python/tvm/topi/riscv_cpu/injective.py @@ -0,0 +1,85 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-variable +"""Schedule for pooling operators""" +from tvm import te +from ..utils import is_empty_shape + + +def schedule_injective_from_existing(sch, out): + """Schedule for injective op from existing schedule. + Parameters + ---------- + sch: Schedule + The schedule to update. + out: Tensor + The tensor representing the injective op. + Returns + ------- + sch: Schedule + The updated schedule. + """ + if len(sch[out].op.axis) >= 5: + fused = sch[out].fuse(sch[out].op.axis[0], sch[out].op.axis[1], sch[out].op.axis[2]) + sch[out].parallel(fused) + elif len(sch[out].op.axis) >= 3: + fused = sch[out].fuse(sch[out].op.axis[0], sch[out].op.axis[1]) + sch[out].parallel(fused) + elif len(sch[out].op.axis) >= 1: + sch[out].parallel(sch[out].op.axis[0]) + + # Vectorize the inner most for loop. Tiling first to get a const extent + if len(sch[out].op.axis) >= 1: + l = sch[out].op.axis[-1] + lo, li = sch[out].split(l, factor=16) + sch[out].vectorize(li) + + # for 1D loop, the above split will break the parallel axis + # Need to make the outer loop parallel again + if len(sch[out].op.axis) == 1: + sch[out].parallel(lo) + + return sch + + +def schedule_injective(outs): + """RISCV schedule for injective op. + Parameters + ---------- + outs: Array of Tensor + The computation graph description of injective in the format + of an array of tensors. + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + te.schedule.AutoInlineInjective(s) + for x in outs: + # Vectorize "ADD32" operation. + if "add" in x.name: + is_int32 = x.op.input_tensors[0].dtype == 'int32' and x.op.input_tensors[1].dtype == 'int32' + is_even = x.shape[-1] % 2 == 0 + if is_int32 and is_even: + outer, inner = s[x].split(x.op.axis[-1], 2) + s[x].vectorize(inner) + elif not is_empty_shape(x.shape): + schedule_injective_from_existing(s, x) + + return s diff --git a/python/tvm/topi/riscv_cpu/tensor_intrin.py b/python/tvm/topi/riscv_cpu/tensor_intrin.py new file mode 100644 index 000000000000..b2594bb6756c --- /dev/null +++ b/python/tvm/topi/riscv_cpu/tensor_intrin.py @@ -0,0 +1,218 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-variable,unused-argument,no-member +"""Conv2D int8 schedule on RISCV""" + +import tvm +from tvm import te +from tvm.contrib import clang + + +def dot_int8_int8_int32(): + int32_lanes = 4 # 4 int32 lanes = 128. + num_int8_elements = 4 # 4 int8 elements in int32. + data = te.placeholder((num_int8_elements,), dtype="uint8", name="data") + kernel = te.placeholder((int32_lanes, num_int8_elements), dtype="int8", name="kernel") + k = te.reduce_axis((0, num_int8_elements), name="k") + C = te.compute( + (int32_lanes,), + lambda i: te.sum(data[k].astype("int32") * kernel[i, k].astype("int32"), axis=k), + name="C", + ) + + a_buffer = tvm.tir.decl_buffer( + data.shape, dtype="uint8", name="a_buffer", offset_factor=1, strides=[1] + ) + b_buffer = tvm.tir.decl_buffer( + kernel.shape, dtype="int8", name="b_buffer", offset_factor=1, strides=[te.var("ldw"), 1] + ) + + def _intrin_func(ins, outs): + def _body(): + ib = tvm.tir.ir_builder.create() + ib.emit( + tvm.tir.call_extern( + outs[0].dtype, + "dot_uint8_int8_int32_body", + ins[0].access_ptr("r"), + ins[1].access_ptr("r"), + outs[0].access_ptr("w"), + ) + ) + return ib.get() + + def _reset(): + ib = tvm.tir.ir_builder.create() + ib.emit( + tvm.tir.call_extern( + outs[0].dtype, + "dot_uint8_int8_int32_reset", + outs[0].access_ptr("w") + ) + ) + return ib.get() + + def _reduce_update(): + ib = tvm.tir.ir_builder.create() + ib.emit( + tvm.tir.call_extern( + outs[0].dtype, + "dot_uint8_int8_int32_update", + ins[0].access_ptr("r"), + ins[1].access_ptr("r"), + outs[0].access_ptr("w"), + ) + ) + return ib.get() + + return _body(), _reset(), _reduce_update() + + buffer_params = {"offset_factor": 1} + return te.decl_tensor_intrin( + C.op, + _intrin_func, + binds={data: a_buffer, kernel: b_buffer}, + default_buffer_params=buffer_params, + ) + +def int8_conv2d_impl(): + """Emit C or IR code for conv2d impl.""" + cc_code = f""" +#ifndef TVM_RISCV_CONV2D_INT8 +#define TVM_RISCV_CONV2D_INT8 +#include +#include + +#ifdef __cplusplus +extern "C" +#endif +int32_t dot_uint8_int8_int32_reset(int32_t* res) {{ + // In this function, we set all values in the output array to 0. + for (uint32_t i = 0; i < 4; ++i) + res[i] = 0; + return 0; +}} + +// In this function we need to multiply two vectors, then in the resulting vector calculate the sum of +// all its elements and store it to output array. +// +// Example without vectorization: +// for (unsigned i = 0; i < 4; ++i) {{ +// output[i] = 0; +// for (unsigned j = 0; j < 4; ++j) {{ +// output[i] += data[k] * kernel[i][j]; +// }} +// }} + +#ifdef __cplusplus +extern "C" +#endif +int32_t dot_uint8_int8_int32_body(uint8_t* data, int8_t* kernel, int32_t* output) {{ + // Load values from data into a vector. + vuint8mf2_t v_data = vle8_v_u8mf2(data, -1); + + // Dummy vector for operations. + vint32m1_t empty; + + // Load values from kernel[i][*], + // then we multiply two vectors with type extension: + // v_mul_[i] = v_kernel_[i] * v_data, + // then we count the sum in resulting vector: + // v_sum_[i] = sum(v_mul_[i]) + + vint8mf2_t v_kernel_1 = vle8_v_i8mf2(&kernel[0], -1); + vint16m1_t v_mul_1 = vwmulsu_vv_i16m1(v_kernel_1, v_data, -1); + vint32m1_t v_sum_1 = vwredsum(empty, v_mul_1, empty, 4); + + vint8mf2_t v_kernel_2 = vle8_v_i8mf2(&kernel[4], -1); + vint16m1_t v_mul_2 = vwmulsu_vv_i16m1(v_kernel_2, v_data, -1); + vint32m1_t v_sum_2 = vwredsum(empty, v_mul_2, empty, 4); + + vint8mf2_t v_kernel_3 = vle8_v_i8mf2(&kernel[8], -1); + vint16m1_t v_mul_3 = vwmulsu_vv_i16m1(v_kernel_3, v_data, -1); + vint32m1_t v_sum_3 = vwredsum(empty, v_mul_3, empty, 4); + + vint8mf2_t v_kernel_4 = vle8_v_i8mf2(&kernel[12], -1); + vint16m1_t v_mul_4 = vwmulsu_vv_i16m1(v_kernel_4, v_data, -1); + vint32m1_t v_sum_4 = vwredsum(empty, v_mul_4, empty, 4); + + // Save new values to output. + output[0] = vmv_x_s_i32m1_i32(v_sum_1); + output[1] = vmv_x_s_i32m1_i32(v_sum_2); + output[2] = vmv_x_s_i32m1_i32(v_sum_3); + output[3] = vmv_x_s_i32m1_i32(v_sum_4); + + return 0; +}} + +// In this function we need to multiply two vectors, then in the resulting vector calculate the sum of +// all its elements and add it to the value from output. +// +// Example without vectorization: +// for (unsigned i = 0; i < 4; ++i) +// for (unsigned j = 0; j < 4; ++j) +// output[i] += data[k] * kernel[i][j]; + +#ifdef __cplusplus +extern "C" +#endif +int32_t dot_uint8_int8_int32_update(uint8_t* data, int8_t* kernel, int32_t* output) {{ + // Load values from data into a vector. + vuint8mf2_t v_data = vle8_v_u8mf2(data, -1); + + // Dummy vector for operations. + vint32m1_t empty; + + // Load values from output into vectors. + vint32m1_t v_output_1 = vle32_v_i32m1(&output[0], -1); + vint32m1_t v_output_2 = vle32_v_i32m1(&output[1], -1); + vint32m1_t v_output_3 = vle32_v_i32m1(&output[2], -1); + vint32m1_t v_output_4 = vle32_v_i32m1(&output[3], -1); + + // Load values from kernel[i][*], + // then we multiply two vectors with type extension: + // v_mul_[i] = v_kernel_[i] * v_data, + // then we count the sum in resulting vector and add value from output: + // v_sum_[i] = v_output_[i] + sum(v_mul_[i]) + + vint8mf2_t v_kernel_1 = vle8_v_i8mf2(&kernel[0], -1); + vint16m1_t v_mul_1 = vwmulsu_vv_i16m1(v_kernel_1, v_data, -1); + vint32m1_t v_sum_1 = vwredsum(empty, v_mul_1, v_output_1, 4); + + vint8mf2_t v_kernel_2 = vle8_v_i8mf2(&kernel[4], -1); + vint16m1_t v_mul_2 = vwmulsu_vv_i16m1(v_kernel_2, v_data, -1); + vint32m1_t v_sum_2 = vwredsum(empty, v_mul_2, v_output_2, 4); + + vint8mf2_t v_kernel_3 = vle8_v_i8mf2(&kernel[8], -1); + vint16m1_t v_mul_3 = vwmulsu_vv_i16m1(v_kernel_3, v_data, -1); + vint32m1_t v_sum_3 = vwredsum(empty, v_mul_3, v_output_3, 4); + + vint8mf2_t v_kernel_4 = vle8_v_i8mf2(&kernel[12], -1); + vint16m1_t v_mul_4 = vwmulsu_vv_i16m1(v_kernel_4, v_data, -1); + vint32m1_t v_sum_4 = vwredsum(empty, v_mul_4, v_output_4, 4); + + // Save updated values to output. + output[0] = vmv_x_s_i32m1_i32(v_sum_1); + output[1] = vmv_x_s_i32m1_i32(v_sum_2); + output[2] = vmv_x_s_i32m1_i32(v_sum_3); + output[3] = vmv_x_s_i32m1_i32(v_sum_4); + + return 0; +}} +#endif + """ + return cc_code diff --git a/tests/python/relay/aot/riscv.mk b/tests/python/relay/aot/riscv.mk new file mode 100644 index 000000000000..de803d3bc6bf --- /dev/null +++ b/tests/python/relay/aot/riscv.mk @@ -0,0 +1,85 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +AOT_ROOT ?= $(CRT_ROOT)/aot + +ENABLE_TVM_PLATFORM_ABORT_BACKTRACE = 0 +DMLC_CORE=$(TVM_ROOT)/3rdparty/dmlc-core +TOOLCHAIN_PATH=$(shell dirname $(shell which riscv64-unknown-linux-gnu-gcc))/.. + +CC = clang +CXX = clang++ + +TARGET_CFLAGS = --target=riscv64-unknown-linux-gnu -march=rv64gcv -static +RUNNER = spike +RUNNER_OPT = --isa=rv64gcv $(shell which pk) + +PKG_CFLAGS = ${PKG_COMPILE_OPTS} ${TARGET_CFLAGS} -O2 \ + -I$(build_dir)/../include \ + -I$(CODEGEN_ROOT)/host/include \ + -include $(CODEGEN_ROOT)/host/include/tvmgen_default.h \ + -isystem$(STANDALONE_CRT_DIR)/include \ + --sysroot=$(TOOLCHAIN_PATH)/sysroot \ + --gcc-toolchain=$(TOOLCHAIN_PATH) + +$(ifeq VERBOSE,1) +QUIET ?= +$(else) +QUIET ?= @ +$(endif) + +aot_test_runner: $(build_dir)/aot_test_runner + +c_source_libs = $(wildcard $(build_dir)/../codegen/host/src/*.c) +cc_source_libs = $(wildcard $(build_dir)/../codegen/host/src/*.cc) +c_lib_objs = $(addprefix $(build_dir)/, $(notdir $(c_source_libs:.c=.o))) +cc_lib_objs = $(cc_source_libs:.cc=.o) + +$(build_dir)/aot_test_runner: $(build_dir)/test.c $(c_lib_objs) $(cc_lib_objs) $(build_dir)/stack_allocator.o $(build_dir)/crt_backend_api.o + $(QUIET)mkdir -p $(@D) + $(QUIET)$(CC) $(CFLAGS) $(PKG_CFLAGS) -o $@ $^ $(PKG_LDFLAGS) $(BACKTRACE_LDFLAGS) $(BACKTRACE_CFLAGS) -lm + +$(build_dir)/%.o: $(build_dir)/../codegen/host/src/%.c + $(QUIET)mkdir -p $(@D) + $(QUIET)$(CC) $(CFLAGS) -c $(PKG_CFLAGS) -o $@ $< $(BACKTRACE_CFLAGS) + +$(build_dir)/%.o: $(build_dir)/../codegen/host/src/%.cc + $(QUIET)mkdir -p $(@D) + $(QUIET)$(CXX) $(CFLAGS) -c $(PKG_CFLAGS) -o $@ $< $(BACKTRACE_CFLAGS) + +$(build_dir)/stack_allocator.o: $(STANDALONE_CRT_DIR)/src/runtime/crt/memory/stack_allocator.c + $(QUIET)mkdir -p $(@D) + $(QUIET)$(CC) $(CFLAGS) -c $(PKG_CFLAGS) -o $@ $^ $(BACKTRACE_CFLAGS) + +$(build_dir)/crt_backend_api.o: $(STANDALONE_CRT_DIR)/src/runtime/crt/common/crt_backend_api.c + $(QUIET)mkdir -p $(@D) + $(QUIET)$(CC) $(CFLAGS) -c $(PKG_CFLAGS) -o $@ $^ $(BACKTRACE_CFLAGS) + +clean: + $(QUIET)rm -rf $(build_dir)/crt +cleanall: + $(QUIET)rm -rf $(build_dir) + +run: $(build_dir)/aot_test_runner + $(RUNNER) $(RUNNER_OPT) $(build_dir)/aot_test_runner + +# # Don't define implicit rules; they tend to match on logical target names that aren't targets (i.e. bundle_static) +.SUFFIXES: + +.DEFAULT: aot_test_runner + +.PHONY: run diff --git a/tests/python/relay/strategy/riscv_cpu/test_conv2d_int8_nchw.py b/tests/python/relay/strategy/riscv_cpu/test_conv2d_int8_nchw.py new file mode 100644 index 000000000000..0980b9741577 --- /dev/null +++ b/tests/python/relay/strategy/riscv_cpu/test_conv2d_int8_nchw.py @@ -0,0 +1,134 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np +import re +import tvm +import tvm.testing +from tvm import relay +from tvm.testing.aot import AOTTestModel, compile_and_run, generate_ref_data +from tvm.micro.testing.aot_test_utils import AOTTestRunner + + +class RISCVConv2dInt8: + @tvm.testing.requires_riscv_spike + def test_conv2d_int8( + self, + data_shape, + kernel_size, + data_layout, + kernel_layout, + num_filter, + padding, + dtype, + wtype, + schedule_name, + ): + weight_shape = (num_filter, data_shape[1], *kernel_size) + + data = relay.var("input", shape=data_shape, dtype=dtype) + + if "int" in wtype: + min_w_value = np.iinfo(wtype).min + max_w_value = np.iinfo(wtype).max + else: + min_w_value = np.finfo(wtype).min + max_w_value = np.finfo(wtype).max + + weight_data = np.random.randint(low=min_w_value, high=max_w_value, size=weight_shape, dtype=wtype) + weight = relay.const(weight_data) + + func = relay.qnn.op.conv2d( + data, + weight, + relay.const(1, "int32"), + relay.const(1, "int32"), + relay.const(1, "float32"), + relay.const(1, "float32"), + channels=weight_shape[0], + kernel_size=kernel_size, + padding=padding, + data_layout=data_layout, + kernel_layout=kernel_layout, + ) + + ref_mod = relay.Function(relay.analysis.free_vars(func), func) + ref_mod = tvm.IRModule.from_expr(ref_mod) + + if "int" in dtype: + min_d_value = np.iinfo(dtype).min + max_d_value = np.iinfo(dtype).max + else: + min_d_value = np.finfo(dtype).min + max_d_value = np.finfo(dtype).max + + inputs = {"input": np.random.randint(low=min_d_value, high=max_d_value, size=data_shape, dtype=dtype)} + + output_list = generate_ref_data(ref_mod, inputs) + + mod = relay.Function(relay.analysis.free_vars(func), func) + mod = tvm.IRModule.from_expr(mod) + + target_opts = { + "-keys": "riscv_cpu", + "-march": "rv64gcv", + } + + def checker(base_path: str) -> bool: + def read_file(path): + with open(path) as f: + return f.read() + + default_lib1 = read_file(base_path + "/codegen/host/src/default_lib1.c") + regex = ( + r"(?s)dot_uint8_int8_int32_update(.*?)" + ) + return re.search(regex, default_lib1) is not None + + assert compile_and_run( + AOTTestModel(module=mod, inputs=inputs, outputs=output_list), + runner=AOTTestRunner(makefile="riscv"), + interface_api="c", + use_unpacked_api=True, + target_opts=target_opts, + schedule_name=schedule_name, + checker=checker + ) + + +class TestConv2d_NCHW(RISCVConv2dInt8): + ( + data_shape, + kernel_size, + num_filter, + ) = tvm.testing.parameters( + ((1, 128, 14, 14), (3, 3), 128), + ((1, 128, 14, 14), (1, 1), 256), + ((1, 256, 7, 7), (1, 1), 512), + ((1, 256, 7, 7), (3, 3), 512), + ((1, 512, 3, 3), (3, 3), 512), + ) + padding = tvm.testing.parameter((1, 1)) + data_layout = tvm.testing.parameter("NCHW") + kernel_layout = tvm.testing.parameter("OIHW") + dtype = tvm.testing.parameter("uint8") + wtype = tvm.testing.parameter("int8") + schedule_name = tvm.testing.parameter("conv2d_int8_NCHW.riscv_cpu") + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/scripts/task_riscv_microtvm.sh b/tests/scripts/task_riscv_microtvm.sh index c597506dfede..db36d295df0a 100755 --- a/tests/scripts/task_riscv_microtvm.sh +++ b/tests/scripts/task_riscv_microtvm.sh @@ -25,3 +25,5 @@ make cython3 # NOTE: this exists to ensure some tests run on RISC-V image. Without it, Jenkins reports a configuration error. # This line can be removed when RISC-V tests are added. run_pytest ctypes riscv-platform-minimal-test-0 tests/python/all-platform-minimal-test + +run_pytest ctypes python-relay-strategy-riscv_cpu tests/python/relay/strategy/riscv_cpu From 686a41ab54db60c343e084bfabf74715e510b607 Mon Sep 17 00:00:00 2001 From: katebern-grovety <132359118+katebern-grovety@users.noreply.github.com> Date: Fri, 12 May 2023 14:02:48 +0400 Subject: [PATCH 02/10] lint fix --- python/tvm/relay/op/strategy/riscv_cpu.py | 6 +++-- python/tvm/testing/utils.py | 1 + python/tvm/topi/riscv_cpu/injective.py | 4 +++- python/tvm/topi/riscv_cpu/tensor_intrin.py | 5 ++--- .../riscv_cpu/test_conv2d_int8_nchw.py | 22 +++++++++---------- 5 files changed, 21 insertions(+), 17 deletions(-) diff --git a/python/tvm/relay/op/strategy/riscv_cpu.py b/python/tvm/relay/op/strategy/riscv_cpu.py index 6eb6329dea74..8542dbe17adc 100644 --- a/python/tvm/relay/op/strategy/riscv_cpu.py +++ b/python/tvm/relay/op/strategy/riscv_cpu.py @@ -65,7 +65,7 @@ def conv2d_strategy_riscv_cpu(attrs, inputs, out_type, target): wrap_compute_conv2d(topi.riscv_cpu.conv2d_nchw_int8), wrap_topi_schedule(topi.riscv_cpu.schedule_conv2d_nchw_int8), name="conv2d_nchw_int8.riscv", - plevel=15 + plevel=15, ) return strategy @@ -81,7 +81,9 @@ def conv2d_NCHWc_strategy_riscv_cpu(attrs, inputs, out_type, target): # Vector instructions with int8 show more performance at a larger size. if is_int8 and kernel.shape[1] >= 128: strategy.add_implementation( - wrap_compute_conv2d(topi.riscv_cpu.conv2d_NCHWc_int8, need_data_layout=True, need_out_layout=True), + wrap_compute_conv2d( + topi.riscv_cpu.conv2d_NCHWc_int8, need_data_layout=True, need_out_layout=True + ), wrap_topi_schedule(topi.riscv_cpu.schedule_conv2d_NCHWc_int8), name="conv2d_NCHWc_int8.riscv_cpu", ) diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index bf7d45cfaf53..268e5429c071 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -1015,6 +1015,7 @@ def _riscv_spike_run_time_check(): run_time_check=_riscv_spike_run_time_check, ) + def _arm_dot_supported(): arch = platform.machine() diff --git a/python/tvm/topi/riscv_cpu/injective.py b/python/tvm/topi/riscv_cpu/injective.py index d9b82fea4663..4a2581040b8c 100644 --- a/python/tvm/topi/riscv_cpu/injective.py +++ b/python/tvm/topi/riscv_cpu/injective.py @@ -74,7 +74,9 @@ def schedule_injective(outs): for x in outs: # Vectorize "ADD32" operation. if "add" in x.name: - is_int32 = x.op.input_tensors[0].dtype == 'int32' and x.op.input_tensors[1].dtype == 'int32' + is_int32 = ( + x.op.input_tensors[0].dtype == "int32" and x.op.input_tensors[1].dtype == "int32" + ) is_even = x.shape[-1] % 2 == 0 if is_int32 and is_even: outer, inner = s[x].split(x.op.axis[-1], 2) diff --git a/python/tvm/topi/riscv_cpu/tensor_intrin.py b/python/tvm/topi/riscv_cpu/tensor_intrin.py index b2594bb6756c..a7c6d49a4516 100644 --- a/python/tvm/topi/riscv_cpu/tensor_intrin.py +++ b/python/tvm/topi/riscv_cpu/tensor_intrin.py @@ -59,9 +59,7 @@ def _reset(): ib = tvm.tir.ir_builder.create() ib.emit( tvm.tir.call_extern( - outs[0].dtype, - "dot_uint8_int8_int32_reset", - outs[0].access_ptr("w") + outs[0].dtype, "dot_uint8_int8_int32_reset", outs[0].access_ptr("w") ) ) return ib.get() @@ -89,6 +87,7 @@ def _reduce_update(): default_buffer_params=buffer_params, ) + def int8_conv2d_impl(): """Emit C or IR code for conv2d impl.""" cc_code = f""" diff --git a/tests/python/relay/strategy/riscv_cpu/test_conv2d_int8_nchw.py b/tests/python/relay/strategy/riscv_cpu/test_conv2d_int8_nchw.py index 0980b9741577..a5bce822084a 100644 --- a/tests/python/relay/strategy/riscv_cpu/test_conv2d_int8_nchw.py +++ b/tests/python/relay/strategy/riscv_cpu/test_conv2d_int8_nchw.py @@ -49,7 +49,9 @@ def test_conv2d_int8( min_w_value = np.finfo(wtype).min max_w_value = np.finfo(wtype).max - weight_data = np.random.randint(low=min_w_value, high=max_w_value, size=weight_shape, dtype=wtype) + weight_data = np.random.randint( + low=min_w_value, high=max_w_value, size=weight_shape, dtype=wtype + ) weight = relay.const(weight_data) func = relay.qnn.op.conv2d( @@ -76,7 +78,11 @@ def test_conv2d_int8( min_d_value = np.finfo(dtype).min max_d_value = np.finfo(dtype).max - inputs = {"input": np.random.randint(low=min_d_value, high=max_d_value, size=data_shape, dtype=dtype)} + inputs = { + "input": np.random.randint( + low=min_d_value, high=max_d_value, size=data_shape, dtype=dtype + ) + } output_list = generate_ref_data(ref_mod, inputs) @@ -94,9 +100,7 @@ def read_file(path): return f.read() default_lib1 = read_file(base_path + "/codegen/host/src/default_lib1.c") - regex = ( - r"(?s)dot_uint8_int8_int32_update(.*?)" - ) + regex = r"(?s)dot_uint8_int8_int32_update(.*?)" return re.search(regex, default_lib1) is not None assert compile_and_run( @@ -106,16 +110,12 @@ def read_file(path): use_unpacked_api=True, target_opts=target_opts, schedule_name=schedule_name, - checker=checker + checker=checker, ) class TestConv2d_NCHW(RISCVConv2dInt8): - ( - data_shape, - kernel_size, - num_filter, - ) = tvm.testing.parameters( + (data_shape, kernel_size, num_filter,) = tvm.testing.parameters( ((1, 128, 14, 14), (3, 3), 128), ((1, 128, 14, 14), (1, 1), 256), ((1, 256, 7, 7), (1, 1), 512), From 4127e9ff6be12b6ca41179395c3e817f56271fbd Mon Sep 17 00:00:00 2001 From: katebern-grovety <132359118+katebern-grovety@users.noreply.github.com> Date: Fri, 12 May 2023 15:28:21 +0400 Subject: [PATCH 03/10] lint fix --- tests/python/relay/aot/riscv.mk | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/relay/aot/riscv.mk b/tests/python/relay/aot/riscv.mk index de803d3bc6bf..c8a91e0390b5 100644 --- a/tests/python/relay/aot/riscv.mk +++ b/tests/python/relay/aot/riscv.mk @@ -33,8 +33,8 @@ PKG_CFLAGS = ${PKG_COMPILE_OPTS} ${TARGET_CFLAGS} -O2 \ -I$(CODEGEN_ROOT)/host/include \ -include $(CODEGEN_ROOT)/host/include/tvmgen_default.h \ -isystem$(STANDALONE_CRT_DIR)/include \ - --sysroot=$(TOOLCHAIN_PATH)/sysroot \ - --gcc-toolchain=$(TOOLCHAIN_PATH) + --sysroot=$(TOOLCHAIN_PATH)/sysroot \ + --gcc-toolchain=$(TOOLCHAIN_PATH) $(ifeq VERBOSE,1) QUIET ?= From fe62428cc59e2df251b998b456bedbe16da31d0a Mon Sep 17 00:00:00 2001 From: katebern-grovety <132359118+katebern-grovety@users.noreply.github.com> Date: Fri, 12 May 2023 16:47:34 +0400 Subject: [PATCH 04/10] lint fix --- python/tvm/topi/riscv_cpu/tensor_intrin.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/tvm/topi/riscv_cpu/tensor_intrin.py b/python/tvm/topi/riscv_cpu/tensor_intrin.py index a7c6d49a4516..bd31dc8f8fec 100644 --- a/python/tvm/topi/riscv_cpu/tensor_intrin.py +++ b/python/tvm/topi/riscv_cpu/tensor_intrin.py @@ -127,7 +127,7 @@ def int8_conv2d_impl(): // Dummy vector for operations. vint32m1_t empty; - // Load values from kernel[i][*], + // Load values from kernel[i][*], // then we multiply two vectors with type extension: // v_mul_[i] = v_kernel_[i] * v_data, // then we count the sum in resulting vector: @@ -148,7 +148,7 @@ def int8_conv2d_impl(): vint8mf2_t v_kernel_4 = vle8_v_i8mf2(&kernel[12], -1); vint16m1_t v_mul_4 = vwmulsu_vv_i16m1(v_kernel_4, v_data, -1); vint32m1_t v_sum_4 = vwredsum(empty, v_mul_4, empty, 4); - + // Save new values to output. output[0] = vmv_x_s_i32m1_i32(v_sum_1); output[1] = vmv_x_s_i32m1_i32(v_sum_2); @@ -172,17 +172,17 @@ def int8_conv2d_impl(): int32_t dot_uint8_int8_int32_update(uint8_t* data, int8_t* kernel, int32_t* output) {{ // Load values from data into a vector. vuint8mf2_t v_data = vle8_v_u8mf2(data, -1); - + // Dummy vector for operations. vint32m1_t empty; - + // Load values from output into vectors. vint32m1_t v_output_1 = vle32_v_i32m1(&output[0], -1); vint32m1_t v_output_2 = vle32_v_i32m1(&output[1], -1); vint32m1_t v_output_3 = vle32_v_i32m1(&output[2], -1); vint32m1_t v_output_4 = vle32_v_i32m1(&output[3], -1); - // Load values from kernel[i][*], + // Load values from kernel[i][*], // then we multiply two vectors with type extension: // v_mul_[i] = v_kernel_[i] * v_data, // then we count the sum in resulting vector and add value from output: @@ -203,7 +203,7 @@ def int8_conv2d_impl(): vint8mf2_t v_kernel_4 = vle8_v_i8mf2(&kernel[12], -1); vint16m1_t v_mul_4 = vwmulsu_vv_i16m1(v_kernel_4, v_data, -1); vint32m1_t v_sum_4 = vwredsum(empty, v_mul_4, v_output_4, 4); - + // Save updated values to output. output[0] = vmv_x_s_i32m1_i32(v_sum_1); output[1] = vmv_x_s_i32m1_i32(v_sum_2); From 9bb118c87daea897ee1e6f21c2c59965cd8dc0b4 Mon Sep 17 00:00:00 2001 From: katebern-grovety <132359118+katebern-grovety@users.noreply.github.com> Date: Fri, 12 May 2023 17:49:02 +0400 Subject: [PATCH 05/10] lint fix, update comments --- python/tvm/relay/op/strategy/riscv_cpu.py | 1 - python/tvm/topi/riscv_cpu/tensor_intrin.py | 8 ++++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/op/strategy/riscv_cpu.py b/python/tvm/relay/op/strategy/riscv_cpu.py index 8542dbe17adc..7853d193757a 100644 --- a/python/tvm/relay/op/strategy/riscv_cpu.py +++ b/python/tvm/relay/op/strategy/riscv_cpu.py @@ -18,7 +18,6 @@ # pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import """Definition of RISCV CPU operator strategy.""" -from functools import reduce import logging from tvm import topi diff --git a/python/tvm/topi/riscv_cpu/tensor_intrin.py b/python/tvm/topi/riscv_cpu/tensor_intrin.py index bd31dc8f8fec..8f5578d690b7 100644 --- a/python/tvm/topi/riscv_cpu/tensor_intrin.py +++ b/python/tvm/topi/riscv_cpu/tensor_intrin.py @@ -15,14 +15,14 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name,unused-variable,unused-argument,no-member -"""Conv2D int8 schedule on RISCV""" +"""Core kernel of dot product of 4 Int8 operations""" import tvm from tvm import te -from tvm.contrib import clang def dot_int8_int8_int32(): + """Int8 dot product by every 4 elements using RISC-V V-extension instructions.""" int32_lanes = 4 # 4 int32 lanes = 128. num_int8_elements = 4 # 4 int8 elements in int32. data = te.placeholder((num_int8_elements,), dtype="uint8", name="data") @@ -89,8 +89,8 @@ def _reduce_update(): def int8_conv2d_impl(): - """Emit C or IR code for conv2d impl.""" - cc_code = f""" + """Emit C code for dot product impl.""" + cc_code = """ #ifndef TVM_RISCV_CONV2D_INT8 #define TVM_RISCV_CONV2D_INT8 #include From 61ee1c3f7fe5aa783eaa58a277da737527296f9e Mon Sep 17 00:00:00 2001 From: katebern-grovety <132359118+katebern-grovety@users.noreply.github.com> Date: Mon, 15 May 2023 11:51:47 +0400 Subject: [PATCH 06/10] fix clang version --- tests/python/relay/aot/riscv.mk | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/relay/aot/riscv.mk b/tests/python/relay/aot/riscv.mk index c8a91e0390b5..9282760df7a7 100644 --- a/tests/python/relay/aot/riscv.mk +++ b/tests/python/relay/aot/riscv.mk @@ -21,8 +21,8 @@ ENABLE_TVM_PLATFORM_ABORT_BACKTRACE = 0 DMLC_CORE=$(TVM_ROOT)/3rdparty/dmlc-core TOOLCHAIN_PATH=$(shell dirname $(shell which riscv64-unknown-linux-gnu-gcc))/.. -CC = clang -CXX = clang++ +CC = clang-16 +CXX = clang++-16 TARGET_CFLAGS = --target=riscv64-unknown-linux-gnu -march=rv64gcv -static RUNNER = spike From d2445fce0ebf179e33d436722cf75b931099f3b5 Mon Sep 17 00:00:00 2001 From: katebern-grovety <132359118+katebern-grovety@users.noreply.github.com> Date: Wed, 16 Aug 2023 13:16:28 +0400 Subject: [PATCH 07/10] Add Project API for RISC-V with spike --- CMakeLists.txt | 1 + apps/microtvm/riscv/ReadMe.md | 18 ++ .../template_project/CMakeLists.txt.template | 64 ++++ .../template_project/microtvm_api_server.py | 285 ++++++++++++++++++ .../riscv/template_project/src/main.cc | 106 +++++++ .../riscv/template_project/src/platform.cc | 126 ++++++++ cmake/modules/Micro.cmake | 10 +- python/tvm/micro/build.py | 1 + .../riscv_cpu/test_conv2d_int8_nchw.py | 69 +++-- 9 files changed, 651 insertions(+), 29 deletions(-) create mode 100644 apps/microtvm/riscv/ReadMe.md create mode 100644 apps/microtvm/riscv/template_project/CMakeLists.txt.template create mode 100644 apps/microtvm/riscv/template_project/microtvm_api_server.py create mode 100644 apps/microtvm/riscv/template_project/src/main.cc create mode 100644 apps/microtvm/riscv/template_project/src/platform.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 67e87d907141..5d3e9761076a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -590,6 +590,7 @@ if(USE_MICRO) add_dependencies(tvm_runtime arduino) add_dependencies(tvm_runtime crt) add_dependencies(tvm_runtime host_standalone_crt) + add_dependencies(tvm_runtime riscv) add_dependencies(tvm_runtime zephyr) endif() diff --git a/apps/microtvm/riscv/ReadMe.md b/apps/microtvm/riscv/ReadMe.md new file mode 100644 index 000000000000..5bffc9218a93 --- /dev/null +++ b/apps/microtvm/riscv/ReadMe.md @@ -0,0 +1,18 @@ + + + + + + + + + + + + + + + + + +This directory contains code to interface microTVM with [RISC-V](https://riscv.org/). diff --git a/apps/microtvm/riscv/template_project/CMakeLists.txt.template b/apps/microtvm/riscv/template_project/CMakeLists.txt.template new file mode 100644 index 000000000000..6a25b4e9a5ee --- /dev/null +++ b/apps/microtvm/riscv/template_project/CMakeLists.txt.template @@ -0,0 +1,64 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# SPDX-License-Identifier: Apache-2.0 + +cmake_minimum_required(VERSION 3.18) +set(CMAKE_CXX_STANDARD 11) +set(CMAKE_C_COMPILER "clang") +set(CMAKE_CXX_COMPILER "clang++") +set(FLAGS "--target=${TARGET} --sysroot=${TOOLCHAIN_PATH}/sysroot --gcc-toolchain=${TOOLCHAIN_PATH} -march=${MARCH} -static") +set(CMAKE_C_FLAGS ${FLAGS}) +set(CMAKE_CXX_FLAGS ${FLAGS}) + +project(crt_autogenerated_project C CXX) +add_executable(main) + +set(CRT_LIB_BASE crt/src/runtime/crt) +set(CRT_LIBS microtvm_rpc_server + microtvm_rpc_common + aot_executor_module + aot_executor + graph_executor_module + graph_executor + common + memory +) + +# Build CRT libraries +foreach(crt_lib_name ${CRT_LIBS}) + add_library(${crt_lib_name}) + file(GLOB_RECURSE crt_lib_srcs ${CRT_LIB_BASE}/${crt_lib_name}/*.c ${CRT_LIB_BASE}/${crt_lib_name}/*.cc) + target_sources(${crt_lib_name} PRIVATE ${crt_lib_srcs}) + target_include_directories(${crt_lib_name} PRIVATE crt_config crt/include) + target_compile_definitions(${crt_lib_name} PRIVATE -DTVM_HOST_USE_GRAPH_EXECUTOR_MODULE) + target_link_libraries(main PRIVATE ${crt_lib_name}) +endforeach(crt_lib_name ${CRT_LIBS}) + +# Build model files +add_library(tvm_model) +file(GLOB_RECURSE tvm_model_srcs model/codegen/host/src/*.c model/codegen/host/lib/*.o) +target_sources(tvm_model PRIVATE ${tvm_model_srcs}) +target_include_directories(tvm_model PRIVATE ${CMAKE_SOURCE_DIR}/include crt_config crt/include) +target_compile_options(tvm_model PRIVATE -Wno-error=unused-variable -Wno-error=missing-braces -Wno-error=unused-const-variable -Wno-unused-variable) +set_target_properties(tvm_model PROPERTIES LINKER_LANGUAGE C) +target_link_libraries(main PRIVATE tvm_model) + +file(GLOB_RECURSE app_srcs src/**.cc) +target_sources(main PRIVATE ${app_srcs} ${cmsis_lib_srcs}) +target_compile_definitions(main PRIVATE -DTVM_HOST_USE_GRAPH_EXECUTOR_MODULE) +target_include_directories(main PRIVATE crt_config include ${CMAKE_SOURCE_DIR}/include crt/include) diff --git a/apps/microtvm/riscv/template_project/microtvm_api_server.py b/apps/microtvm/riscv/template_project/microtvm_api_server.py new file mode 100644 index 000000000000..1495073d8583 --- /dev/null +++ b/apps/microtvm/riscv/template_project/microtvm_api_server.py @@ -0,0 +1,285 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import fcntl +import os +import os.path +import pathlib +import select +import shutil +import subprocess +import tarfile +import time +import re + +from tvm.micro.project_api import server + + +PROJECT_DIR = pathlib.Path(os.path.dirname(__file__) or os.path.getcwd()) + + +MODEL_LIBRARY_FORMAT_RELPATH = "model.tar" + + +IS_TEMPLATE = not os.path.exists(os.path.join(PROJECT_DIR, MODEL_LIBRARY_FORMAT_RELPATH)) + +# Used this size to pass most CRT tests in TVM. +WORKSPACE_SIZE_BYTES = 2 * 1024 * 1024 + +CMAKEFILE_FILENAME = "CMakeLists.txt" + +# The build target given to make +BUILD_TARGET = "build/main" + + +class Handler(server.ProjectAPIHandler): + BUILD_TARGET = "build/main" + + def __init__(self): + super(Handler, self).__init__() + self._proc = None + + def server_info_query(self, tvm_version): + return server.ServerInfo( + platform_name="host", + is_template=IS_TEMPLATE, + model_library_format_path="" + if IS_TEMPLATE + else PROJECT_DIR / MODEL_LIBRARY_FORMAT_RELPATH, + project_options=[ + server.ProjectOption( + "verbose", + optional=["build"], + type="bool", + default=False, + help="Run make with verbose output", + ), + server.ProjectOption( + "workspace_size_bytes", + optional=["generate_project"], + type="int", + default=WORKSPACE_SIZE_BYTES, + help="Sets the value of TVM_WORKSPACE_SIZE_BYTES.", + ), + server.ProjectOption( + "toolchain_path", + optional=["generate_project"], + type="str", + default="/opt/riscv", + help="Sets the value of toolchain path.", + ), + server.ProjectOption( + "target", + optional=["generate_project"], + type="str", + default="riscv64-unknown-linux-gnu", + help="Sets the value of target.", + ), + server.ProjectOption( + "march", + optional=["generate_project"], + type="str", + default="rv64gc", + help="Sets the value of target architecture.", + ), + ], + ) + + # These files and directories will be recursively copied into generated projects from the CRT. + CRT_COPY_ITEMS = ("include", "CMakeLists.txt", "src") + + def _populate_cmake( + self, + cmakefile_template_path: pathlib.Path, + cmakefile_path: pathlib.Path, + memory_size: int, + verbose: bool, + ): + """Generate CMakeList file from template.""" + + regex = re.compile(r"([A-Z_]+) := (<[A-Z_]+>)") + with open(cmakefile_path, "w") as cmakefile_f: + with open(cmakefile_template_path, "r") as cmakefile_template_f: + for line in cmakefile_template_f: + cmakefile_f.write(line) + cmakefile_f.write( + f"target_compile_definitions(main PUBLIC -DTVM_WORKSPACE_SIZE_BYTES={memory_size})\n" + ) + if verbose: + cmakefile_f.write(f"set(CMAKE_VERBOSE_MAKEFILE TRUE)\n") + + def generate_project(self, model_library_format_path, standalone_crt_dir, project_dir, options): + # Make project directory. + project_dir.mkdir(parents=True) + current_dir = pathlib.Path(__file__).parent.absolute() + + # Copy ourselves to the generated project. TVM may perform further build steps on the generated project + # by launching the copy. + shutil.copy2(__file__, project_dir / os.path.basename(__file__)) + + # Place Model Library Format tarball in the special location, which this script uses to decide + # whether it's being invoked in a template or generated project. + project_model_library_format_path = project_dir / MODEL_LIBRARY_FORMAT_RELPATH + shutil.copy2(model_library_format_path, project_model_library_format_path) + + # Extract Model Library Format tarball.into /model. + extract_path = project_dir / project_model_library_format_path.stem + with tarfile.TarFile(project_model_library_format_path) as tf: + os.makedirs(extract_path) + tf.extractall(path=extract_path) + + # Populate CRT. + crt_path = project_dir / "crt" + os.mkdir(crt_path) + for item in self.CRT_COPY_ITEMS: + src_path = standalone_crt_dir / item + dst_path = crt_path / item + if os.path.isdir(src_path): + shutil.copytree(src_path, dst_path) + else: + shutil.copy2(src_path, dst_path) + + # Populate CMake file + self._populate_cmake( + current_dir / f"{CMAKEFILE_FILENAME}.template", + project_dir / CMAKEFILE_FILENAME, + options.get("workspace_size_bytes", WORKSPACE_SIZE_BYTES), + options.get("verbose"), + ) + + # Populate crt-config.h + crt_config_dir = project_dir / "crt_config" + crt_config_dir.mkdir() + shutil.copy2( + current_dir / "crt_config" / "crt_config.h", + crt_config_dir / "crt_config.h", + ) + + # Populate src/ + src_dir = project_dir / "src" + src_dir.mkdir() + shutil.copy2( + current_dir / "src" / "main.cc", + src_dir / "main.cc", + ) + shutil.copy2( + current_dir / "src" / "platform.cc", + src_dir / "platform.cc", + ) + + def build(self, options): + cmake_args = [] + toolchain_path = options.get("toolchain_path") + target = options.get("target") + march = options.get("march") + cmake_args.append(f"-DTOOLCHAIN_PATH={toolchain_path}") + cmake_args.append(f"-DTARGET={target}") + cmake_args.append(f"-DMARCH={march}") + + build_dir = PROJECT_DIR / "build" + build_dir.mkdir() + subprocess.check_call(["cmake", *cmake_args, ".."], cwd=build_dir) + subprocess.check_call(["make"], cwd=build_dir) + + def flash(self, options): + pass # Flashing does nothing on host. + + def _set_nonblock(self, fd): + flag = fcntl.fcntl(fd, fcntl.F_GETFL) + fcntl.fcntl(fd, fcntl.F_SETFL, flag | os.O_NONBLOCK) + new_flag = fcntl.fcntl(fd, fcntl.F_GETFL) + assert (new_flag & os.O_NONBLOCK) != 0, "Cannot set file descriptor {fd} to non-blocking" + + def open_transport(self, options): + toolchain_path = options.get("toolchain_path") + target = options.get("target") + march = options.get("march") + spike_args = ["spike", f"--isa={march}", os.path.join(toolchain_path, target, "bin/pk")] + self._proc = subprocess.Popen( + spike_args + [self.BUILD_TARGET], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + bufsize=0, + ) + self._set_nonblock(self._proc.stdin.fileno()) + self._set_nonblock(self._proc.stdout.fileno()) + return server.TransportTimeouts( + session_start_retry_timeout_sec=0, + session_start_timeout_sec=0, + session_established_timeout_sec=0, + ) + + def close_transport(self): + if self._proc is not None: + proc = self._proc + self._proc = None + proc.kill() + proc.wait() + + def _await_ready(self, rlist, wlist, timeout_sec=None, end_time=None): + if timeout_sec is None and end_time is not None: + timeout_sec = max(0, end_time - time.monotonic()) + + rlist, wlist, xlist = select.select(rlist, wlist, rlist + wlist, timeout_sec) + if not rlist and not wlist and not xlist: + raise server.IoTimeoutError() + + return True + + def read_transport(self, n, timeout_sec): + if self._proc is None: + raise server.TransportClosedError() + + fd = self._proc.stdout.fileno() + end_time = None if timeout_sec is None else time.monotonic() + timeout_sec + + try: + self._await_ready([fd], [], end_time=end_time) + to_return = os.read(fd, n) + except BrokenPipeError: + to_return = 0 + + if not to_return: + self.close_transport() + raise server.TransportClosedError() + + return to_return + + def write_transport(self, data, timeout_sec): + if self._proc is None: + raise server.TransportClosedError() + + fd = self._proc.stdin.fileno() + end_time = None if timeout_sec is None else time.monotonic() + timeout_sec + + data_len = len(data) + while data: + self._await_ready([], [fd], end_time=end_time) + try: + num_written = os.write(fd, data) + except BrokenPipeError: + num_written = 0 + + if not num_written: + self.close_transport() + raise server.TransportClosedError() + + data = data[num_written:] + + +if __name__ == "__main__": + server.main(Handler()) diff --git a/apps/microtvm/riscv/template_project/src/main.cc b/apps/microtvm/riscv/template_project/src/main.cc new file mode 100644 index 000000000000..0607d4b28719 --- /dev/null +++ b/apps/microtvm/riscv/template_project/src/main.cc @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file main.cc + * \brief main entry point for host subprocess-based CRT + */ +#include +#include +#include +#include +#include +#include + +#include + +#include "crt_config.h" + +#ifdef TVM_HOST_USE_GRAPH_EXECUTOR_MODULE +#include +#endif + +extern "C" { + +ssize_t MicroTVMWriteFunc(void* context, const uint8_t* data, size_t num_bytes) { + ssize_t to_return = write(STDOUT_FILENO, data, num_bytes); + fflush(stdout); + fsync(STDOUT_FILENO); + return to_return; +} +} + +static char** g_argv = NULL; + +int testonly_reset_server(TVMValue* args, int* type_codes, int num_args, TVMValue* out_ret_value, + int* out_ret_tcode, void* resource_handle) { + execvp(g_argv[0], g_argv); + perror("microTVM runtime: error restarting"); + return -1; +} + +int main(int argc, char** argv) { + g_argv = argv; + TVMPlatformInitialize(); + microtvm_rpc_server_t rpc_server = MicroTVMRpcServerInit(&MicroTVMWriteFunc, nullptr); + +#ifdef TVM_HOST_USE_GRAPH_EXECUTOR_MODULE + CHECK_EQ(TVMGraphExecutorModule_Register(), kTvmErrorNoError, + "failed to register GraphExecutor TVMModule"); +#endif + + int error = TVMFuncRegisterGlobal("tvm.testing.reset_server", + (TVMFunctionHandle)&testonly_reset_server, 0); + if (error) { + fprintf( + stderr, + "microTVM runtime: internal error (error#: %x) registering global packedfunc; exiting\n", + error); + return 2; + } + + setbuf(stdin, NULL); + setbuf(stdout, NULL); + + for (;;) { + uint8_t c; + int ret_code = read(STDIN_FILENO, &c, 1); + if (ret_code < 0) { + perror("microTVM runtime: read failed"); + return 2; + } else if (ret_code == 0) { + fprintf(stderr, "microTVM runtime: 0-length read, exiting!\n"); + return 2; + } + uint8_t* cursor = &c; + size_t bytes_to_process = 1; + while (bytes_to_process > 0) { + tvm_crt_error_t err = MicroTVMRpcServerLoop(rpc_server, &cursor, &bytes_to_process); + if (err == kTvmErrorPlatformShutdown) { + break; + } else if (err != kTvmErrorNoError) { + char buf[1024]; + snprintf(buf, sizeof(buf), "microTVM runtime: MicroTVMRpcServerLoop error: %08x", err); + perror(buf); + return 2; + } + } + } + return 0; +} diff --git a/apps/microtvm/riscv/template_project/src/platform.cc b/apps/microtvm/riscv/template_project/src/platform.cc new file mode 100644 index 000000000000..f5af08a9be88 --- /dev/null +++ b/apps/microtvm/riscv/template_project/src/platform.cc @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \brief Implementation of TVMPlatform functions in tvm/runtime/crt/platform.h + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +using namespace std::chrono; + +extern "C" { + +uint8_t memory[TVM_WORKSPACE_SIZE_BYTES]; +MemoryManagerInterface* memory_manager; + +steady_clock::time_point g_microtvm_start_time; +int g_microtvm_timer_running = 0; + +// Called when an internal error occurs and execution cannot continue. +void TVMPlatformAbort(tvm_crt_error_t error_code) { + std::cerr << "TVMPlatformAbort: " << error_code << std::endl; + throw "Aborted"; +} + +// Called by the microTVM RPC server to implement TVMLogf. +size_t TVMPlatformFormatMessage(char* out_buf, size_t out_buf_size_bytes, const char* fmt, + va_list args) { + return vsprintf(out_buf, fmt, args); +} + +// Allocate memory for use by TVM. +tvm_crt_error_t TVMPlatformMemoryAllocate(size_t num_bytes, DLDevice dev, void** out_ptr) { + return memory_manager->Allocate(memory_manager, num_bytes, dev, out_ptr); +} + +// Free memory used by TVM. +tvm_crt_error_t TVMPlatformMemoryFree(void* ptr, DLDevice dev) { + return memory_manager->Free(memory_manager, ptr, dev); +} + +// Start a device timer. +tvm_crt_error_t TVMPlatformTimerStart() { + if (g_microtvm_timer_running) { + std::cerr << "timer already running" << std::endl; + return kTvmErrorPlatformTimerBadState; + } + g_microtvm_start_time = std::chrono::steady_clock::now(); + g_microtvm_timer_running = 1; + return kTvmErrorNoError; +} + +// Stop the running device timer and get the elapsed time (in microseconds). +tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) { + if (!g_microtvm_timer_running) { + std::cerr << "timer not running" << std::endl; + return kTvmErrorPlatformTimerBadState; + } + auto microtvm_stop_time = std::chrono::steady_clock::now(); + std::chrono::microseconds time_span = std::chrono::duration_cast( + microtvm_stop_time - g_microtvm_start_time); + *elapsed_time_seconds = static_cast(time_span.count()) / 1e6; + g_microtvm_timer_running = 0; + return kTvmErrorNoError; +} + +// Platform-specific before measurement call. +tvm_crt_error_t TVMPlatformBeforeMeasurement() { return kTvmErrorNoError; } + +// Platform-specific after measurement call. +tvm_crt_error_t TVMPlatformAfterMeasurement() { return kTvmErrorNoError; } + +static_assert(RAND_MAX >= (1 << 8), "RAND_MAX is smaller than acceptable"); +unsigned int random_seed = 0; +// Fill a buffer with random data. +tvm_crt_error_t TVMPlatformGenerateRandom(uint8_t* buffer, size_t num_bytes) { + if (random_seed == 0) { + random_seed = (unsigned int)time(NULL); + } + for (size_t i = 0; i < num_bytes; ++i) { + int random = rand_r(&random_seed); + buffer[i] = (uint8_t)random; + } + return kTvmErrorNoError; +} + +// Initialize TVM inference. +tvm_crt_error_t TVMPlatformInitialize() { + int status = + PageMemoryManagerCreate(&memory_manager, memory, sizeof(memory), 8 /* page_size_log2 */); + if (status != 0) { + fprintf(stderr, "error initiailizing memory manager\n"); + return kTvmErrorPlatformMemoryManagerInitialized; + } + return kTvmErrorNoError; +} + +} // extern C diff --git a/cmake/modules/Micro.cmake b/cmake/modules/Micro.cmake index d887486d2e98..7172d91f1132 100644 --- a/cmake/modules/Micro.cmake +++ b/cmake/modules/Micro.cmake @@ -64,6 +64,14 @@ if(USE_MICRO) "src/runtime/crt/host CMakeLists.txt.template -> crt" "src/runtime/crt/host **.cc -> crt/src" ) + elseif("${platform}" STREQUAL "riscv") + list( + APPEND + PLATFORM_FILE_COPY_JOBS + "apps/microtvm/riscv/template_project microtvm_api_server.py -> riscv" + "apps/microtvm/riscv/template_project CMakeLists.txt.template -> riscv" + "apps/microtvm/riscv/template_project/src/ *.cc -> riscv/src/" + ) else() message(FATAL_ERROR "${platform} not supported.") endif() @@ -111,7 +119,7 @@ if(USE_MICRO) add_custom_target(${platform} DEPENDS ${platform_template_deps}) endfunction() - set(PLATFORMS crt;zephyr;arduino) + set(PLATFORMS crt;zephyr;arduino;riscv) foreach(platform IN LISTS PLATFORMS) message(STATUS "Add ${platform} template project.") microtvm_add_platform_project_api(${platform}) diff --git a/python/tvm/micro/build.py b/python/tvm/micro/build.py index df7d1fc7196d..c8f9f77665cd 100644 --- a/python/tvm/micro/build.py +++ b/python/tvm/micro/build.py @@ -40,6 +40,7 @@ class MicroTVMTemplateProject(enum.Enum): ZEPHYR = "zephyr" ARDUINO = "arduino" CRT = "crt" + RISCV = "riscv" @classmethod def list(cls): diff --git a/tests/python/relay/strategy/riscv_cpu/test_conv2d_int8_nchw.py b/tests/python/relay/strategy/riscv_cpu/test_conv2d_int8_nchw.py index a5bce822084a..753ea3337e4a 100644 --- a/tests/python/relay/strategy/riscv_cpu/test_conv2d_int8_nchw.py +++ b/tests/python/relay/strategy/riscv_cpu/test_conv2d_int8_nchw.py @@ -15,13 +15,29 @@ # specific language governing permissions and limitations # under the License. +import pathlib + import numpy as np -import re + import tvm import tvm.testing from tvm import relay -from tvm.testing.aot import AOTTestModel, compile_and_run, generate_ref_data -from tvm.micro.testing.aot_test_utils import AOTTestRunner +from tvm.relay.backend import Executor, Runtime +from tvm.testing.aot import generate_ref_data + + +def _make_session(temp_dir, mod): + template_project_dir = pathlib.Path(tvm.micro.get_microtvm_template_projects("riscv")) + options = { + "toolchain_path": "/opt/riscv", + "target": "riscv64-unknown-linux-gnu", + "march": "rv64gcv", + "verbose": 1, + } + project = tvm.micro.generate_project(template_project_dir, mod, temp_dir / "project", options) + project.build() + project.flash() + return tvm.micro.Session(project.transport()) class RISCVConv2dInt8: @@ -36,7 +52,6 @@ def test_conv2d_int8( padding, dtype, wtype, - schedule_name, ): weight_shape = (num_filter, data_shape[1], *kernel_size) @@ -89,33 +104,32 @@ def test_conv2d_int8( mod = relay.Function(relay.analysis.free_vars(func), func) mod = tvm.IRModule.from_expr(mod) - target_opts = { - "-keys": "riscv_cpu", - "-march": "rv64gcv", - } + temp_dir = tvm.contrib.utils.tempdir() + target = "c -keys=riscv_cpu -march=rv64gcv" + target = tvm.target.Target(target, host="c") + runtime = Runtime("crt", {"system-lib": True}) + executor = Executor("aot") + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + factory = tvm.relay.build(mod, target=target, runtime=runtime, executor=executor) - def checker(base_path: str) -> bool: - def read_file(path): - with open(path) as f: - return f.read() - - default_lib1 = read_file(base_path + "/codegen/host/src/default_lib1.c") - regex = r"(?s)dot_uint8_int8_int32_update(.*?)" - return re.search(regex, default_lib1) is not None - - assert compile_and_run( - AOTTestModel(module=mod, inputs=inputs, outputs=output_list), - runner=AOTTestRunner(makefile="riscv"), - interface_api="c", - use_unpacked_api=True, - target_opts=target_opts, - schedule_name=schedule_name, - checker=checker, - ) + def do_test(): + aot_executor = tvm.micro.create_local_aot_executor(sess) + aot_executor.get_input("input").copyfrom(inputs["input"]) + aot_executor.run() + + out = aot_executor.get_output(0) + assert (out.numpy() == output_list["output"]).all() + + with _make_session(temp_dir, factory) as sess: + do_test() class TestConv2d_NCHW(RISCVConv2dInt8): - (data_shape, kernel_size, num_filter,) = tvm.testing.parameters( + ( + data_shape, + kernel_size, + num_filter, + ) = tvm.testing.parameters( ((1, 128, 14, 14), (3, 3), 128), ((1, 128, 14, 14), (1, 1), 256), ((1, 256, 7, 7), (1, 1), 512), @@ -127,7 +141,6 @@ class TestConv2d_NCHW(RISCVConv2dInt8): kernel_layout = tvm.testing.parameter("OIHW") dtype = tvm.testing.parameter("uint8") wtype = tvm.testing.parameter("int8") - schedule_name = tvm.testing.parameter("conv2d_int8_NCHW.riscv_cpu") if __name__ == "__main__": From 7b013218a3197cae57f01dda6c7d16ad8efeb038 Mon Sep 17 00:00:00 2001 From: katebern-grovety <132359118+katebern-grovety@users.noreply.github.com> Date: Wed, 16 Aug 2023 13:17:19 +0400 Subject: [PATCH 08/10] Add select implementation test case --- .../strategy/test_select_implementation.py | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/tests/python/relay/strategy/test_select_implementation.py b/tests/python/relay/strategy/test_select_implementation.py index 2bf1548d41d8..239cc88e8430 100644 --- a/tests/python/relay/strategy/test_select_implementation.py +++ b/tests/python/relay/strategy/test_select_implementation.py @@ -104,6 +104,53 @@ def test_int8_conv2d(target, expected_impl): assert impl.name == expected_impl +@pytest.mark.parametrize( + "target,expected_impl", + [ + ( + "c -keys=riscv_cpu -march=rv32gcv", + "conv2d_nchw_int8.riscv", + ), + ( + "c -keys=riscv_cpu -march=rv64gcv", + "conv2d_nchw_int8.riscv", + ), + ], +) +def test_riscv_conv2d(target, expected_impl): + target = tvm.target.Target(target) + + data_type = "uint8" + weight_type = "int8" + data_shape = (1, 128, 1, 1) + weight_shape = (128, 128, 1, 1) + data_layout = "NCHW" + kernel_layout = "OIHW" + channels = 128 + kernel_size = (1, 1) + + out = relay.nn.conv2d( + relay.var("data", shape=data_shape, dtype=data_type), + relay.var("weight", shape=weight_shape, dtype=weight_type), + kernel_size=kernel_size, + channels=channels, + data_layout=data_layout, + kernel_layout=kernel_layout, + ) + out = run_infer_type(out) + + with target: + impl, _ = relay.backend.te_compiler.select_implementation( + out.op, + out.attrs, + [te.placeholder(data_shape, data_type), te.placeholder(weight_shape, weight_type)], + out.checked_type, + target, + ) + + assert impl.name == expected_impl + + @pytest.mark.parametrize( "target,expected_impl", [ From f7fb4251164f735d37f1f021fa728a34b02b389e Mon Sep 17 00:00:00 2001 From: katebern-grovety <132359118+katebern-grovety@users.noreply.github.com> Date: Wed, 16 Aug 2023 13:24:31 +0400 Subject: [PATCH 09/10] Delete riscv.mk --- tests/python/relay/aot/riscv.mk | 85 --------------------------------- 1 file changed, 85 deletions(-) delete mode 100644 tests/python/relay/aot/riscv.mk diff --git a/tests/python/relay/aot/riscv.mk b/tests/python/relay/aot/riscv.mk deleted file mode 100644 index 9282760df7a7..000000000000 --- a/tests/python/relay/aot/riscv.mk +++ /dev/null @@ -1,85 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -AOT_ROOT ?= $(CRT_ROOT)/aot - -ENABLE_TVM_PLATFORM_ABORT_BACKTRACE = 0 -DMLC_CORE=$(TVM_ROOT)/3rdparty/dmlc-core -TOOLCHAIN_PATH=$(shell dirname $(shell which riscv64-unknown-linux-gnu-gcc))/.. - -CC = clang-16 -CXX = clang++-16 - -TARGET_CFLAGS = --target=riscv64-unknown-linux-gnu -march=rv64gcv -static -RUNNER = spike -RUNNER_OPT = --isa=rv64gcv $(shell which pk) - -PKG_CFLAGS = ${PKG_COMPILE_OPTS} ${TARGET_CFLAGS} -O2 \ - -I$(build_dir)/../include \ - -I$(CODEGEN_ROOT)/host/include \ - -include $(CODEGEN_ROOT)/host/include/tvmgen_default.h \ - -isystem$(STANDALONE_CRT_DIR)/include \ - --sysroot=$(TOOLCHAIN_PATH)/sysroot \ - --gcc-toolchain=$(TOOLCHAIN_PATH) - -$(ifeq VERBOSE,1) -QUIET ?= -$(else) -QUIET ?= @ -$(endif) - -aot_test_runner: $(build_dir)/aot_test_runner - -c_source_libs = $(wildcard $(build_dir)/../codegen/host/src/*.c) -cc_source_libs = $(wildcard $(build_dir)/../codegen/host/src/*.cc) -c_lib_objs = $(addprefix $(build_dir)/, $(notdir $(c_source_libs:.c=.o))) -cc_lib_objs = $(cc_source_libs:.cc=.o) - -$(build_dir)/aot_test_runner: $(build_dir)/test.c $(c_lib_objs) $(cc_lib_objs) $(build_dir)/stack_allocator.o $(build_dir)/crt_backend_api.o - $(QUIET)mkdir -p $(@D) - $(QUIET)$(CC) $(CFLAGS) $(PKG_CFLAGS) -o $@ $^ $(PKG_LDFLAGS) $(BACKTRACE_LDFLAGS) $(BACKTRACE_CFLAGS) -lm - -$(build_dir)/%.o: $(build_dir)/../codegen/host/src/%.c - $(QUIET)mkdir -p $(@D) - $(QUIET)$(CC) $(CFLAGS) -c $(PKG_CFLAGS) -o $@ $< $(BACKTRACE_CFLAGS) - -$(build_dir)/%.o: $(build_dir)/../codegen/host/src/%.cc - $(QUIET)mkdir -p $(@D) - $(QUIET)$(CXX) $(CFLAGS) -c $(PKG_CFLAGS) -o $@ $< $(BACKTRACE_CFLAGS) - -$(build_dir)/stack_allocator.o: $(STANDALONE_CRT_DIR)/src/runtime/crt/memory/stack_allocator.c - $(QUIET)mkdir -p $(@D) - $(QUIET)$(CC) $(CFLAGS) -c $(PKG_CFLAGS) -o $@ $^ $(BACKTRACE_CFLAGS) - -$(build_dir)/crt_backend_api.o: $(STANDALONE_CRT_DIR)/src/runtime/crt/common/crt_backend_api.c - $(QUIET)mkdir -p $(@D) - $(QUIET)$(CC) $(CFLAGS) -c $(PKG_CFLAGS) -o $@ $^ $(BACKTRACE_CFLAGS) - -clean: - $(QUIET)rm -rf $(build_dir)/crt -cleanall: - $(QUIET)rm -rf $(build_dir) - -run: $(build_dir)/aot_test_runner - $(RUNNER) $(RUNNER_OPT) $(build_dir)/aot_test_runner - -# # Don't define implicit rules; they tend to match on logical target names that aren't targets (i.e. bundle_static) -.SUFFIXES: - -.DEFAULT: aot_test_runner - -.PHONY: run From fe48a1b10df6afbd6478db8480ddd0ac1f2834ff Mon Sep 17 00:00:00 2001 From: katebern-grovety <132359118+katebern-grovety@users.noreply.github.com> Date: Mon, 21 Aug 2023 10:42:18 +0400 Subject: [PATCH 10/10] move Project API changes to new PR --- CMakeLists.txt | 1 - apps/microtvm/riscv/ReadMe.md | 18 -- .../template_project/CMakeLists.txt.template | 64 ---- .../template_project/microtvm_api_server.py | 285 ------------------ .../riscv/template_project/src/main.cc | 106 ------- .../riscv/template_project/src/platform.cc | 126 -------- cmake/modules/Micro.cmake | 10 +- python/tvm/micro/build.py | 1 - .../riscv_cpu/test_conv2d_int8_nchw.py | 6 +- 9 files changed, 2 insertions(+), 615 deletions(-) delete mode 100644 apps/microtvm/riscv/ReadMe.md delete mode 100644 apps/microtvm/riscv/template_project/CMakeLists.txt.template delete mode 100644 apps/microtvm/riscv/template_project/microtvm_api_server.py delete mode 100644 apps/microtvm/riscv/template_project/src/main.cc delete mode 100644 apps/microtvm/riscv/template_project/src/platform.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 5d3e9761076a..67e87d907141 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -590,7 +590,6 @@ if(USE_MICRO) add_dependencies(tvm_runtime arduino) add_dependencies(tvm_runtime crt) add_dependencies(tvm_runtime host_standalone_crt) - add_dependencies(tvm_runtime riscv) add_dependencies(tvm_runtime zephyr) endif() diff --git a/apps/microtvm/riscv/ReadMe.md b/apps/microtvm/riscv/ReadMe.md deleted file mode 100644 index 5bffc9218a93..000000000000 --- a/apps/microtvm/riscv/ReadMe.md +++ /dev/null @@ -1,18 +0,0 @@ - - - - - - - - - - - - - - - - - -This directory contains code to interface microTVM with [RISC-V](https://riscv.org/). diff --git a/apps/microtvm/riscv/template_project/CMakeLists.txt.template b/apps/microtvm/riscv/template_project/CMakeLists.txt.template deleted file mode 100644 index 6a25b4e9a5ee..000000000000 --- a/apps/microtvm/riscv/template_project/CMakeLists.txt.template +++ /dev/null @@ -1,64 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# SPDX-License-Identifier: Apache-2.0 - -cmake_minimum_required(VERSION 3.18) -set(CMAKE_CXX_STANDARD 11) -set(CMAKE_C_COMPILER "clang") -set(CMAKE_CXX_COMPILER "clang++") -set(FLAGS "--target=${TARGET} --sysroot=${TOOLCHAIN_PATH}/sysroot --gcc-toolchain=${TOOLCHAIN_PATH} -march=${MARCH} -static") -set(CMAKE_C_FLAGS ${FLAGS}) -set(CMAKE_CXX_FLAGS ${FLAGS}) - -project(crt_autogenerated_project C CXX) -add_executable(main) - -set(CRT_LIB_BASE crt/src/runtime/crt) -set(CRT_LIBS microtvm_rpc_server - microtvm_rpc_common - aot_executor_module - aot_executor - graph_executor_module - graph_executor - common - memory -) - -# Build CRT libraries -foreach(crt_lib_name ${CRT_LIBS}) - add_library(${crt_lib_name}) - file(GLOB_RECURSE crt_lib_srcs ${CRT_LIB_BASE}/${crt_lib_name}/*.c ${CRT_LIB_BASE}/${crt_lib_name}/*.cc) - target_sources(${crt_lib_name} PRIVATE ${crt_lib_srcs}) - target_include_directories(${crt_lib_name} PRIVATE crt_config crt/include) - target_compile_definitions(${crt_lib_name} PRIVATE -DTVM_HOST_USE_GRAPH_EXECUTOR_MODULE) - target_link_libraries(main PRIVATE ${crt_lib_name}) -endforeach(crt_lib_name ${CRT_LIBS}) - -# Build model files -add_library(tvm_model) -file(GLOB_RECURSE tvm_model_srcs model/codegen/host/src/*.c model/codegen/host/lib/*.o) -target_sources(tvm_model PRIVATE ${tvm_model_srcs}) -target_include_directories(tvm_model PRIVATE ${CMAKE_SOURCE_DIR}/include crt_config crt/include) -target_compile_options(tvm_model PRIVATE -Wno-error=unused-variable -Wno-error=missing-braces -Wno-error=unused-const-variable -Wno-unused-variable) -set_target_properties(tvm_model PROPERTIES LINKER_LANGUAGE C) -target_link_libraries(main PRIVATE tvm_model) - -file(GLOB_RECURSE app_srcs src/**.cc) -target_sources(main PRIVATE ${app_srcs} ${cmsis_lib_srcs}) -target_compile_definitions(main PRIVATE -DTVM_HOST_USE_GRAPH_EXECUTOR_MODULE) -target_include_directories(main PRIVATE crt_config include ${CMAKE_SOURCE_DIR}/include crt/include) diff --git a/apps/microtvm/riscv/template_project/microtvm_api_server.py b/apps/microtvm/riscv/template_project/microtvm_api_server.py deleted file mode 100644 index 1495073d8583..000000000000 --- a/apps/microtvm/riscv/template_project/microtvm_api_server.py +++ /dev/null @@ -1,285 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import fcntl -import os -import os.path -import pathlib -import select -import shutil -import subprocess -import tarfile -import time -import re - -from tvm.micro.project_api import server - - -PROJECT_DIR = pathlib.Path(os.path.dirname(__file__) or os.path.getcwd()) - - -MODEL_LIBRARY_FORMAT_RELPATH = "model.tar" - - -IS_TEMPLATE = not os.path.exists(os.path.join(PROJECT_DIR, MODEL_LIBRARY_FORMAT_RELPATH)) - -# Used this size to pass most CRT tests in TVM. -WORKSPACE_SIZE_BYTES = 2 * 1024 * 1024 - -CMAKEFILE_FILENAME = "CMakeLists.txt" - -# The build target given to make -BUILD_TARGET = "build/main" - - -class Handler(server.ProjectAPIHandler): - BUILD_TARGET = "build/main" - - def __init__(self): - super(Handler, self).__init__() - self._proc = None - - def server_info_query(self, tvm_version): - return server.ServerInfo( - platform_name="host", - is_template=IS_TEMPLATE, - model_library_format_path="" - if IS_TEMPLATE - else PROJECT_DIR / MODEL_LIBRARY_FORMAT_RELPATH, - project_options=[ - server.ProjectOption( - "verbose", - optional=["build"], - type="bool", - default=False, - help="Run make with verbose output", - ), - server.ProjectOption( - "workspace_size_bytes", - optional=["generate_project"], - type="int", - default=WORKSPACE_SIZE_BYTES, - help="Sets the value of TVM_WORKSPACE_SIZE_BYTES.", - ), - server.ProjectOption( - "toolchain_path", - optional=["generate_project"], - type="str", - default="/opt/riscv", - help="Sets the value of toolchain path.", - ), - server.ProjectOption( - "target", - optional=["generate_project"], - type="str", - default="riscv64-unknown-linux-gnu", - help="Sets the value of target.", - ), - server.ProjectOption( - "march", - optional=["generate_project"], - type="str", - default="rv64gc", - help="Sets the value of target architecture.", - ), - ], - ) - - # These files and directories will be recursively copied into generated projects from the CRT. - CRT_COPY_ITEMS = ("include", "CMakeLists.txt", "src") - - def _populate_cmake( - self, - cmakefile_template_path: pathlib.Path, - cmakefile_path: pathlib.Path, - memory_size: int, - verbose: bool, - ): - """Generate CMakeList file from template.""" - - regex = re.compile(r"([A-Z_]+) := (<[A-Z_]+>)") - with open(cmakefile_path, "w") as cmakefile_f: - with open(cmakefile_template_path, "r") as cmakefile_template_f: - for line in cmakefile_template_f: - cmakefile_f.write(line) - cmakefile_f.write( - f"target_compile_definitions(main PUBLIC -DTVM_WORKSPACE_SIZE_BYTES={memory_size})\n" - ) - if verbose: - cmakefile_f.write(f"set(CMAKE_VERBOSE_MAKEFILE TRUE)\n") - - def generate_project(self, model_library_format_path, standalone_crt_dir, project_dir, options): - # Make project directory. - project_dir.mkdir(parents=True) - current_dir = pathlib.Path(__file__).parent.absolute() - - # Copy ourselves to the generated project. TVM may perform further build steps on the generated project - # by launching the copy. - shutil.copy2(__file__, project_dir / os.path.basename(__file__)) - - # Place Model Library Format tarball in the special location, which this script uses to decide - # whether it's being invoked in a template or generated project. - project_model_library_format_path = project_dir / MODEL_LIBRARY_FORMAT_RELPATH - shutil.copy2(model_library_format_path, project_model_library_format_path) - - # Extract Model Library Format tarball.into /model. - extract_path = project_dir / project_model_library_format_path.stem - with tarfile.TarFile(project_model_library_format_path) as tf: - os.makedirs(extract_path) - tf.extractall(path=extract_path) - - # Populate CRT. - crt_path = project_dir / "crt" - os.mkdir(crt_path) - for item in self.CRT_COPY_ITEMS: - src_path = standalone_crt_dir / item - dst_path = crt_path / item - if os.path.isdir(src_path): - shutil.copytree(src_path, dst_path) - else: - shutil.copy2(src_path, dst_path) - - # Populate CMake file - self._populate_cmake( - current_dir / f"{CMAKEFILE_FILENAME}.template", - project_dir / CMAKEFILE_FILENAME, - options.get("workspace_size_bytes", WORKSPACE_SIZE_BYTES), - options.get("verbose"), - ) - - # Populate crt-config.h - crt_config_dir = project_dir / "crt_config" - crt_config_dir.mkdir() - shutil.copy2( - current_dir / "crt_config" / "crt_config.h", - crt_config_dir / "crt_config.h", - ) - - # Populate src/ - src_dir = project_dir / "src" - src_dir.mkdir() - shutil.copy2( - current_dir / "src" / "main.cc", - src_dir / "main.cc", - ) - shutil.copy2( - current_dir / "src" / "platform.cc", - src_dir / "platform.cc", - ) - - def build(self, options): - cmake_args = [] - toolchain_path = options.get("toolchain_path") - target = options.get("target") - march = options.get("march") - cmake_args.append(f"-DTOOLCHAIN_PATH={toolchain_path}") - cmake_args.append(f"-DTARGET={target}") - cmake_args.append(f"-DMARCH={march}") - - build_dir = PROJECT_DIR / "build" - build_dir.mkdir() - subprocess.check_call(["cmake", *cmake_args, ".."], cwd=build_dir) - subprocess.check_call(["make"], cwd=build_dir) - - def flash(self, options): - pass # Flashing does nothing on host. - - def _set_nonblock(self, fd): - flag = fcntl.fcntl(fd, fcntl.F_GETFL) - fcntl.fcntl(fd, fcntl.F_SETFL, flag | os.O_NONBLOCK) - new_flag = fcntl.fcntl(fd, fcntl.F_GETFL) - assert (new_flag & os.O_NONBLOCK) != 0, "Cannot set file descriptor {fd} to non-blocking" - - def open_transport(self, options): - toolchain_path = options.get("toolchain_path") - target = options.get("target") - march = options.get("march") - spike_args = ["spike", f"--isa={march}", os.path.join(toolchain_path, target, "bin/pk")] - self._proc = subprocess.Popen( - spike_args + [self.BUILD_TARGET], - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - bufsize=0, - ) - self._set_nonblock(self._proc.stdin.fileno()) - self._set_nonblock(self._proc.stdout.fileno()) - return server.TransportTimeouts( - session_start_retry_timeout_sec=0, - session_start_timeout_sec=0, - session_established_timeout_sec=0, - ) - - def close_transport(self): - if self._proc is not None: - proc = self._proc - self._proc = None - proc.kill() - proc.wait() - - def _await_ready(self, rlist, wlist, timeout_sec=None, end_time=None): - if timeout_sec is None and end_time is not None: - timeout_sec = max(0, end_time - time.monotonic()) - - rlist, wlist, xlist = select.select(rlist, wlist, rlist + wlist, timeout_sec) - if not rlist and not wlist and not xlist: - raise server.IoTimeoutError() - - return True - - def read_transport(self, n, timeout_sec): - if self._proc is None: - raise server.TransportClosedError() - - fd = self._proc.stdout.fileno() - end_time = None if timeout_sec is None else time.monotonic() + timeout_sec - - try: - self._await_ready([fd], [], end_time=end_time) - to_return = os.read(fd, n) - except BrokenPipeError: - to_return = 0 - - if not to_return: - self.close_transport() - raise server.TransportClosedError() - - return to_return - - def write_transport(self, data, timeout_sec): - if self._proc is None: - raise server.TransportClosedError() - - fd = self._proc.stdin.fileno() - end_time = None if timeout_sec is None else time.monotonic() + timeout_sec - - data_len = len(data) - while data: - self._await_ready([], [fd], end_time=end_time) - try: - num_written = os.write(fd, data) - except BrokenPipeError: - num_written = 0 - - if not num_written: - self.close_transport() - raise server.TransportClosedError() - - data = data[num_written:] - - -if __name__ == "__main__": - server.main(Handler()) diff --git a/apps/microtvm/riscv/template_project/src/main.cc b/apps/microtvm/riscv/template_project/src/main.cc deleted file mode 100644 index 0607d4b28719..000000000000 --- a/apps/microtvm/riscv/template_project/src/main.cc +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file main.cc - * \brief main entry point for host subprocess-based CRT - */ -#include -#include -#include -#include -#include -#include - -#include - -#include "crt_config.h" - -#ifdef TVM_HOST_USE_GRAPH_EXECUTOR_MODULE -#include -#endif - -extern "C" { - -ssize_t MicroTVMWriteFunc(void* context, const uint8_t* data, size_t num_bytes) { - ssize_t to_return = write(STDOUT_FILENO, data, num_bytes); - fflush(stdout); - fsync(STDOUT_FILENO); - return to_return; -} -} - -static char** g_argv = NULL; - -int testonly_reset_server(TVMValue* args, int* type_codes, int num_args, TVMValue* out_ret_value, - int* out_ret_tcode, void* resource_handle) { - execvp(g_argv[0], g_argv); - perror("microTVM runtime: error restarting"); - return -1; -} - -int main(int argc, char** argv) { - g_argv = argv; - TVMPlatformInitialize(); - microtvm_rpc_server_t rpc_server = MicroTVMRpcServerInit(&MicroTVMWriteFunc, nullptr); - -#ifdef TVM_HOST_USE_GRAPH_EXECUTOR_MODULE - CHECK_EQ(TVMGraphExecutorModule_Register(), kTvmErrorNoError, - "failed to register GraphExecutor TVMModule"); -#endif - - int error = TVMFuncRegisterGlobal("tvm.testing.reset_server", - (TVMFunctionHandle)&testonly_reset_server, 0); - if (error) { - fprintf( - stderr, - "microTVM runtime: internal error (error#: %x) registering global packedfunc; exiting\n", - error); - return 2; - } - - setbuf(stdin, NULL); - setbuf(stdout, NULL); - - for (;;) { - uint8_t c; - int ret_code = read(STDIN_FILENO, &c, 1); - if (ret_code < 0) { - perror("microTVM runtime: read failed"); - return 2; - } else if (ret_code == 0) { - fprintf(stderr, "microTVM runtime: 0-length read, exiting!\n"); - return 2; - } - uint8_t* cursor = &c; - size_t bytes_to_process = 1; - while (bytes_to_process > 0) { - tvm_crt_error_t err = MicroTVMRpcServerLoop(rpc_server, &cursor, &bytes_to_process); - if (err == kTvmErrorPlatformShutdown) { - break; - } else if (err != kTvmErrorNoError) { - char buf[1024]; - snprintf(buf, sizeof(buf), "microTVM runtime: MicroTVMRpcServerLoop error: %08x", err); - perror(buf); - return 2; - } - } - } - return 0; -} diff --git a/apps/microtvm/riscv/template_project/src/platform.cc b/apps/microtvm/riscv/template_project/src/platform.cc deleted file mode 100644 index f5af08a9be88..000000000000 --- a/apps/microtvm/riscv/template_project/src/platform.cc +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief Implementation of TVMPlatform functions in tvm/runtime/crt/platform.h - */ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -using namespace std::chrono; - -extern "C" { - -uint8_t memory[TVM_WORKSPACE_SIZE_BYTES]; -MemoryManagerInterface* memory_manager; - -steady_clock::time_point g_microtvm_start_time; -int g_microtvm_timer_running = 0; - -// Called when an internal error occurs and execution cannot continue. -void TVMPlatformAbort(tvm_crt_error_t error_code) { - std::cerr << "TVMPlatformAbort: " << error_code << std::endl; - throw "Aborted"; -} - -// Called by the microTVM RPC server to implement TVMLogf. -size_t TVMPlatformFormatMessage(char* out_buf, size_t out_buf_size_bytes, const char* fmt, - va_list args) { - return vsprintf(out_buf, fmt, args); -} - -// Allocate memory for use by TVM. -tvm_crt_error_t TVMPlatformMemoryAllocate(size_t num_bytes, DLDevice dev, void** out_ptr) { - return memory_manager->Allocate(memory_manager, num_bytes, dev, out_ptr); -} - -// Free memory used by TVM. -tvm_crt_error_t TVMPlatformMemoryFree(void* ptr, DLDevice dev) { - return memory_manager->Free(memory_manager, ptr, dev); -} - -// Start a device timer. -tvm_crt_error_t TVMPlatformTimerStart() { - if (g_microtvm_timer_running) { - std::cerr << "timer already running" << std::endl; - return kTvmErrorPlatformTimerBadState; - } - g_microtvm_start_time = std::chrono::steady_clock::now(); - g_microtvm_timer_running = 1; - return kTvmErrorNoError; -} - -// Stop the running device timer and get the elapsed time (in microseconds). -tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) { - if (!g_microtvm_timer_running) { - std::cerr << "timer not running" << std::endl; - return kTvmErrorPlatformTimerBadState; - } - auto microtvm_stop_time = std::chrono::steady_clock::now(); - std::chrono::microseconds time_span = std::chrono::duration_cast( - microtvm_stop_time - g_microtvm_start_time); - *elapsed_time_seconds = static_cast(time_span.count()) / 1e6; - g_microtvm_timer_running = 0; - return kTvmErrorNoError; -} - -// Platform-specific before measurement call. -tvm_crt_error_t TVMPlatformBeforeMeasurement() { return kTvmErrorNoError; } - -// Platform-specific after measurement call. -tvm_crt_error_t TVMPlatformAfterMeasurement() { return kTvmErrorNoError; } - -static_assert(RAND_MAX >= (1 << 8), "RAND_MAX is smaller than acceptable"); -unsigned int random_seed = 0; -// Fill a buffer with random data. -tvm_crt_error_t TVMPlatformGenerateRandom(uint8_t* buffer, size_t num_bytes) { - if (random_seed == 0) { - random_seed = (unsigned int)time(NULL); - } - for (size_t i = 0; i < num_bytes; ++i) { - int random = rand_r(&random_seed); - buffer[i] = (uint8_t)random; - } - return kTvmErrorNoError; -} - -// Initialize TVM inference. -tvm_crt_error_t TVMPlatformInitialize() { - int status = - PageMemoryManagerCreate(&memory_manager, memory, sizeof(memory), 8 /* page_size_log2 */); - if (status != 0) { - fprintf(stderr, "error initiailizing memory manager\n"); - return kTvmErrorPlatformMemoryManagerInitialized; - } - return kTvmErrorNoError; -} - -} // extern C diff --git a/cmake/modules/Micro.cmake b/cmake/modules/Micro.cmake index 7172d91f1132..d887486d2e98 100644 --- a/cmake/modules/Micro.cmake +++ b/cmake/modules/Micro.cmake @@ -64,14 +64,6 @@ if(USE_MICRO) "src/runtime/crt/host CMakeLists.txt.template -> crt" "src/runtime/crt/host **.cc -> crt/src" ) - elseif("${platform}" STREQUAL "riscv") - list( - APPEND - PLATFORM_FILE_COPY_JOBS - "apps/microtvm/riscv/template_project microtvm_api_server.py -> riscv" - "apps/microtvm/riscv/template_project CMakeLists.txt.template -> riscv" - "apps/microtvm/riscv/template_project/src/ *.cc -> riscv/src/" - ) else() message(FATAL_ERROR "${platform} not supported.") endif() @@ -119,7 +111,7 @@ if(USE_MICRO) add_custom_target(${platform} DEPENDS ${platform_template_deps}) endfunction() - set(PLATFORMS crt;zephyr;arduino;riscv) + set(PLATFORMS crt;zephyr;arduino) foreach(platform IN LISTS PLATFORMS) message(STATUS "Add ${platform} template project.") microtvm_add_platform_project_api(${platform}) diff --git a/python/tvm/micro/build.py b/python/tvm/micro/build.py index c8f9f77665cd..df7d1fc7196d 100644 --- a/python/tvm/micro/build.py +++ b/python/tvm/micro/build.py @@ -40,7 +40,6 @@ class MicroTVMTemplateProject(enum.Enum): ZEPHYR = "zephyr" ARDUINO = "arduino" CRT = "crt" - RISCV = "riscv" @classmethod def list(cls): diff --git a/tests/python/relay/strategy/riscv_cpu/test_conv2d_int8_nchw.py b/tests/python/relay/strategy/riscv_cpu/test_conv2d_int8_nchw.py index 753ea3337e4a..acfa89281271 100644 --- a/tests/python/relay/strategy/riscv_cpu/test_conv2d_int8_nchw.py +++ b/tests/python/relay/strategy/riscv_cpu/test_conv2d_int8_nchw.py @@ -125,11 +125,7 @@ def do_test(): class TestConv2d_NCHW(RISCVConv2dInt8): - ( - data_shape, - kernel_size, - num_filter, - ) = tvm.testing.parameters( + (data_shape, kernel_size, num_filter,) = tvm.testing.parameters( ((1, 128, 14, 14), (3, 3), 128), ((1, 128, 14, 14), (1, 1), 256), ((1, 256, 7, 7), (1, 1), 512),