Skip to content
Closed
1 change: 1 addition & 0 deletions python/tvm/relay/op/strategy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@
from . import intel_graphics
from . import hexagon
from . import adreno
from . import riscv_cpu
95 changes: 95 additions & 0 deletions python/tvm/relay/op/strategy/riscv_cpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import
"""Definition of RISCV CPU operator strategy."""

import logging

from tvm import topi
from .. import op as _op
from .generic import *
from .x86 import conv2d_strategy_cpu

logger = logging.getLogger("strategy")


@schedule_injective.register("riscv_cpu")
def schedule_injective_riscv_cpu(_, outs, target):
"""schedule injective ops for riscv_cpu"""
with target:
return topi.riscv_cpu.schedule_injective(outs)


@schedule_reduce.register("riscv_cpu")
def schedule_reduce_riscv_cpu(_, outs, target):
"""schedule reduction ops for riscv_cpu"""
with target:
return topi.x86.schedule_reduce(outs)


@conv2d_strategy.register("riscv_cpu")
def conv2d_strategy_riscv_cpu(attrs, inputs, out_type, target):
"""conv2d riscv_cpu strategy"""
strategy = _op.OpStrategy()
data, kernel = inputs
dilation_h, dilation_w = attrs.get_int_tuple("dilation")
groups = attrs.groups
layout = attrs.data_layout
kernel_layout = attrs.kernel_layout
if dilation_h < 1 or dilation_w < 1:
raise ValueError("dilation should be positive value")

if groups == 1:
if layout == "NCHW":
assert kernel_layout == "OIHW"
is_int8 = topi.riscv_cpu.is_int8_hw_support(data.dtype, kernel.dtype)
# Vector instructions with int8 show more performance at a larger size.
if is_int8 and kernel.shape[1] >= 128:
strategy.add_implementation(
wrap_compute_conv2d(topi.riscv_cpu.conv2d_nchw_int8),
wrap_topi_schedule(topi.riscv_cpu.schedule_conv2d_nchw_int8),
name="conv2d_nchw_int8.riscv",
plevel=15,
)
return strategy

return conv2d_strategy_cpu(attrs, inputs, out_type, target)


@conv2d_NCHWc_strategy.register("riscv_cpu")
def conv2d_NCHWc_strategy_riscv_cpu(attrs, inputs, out_type, target):
"""conv2d_NCHWc adopted from x86"""
strategy = _op.OpStrategy()
data, kernel = inputs
is_int8 = topi.riscv_cpu.is_int8_hw_support(data.dtype, kernel.dtype)
# Vector instructions with int8 show more performance at a larger size.
if is_int8 and kernel.shape[1] >= 128:
strategy.add_implementation(
wrap_compute_conv2d(
topi.riscv_cpu.conv2d_NCHWc_int8, need_data_layout=True, need_out_layout=True
),
wrap_topi_schedule(topi.riscv_cpu.schedule_conv2d_NCHWc_int8),
name="conv2d_NCHWc_int8.riscv_cpu",
)
else:
strategy.add_implementation(
wrap_compute_conv2d(topi.x86.conv2d_NCHWc, need_data_layout=True, need_out_layout=True),
wrap_topi_schedule(topi.x86.schedule_conv2d_NCHWc),
name="conv2d_NCHWc.x86",
)
return strategy
14 changes: 14 additions & 0 deletions python/tvm/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1002,6 +1002,20 @@ def _corstone300_compile_time_check():
requires_vitis_ai = Feature("vitis_ai", "Vitis AI", cmake_flag="USE_VITIS_AI")


def _riscv_spike_run_time_check():
if shutil.which("spike") is None:
return "Spike RISC-V ISA Simulator unavailable"
return True


# Mark a test as requiring Spike to run
requires_riscv_spike = Feature(
"spike",
"Spike RISC-V ISA Simulator",
run_time_check=_riscv_spike_run_time_check,
)


def _arm_dot_supported():
arch = platform.machine()

Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from . import random
from . import hexagon
from . import adreno
from . import riscv_cpu

# error reporting
from .utils import InvalidShapeError
Expand Down
22 changes: 22 additions & 0 deletions python/tvm/topi/riscv_cpu/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# pylint: disable=redefined-builtin, wildcard-import
"""RISCV specific declaration and schedules."""

from .injective import *
from .conv2d_int8 import *
Loading