Skip to content

Commit

Permalink
Add Arm(R) Ethos(TM)-U codegen support on tvmc
Browse files Browse the repository at this point in the history
*change vela imports to lazy imports

Change-Id: Id300953ea0ce252730ef97d5db7e96147640ca1c
  • Loading branch information
manupak committed Aug 27, 2021
1 parent 9a8204e commit 2a8aef1
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 23 deletions.
26 changes: 25 additions & 1 deletion python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from enum import auto
from enum import Enum
import numpy as np
import ethosu.vela.api as vapi

import tvm
from tvm.tir import stmt_functor
Expand Down Expand Up @@ -83,6 +82,7 @@ def translate(tir_module, params):
scratch_size : int
The size of the scratch buffer needed.
"""
import ethosu.vela.api as vapi

buffer_info = extract_buffer_info(tir_module, params)
extern_calls = extract_extern_calls(tir_module)
Expand Down Expand Up @@ -215,20 +215,26 @@ def replace_npu_fm_with_address(npu_fm):
return npu_fm

def replace_npu_address_range_with_address(npu_addr_range):
import ethosu.vela.api as vapi

assert isinstance(npu_addr_range.address, tvm.tir.Load)
buffer = npu_addr_range.address.buffer_var
assert buffer in buffer_addresses.keys()
address, buffer_type = buffer_addresses[buffer]
return vapi.NpuAddressRange(_REGION_MAP[buffer_type], address, npu_addr_range.length)

def replace_tir_loads(npu_object):
import ethosu.vela.api as vapi

if isinstance(npu_object, vapi.NpuFeatureMap):
return replace_npu_fm_with_address(npu_object)
if isinstance(npu_object, vapi.NpuAddressRange):
return replace_npu_address_range_with_address(npu_object)
return npu_object

def classify_io(buffer):
import ethosu.vela.api as vapi

for _npu_op in npu_ops:
if issubclass(type(_npu_op), vapi.NpuBlockOperation):
if _npu_op.ifm and _npu_op.ifm.tiles.addresses[0].buffer_var == buffer:
Expand Down Expand Up @@ -374,6 +380,8 @@ def _create_npu_op_conv2d(serial_2d_convolution):
"""This is a helper function to capture a list
of arguments to create Vela NpuConv2DOperation object
"""
import ethosu.vela.api as vapi

npu_conv2d_op = vapi.NpuConv2DOperation()
npu_conv2d_op.ifm = _create_npu_feature_map(serial_2d_convolution.ifm)
npu_conv2d_op.ofm = _create_npu_feature_map(serial_2d_convolution.ofm)
Expand Down Expand Up @@ -412,6 +420,8 @@ def _create_npu_feature_map(serial_feature_map):
"""This is a helper function to capture a list
of arguments to create Vela NpuFeatureMap object
"""
import ethosu.vela.api as vapi

layout_map = {"NHWC": vapi.NpuLayout.NHWC, "NHCWB16": vapi.NpuLayout.NHCWB16}
datatype_map = {
"uint8": vapi.NpuDataType.UINT8,
Expand Down Expand Up @@ -458,6 +468,8 @@ def _create_npu_kernel(serial_kernel):
"""This is a helper function to capture a list
of arguments to create Vela NpuKernel object
"""
import ethosu.vela.api as vapi

nknl = vapi.NpuKernel(
w=int(serial_kernel.width.value),
h=int(serial_kernel.height.value),
Expand All @@ -473,6 +485,8 @@ def _create_npu_address_range(serial_address_range):
"""This is a helper function to capture a list
of arguments to create Vela NpuAddressRange object
"""
import ethosu.vela.api as vapi

addr_range = vapi.NpuAddressRange(
# region will be updated later
region=0,
Expand All @@ -489,6 +503,8 @@ def _create_npu_quantization(
"""This is a helper function to capture a list
of arguments to create Vela NpuQuantization object
"""
import ethosu.vela.api as vapi

# Scale could be an ndarray if per-channel quantization is available
if not isinstance(scale, tvm.tir.expr.Load):
if isinstance(scale.value, float):
Expand All @@ -510,6 +526,8 @@ def _create_npu_weights_zero_point(
def _create_npu_padding(serial_padding):
"""This is a helper function to capture a list
of arguments to create Vela NpuPadding object"""
import ethosu.vela.api as vapi

padding = vapi.NpuPadding(
top=int(serial_padding.top.value),
left=int(serial_padding.left.value),
Expand All @@ -522,6 +540,8 @@ def _create_npu_padding(serial_padding):
def _create_npu_activation(serial_activation):
"""This is a helper function to capture a list
of arguments to create Vela NpuActivation object"""
import ethosu.vela.api as vapi

if serial_activation.op == "NONE":
return None
if (
Expand All @@ -548,6 +568,8 @@ def _create_npu_resampling_mode(
):
"""This is a helper function to capture a list
of arguments to create Vela NpuResamplingMode object"""
import ethosu.vela.api as vapi

mode_map = {
"NONE": vapi.NpuResamplingMode.NONE,
"NEAREST": vapi.NpuResamplingMode.NEAREST,
Expand All @@ -561,6 +583,8 @@ def _create_npu_resampling_mode(
def _create_npu_dma_op(serial_copy):
"""This is a helper function to capture the list of arguments
to create a NpuDmaOperation object"""
import ethosu.vela.api as vapi

src = vapi.NpuAddressRange(
# region will be updated later
region=0,
Expand Down
20 changes: 11 additions & 9 deletions python/tvm/relay/backend/contrib/ethosu/vela_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,13 @@
import logging
import math
import numpy as np
from ethosu.vela import api as vapi

from tvm.relay.backend.contrib.ethosu import util
from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator as tirtocs

# pylint: disable=invalid-name
logger = logging.getLogger("Ethos-U")

VELA_TO_NP_DTYPES = {
vapi.NpuDataType.UINT8: np.uint8,
vapi.NpuDataType.UINT16: np.uint16,
vapi.NpuDataType.INT8: np.int8,
vapi.NpuDataType.INT16: np.int16,
vapi.NpuDataType.INT32: np.int32,
}

SCALE_BIAS_LENGTH = 10


Expand All @@ -62,13 +53,17 @@ def get_optimal_block_config(npu_op, accel_type):
ethosu.vela.api.NpuShape3d :
The optimal block config for the operator
"""
from ethosu.vela import api as vapi

all_valid_block_configs = vapi.npu_find_block_configs(npu_op, accel_type)
return _get_optimal_block_config(all_valid_block_configs)


def _get_optimal_block_config(all_valid_block_configs):
"""An internal function to get block config with largest depth
and then highest volume/area"""
from ethosu.vela import api as vapi

assert isinstance(all_valid_block_configs, list)
for block_cfg in all_valid_block_configs:
assert isinstance(block_cfg, vapi.NpuShape3D)
Expand Down Expand Up @@ -194,6 +189,8 @@ def compress_weights(
compressed_weights : bytearray
Compressed weights
"""
from ethosu.vela import api as vapi

layout_transform_indices = {"HWIO": (3, 0, 1, 2), "HWOI": (2, 0, 1, 3), "OHWI": (0, 1, 2, 3)}
assert weights_layout in layout_transform_indices.keys()
assert isinstance(weights_zp, np.int64)
Expand Down Expand Up @@ -223,6 +220,7 @@ def calculate_block_traversal_mode(is_depthwise, weights_shape_ohwi, ifm_bitdept
"""Calculate a block traversal mode given whether the op is depthwise convolution,
shape of weights and bit-depth of the ifm.
"""
from ethosu.vela import api as vapi

if is_depthwise:
return vapi.NpuBlockTraversal.DEPTH_FIRST
Expand Down Expand Up @@ -276,6 +274,8 @@ def pack_biases(
scale_bias : numpy.ndarray
Packed scales/biases as the hardware requires them.
"""
from ethosu.vela import api as vapi

# The BYOC infra should not partition anything else.
supported_ifm_dtypes = (np.uint8, np.int8, np.int16)
assert ifm_dtype in supported_ifm_dtypes
Expand Down Expand Up @@ -350,6 +350,8 @@ def _calculate_hw_bias_scales(

def get_target_accel_type():
"""This is a helper function to convert TVMC command line argument to NpuAccelerator type"""
from ethosu.vela import api as vapi

npu_accel_str_map = {
"ethos-u55-256": vapi.NpuAccelerator.Ethos_U55_256,
"ethos-u55-128": vapi.NpuAccelerator.Ethos_U55_128,
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/relay/op/contrib/ethosu.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from tvm.relay.backend.contrib.ethosu.util import BiasAddArgs
from tvm.relay.backend.contrib.ethosu.util import RequantArgs
from tvm.relay.backend.contrib.ethosu.util import get_dim_value
from ethosu.vela import api as vapi


def check_strides(strides):
Expand Down Expand Up @@ -123,6 +122,8 @@ class TensorParams:
"""

def __init__(self, tensor, layout=None, scale=None, zero_point=None):
from ethosu.vela import api as vapi

self.tensor = tensor
if isinstance(tensor, Constant):
self.values = tensor.data.asnumpy()
Expand Down
16 changes: 4 additions & 12 deletions tests/python/contrib/test_ethosu/test_vela_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,9 +347,7 @@ def verify(test_vec, mock_obj):
assert mock_obj.call_args[1]["block_traversal"] == test_vec["block_traversal"]

def create_mock(test_vec):
with patch(
"tvm.relay.backend.contrib.ethosu.vela_api.vapi.npu_encode_weights"
) as mock_npu_encode_weights:
with patch("ethosu.vela.api.npu_encode_weights") as mock_npu_encode_weights:
ifm_bitdepth = np.iinfo(test_vec["ifm_dtype"]).bits
ifm_dtype = test_vec["ifm_dtype"]
max = np.iinfo(ifm_dtype).max
Expand Down Expand Up @@ -427,9 +425,7 @@ def verify(test_vec, mock_obj, packed_biases):
assert test_vec["hw_shifts"][idx] == mock_obj.call_args_list[idx][0][2]

def create_mock(test_vec):
with patch(
"tvm.relay.backend.contrib.ethosu.vela_api.vapi.npu_encode_bias"
) as mock_npu_encode_bias:
with patch("ethosu.vela.api.npu_encode_bias") as mock_npu_encode_bias:
mock_npu_encode_bias.return_value = bytearray(10)
ifm_dtype = test_vec["ifm_dtype"]
max = np.iinfo(ifm_dtype).max
Expand Down Expand Up @@ -507,12 +503,8 @@ def test_encode_weights(accel):
]

def create_mock(test_vec):
with patch(
"tvm.relay.backend.contrib.ethosu.vela_api.vapi.npu_encode_weights"
) as mock_enc_w:
with patch(
"tvm.relay.backend.contrib.ethosu.vela_api.vapi.npu_find_block_configs"
) as mock_blk_cfg:
with patch("ethosu.vela.api.npu_encode_weights") as mock_enc_w:
with patch("ethosu.vela.api.npu_find_block_configs") as mock_blk_cfg:
mock_blk_cfg.return_value = [vapi.NpuShape3D(8, 8, 8)]
ethosu_conv2d_calls = extract_ethosu_conv2d_extern_calls(test_vec["tir_module"])
buffer_info = tirtocs.extract_buffer_info(
Expand Down

0 comments on commit 2a8aef1

Please sign in to comment.