Skip to content

Commit

Permalink
Arm(R) Ethos(TM)-U NPU BinaryElementwise operators support (apache#9442)
Browse files Browse the repository at this point in the history
This commit adds support for the binary elementwise primitive operators for the Arm(R) Ethos(TM)-U NPU and includes a few minor rewording changes.
  • Loading branch information
NicolaLancellotti authored and mehrdadh committed Dec 1, 2021
1 parent c93a267 commit 5cd99b0
Show file tree
Hide file tree
Showing 24 changed files with 2,902 additions and 21 deletions.
227 changes: 227 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,224 @@ def __call__(self, *args, **kwargs):
pass


class BinaryElementwiseRewriter(DFPatternCallback):
"""Convert ethosu binary elementwise composite functions to
ethosu_binary_elementwise operators"""

def __init__(
self,
params_class: Type,
pattern: CallPattern,
):
super().__init__(require_type=True)
self.params_class = params_class
self.pattern = pattern

def callback(
self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map
) -> tvm.relay.Expr:
params = self.params_class(post.op.body)
params.ifm.tensor = post.args[1] if params.reversed_operands else post.args[0]
params.ifm2.tensor = post.args[0] if params.reversed_operands else post.args[1]
channels_map = {
"NHWC": 3,
}
if str(params.ofm.layout) not in channels_map.keys():
raise UnsupportedLayout(str(params.ofm.layout))

activation_map = {"clip": "CLIP"}
if params.activation:
activation = activation_map[params.activation.op.name]
clip_min = int(params.activation.attrs.a_min)
clip_max = int(params.activation.attrs.a_max)
else:
activation = "NONE"
clip_min = 0
clip_max = 0

# We don't yet support activation functions that need to get legalized to LUTs.
lut = relay.const([], dtype="int8")

return ethosu_ops.ethosu_binary_elementwise(
ifm=params.ifm.tensor,
ifm2=params.ifm2.tensor,
lut=lut,
operator_type=params.operator_type,
ifm_scale=float(params.ifm.q_params.scale_f32),
ifm_zero_point=int(params.ifm.q_params.zero_point),
ifm2_scale=float(params.ifm2.q_params.scale_f32),
ifm2_zero_point=int(params.ifm2.q_params.zero_point),
ofm_scale=float(params.ofm.q_params.scale_f32),
ofm_zero_point=int(params.ofm.q_params.zero_point),
ifm_channels=params.ifm.shape[3],
ifm2_channels=params.ifm2.shape[3],
reversed_operands=params.reversed_operands,
ofm_dtype=params.ofm.dtype,
activation=activation,
clip_min=clip_min,
clip_max=clip_max,
ifm_layout=str(params.ifm.layout),
ifm2_layout=str(params.ifm2.layout),
ofm_layout=str(params.ofm.layout),
)


class AddRewriter(BinaryElementwiseRewriter):
def __init__(self):
super().__init__(
params_class=ethosu_patterns.AddParams,
pattern=(wildcard().has_attr({"Composite": ethosu_patterns.AddParams.composite_name}))(
wildcard(), wildcard()
),
)


@ir.transform.module_pass(opt_level=1)
class LegalizeAdd:
"""This is the pass that wraps the AddRewriter"""

def transform_module(
self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
) -> tvm.ir.IRModule:
for global_var, func in mod.functions.items():
func = rewrite(AddRewriter(), func)
mod.update_func(global_var, func)
return mod

def __call__(self, *args, **kwargs):
pass


class SubRewriter(BinaryElementwiseRewriter):
def __init__(self):
super().__init__(
params_class=ethosu_patterns.SubParams,
pattern=(wildcard().has_attr({"Composite": ethosu_patterns.SubParams.composite_name}))(
wildcard(), wildcard()
),
)


@ir.transform.module_pass(opt_level=1)
class LegalizeSub:
"""This is the pass that wraps the SubRewriter"""

def transform_module(
self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
) -> tvm.ir.IRModule:
for global_var, func in mod.functions.items():
func = rewrite(SubRewriter(), func)
mod.update_func(global_var, func)
return mod

def __call__(self, *args, **kwargs):
pass


class MulRewriter(BinaryElementwiseRewriter):
def __init__(self):
super().__init__(
params_class=ethosu_patterns.MulParams,
pattern=(wildcard().has_attr({"Composite": ethosu_patterns.MulParams.composite_name}))(
wildcard(), wildcard()
),
)


@ir.transform.module_pass(opt_level=1)
class LegalizeMul:
"""This is the pass that wraps the MulRewriter"""

def transform_module(
self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
) -> tvm.ir.IRModule:
for global_var, func in mod.functions.items():
func = rewrite(MulRewriter(), func)
mod.update_func(global_var, func)
return mod

def __call__(self, *args, **kwargs):
pass


class MinRewriter(BinaryElementwiseRewriter):
def __init__(self):
super().__init__(
params_class=ethosu_patterns.MinParams,
pattern=(wildcard().has_attr({"Composite": ethosu_patterns.MinParams.composite_name}))(
wildcard(), wildcard()
),
)


@ir.transform.module_pass(opt_level=1)
class LegalizeMin:
"""This is the pass that wraps the MinRewriter"""

def transform_module(
self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
) -> tvm.ir.IRModule:
for global_var, func in mod.functions.items():
func = rewrite(MinRewriter(), func)
mod.update_func(global_var, func)
return mod

def __call__(self, *args, **kwargs):
pass


class MaxRewriter(BinaryElementwiseRewriter):
def __init__(self):
super().__init__(
params_class=ethosu_patterns.MaxParams,
pattern=(wildcard().has_attr({"Composite": ethosu_patterns.MaxParams.composite_name}))(
wildcard(), wildcard()
),
)


@ir.transform.module_pass(opt_level=1)
class LegalizeMax:
"""This is the pass that wraps the MaxRewriter"""

def transform_module(
self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
) -> tvm.ir.IRModule:
for global_var, func in mod.functions.items():
func = rewrite(MaxRewriter(), func)
mod.update_func(global_var, func)
return mod

def __call__(self, *args, **kwargs):
pass


class ShlRewriter(BinaryElementwiseRewriter):
def __init__(self):
super().__init__(
params_class=ethosu_patterns.ShlParams,
pattern=(wildcard().has_attr({"Composite": ethosu_patterns.ShlParams.composite_name}))(
wildcard(), wildcard()
),
)


@ir.transform.module_pass(opt_level=1)
class LegalizeShl:
"""This is the pass that wraps the ShlRewriter"""

def transform_module(
self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
) -> tvm.ir.IRModule:
for global_var, func in mod.functions.items():
func = rewrite(ShlRewriter(), func)
mod.update_func(global_var, func)
return mod

def __call__(self, *args, **kwargs):
pass


@ir.transform.module_pass(opt_level=1)
class LegalizeEthosU:
"""This is the pass to call graph-rewrites to perform graph transformation
Expand All @@ -423,11 +641,20 @@ class LegalizeEthosU:
def transform_module(
self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
) -> tvm.ir.IRModule:
"""This is the method that replaces the operations with hardware/codegen supported
operations.
"""
mod = LegalizeSplit()(mod)
mod = LegalizeConv2D()(mod)
mod = LegalizeDepthwiseConv2D()(mod)
mod = LegalizeMaxPooling()(mod)
mod = LegalizeAvgPooling()(mod)
mod = LegalizeAdd()(mod)
mod = LegalizeSub()(mod)
mod = LegalizeMul()(mod)
mod = LegalizeMin()(mod)
mod = LegalizeMax()(mod)
mod = LegalizeShl()(mod)
return mod

def __call__(self, *args, **kwargs):
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/backend/contrib/ethosu/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@
from .convolution import ethosu_conv2d
from .depthwise import ethosu_depthwise_conv2d
from .pooling import ethosu_pooling
from .binary_elementwise import ethosu_binary_elementwise
Loading

0 comments on commit 5cd99b0

Please sign in to comment.