diff --git a/model_compression_toolkit/target_platform_capabilities/schema/v1.py b/model_compression_toolkit/target_platform_capabilities/schema/v1.py index 29dc647f9..31ea36458 100644 --- a/model_compression_toolkit/target_platform_capabilities/schema/v1.py +++ b/model_compression_toolkit/target_platform_capabilities/schema/v1.py @@ -25,7 +25,7 @@ class OperatorSetNames(Enum): OPSET_CONV = "Conv" OPSET_DEPTHWISE_CONV = "DepthwiseConv2D" - OPSET_CONV_TRANSPOSE = "ConvTraspose" + OPSET_CONV_TRANSPOSE = "ConvTranspose" OPSET_FULLY_CONNECTED = "FullyConnected" OPSET_CONCATENATE = "Concatenate" OPSET_STACK = "Stack" @@ -41,7 +41,8 @@ class OperatorSetNames(Enum): OPSET_SUB = "Sub" OPSET_MUL = "Mul" OPSET_DIV = "Div" - OPSET_MIN_MAX = "MinMax" + OPSET_MIN = "Min" + OPSET_MAX = "Max" OPSET_PRELU = "PReLU" OPSET_SWISH = "Swish" OPSET_SIGMOID = "Sigmoid" diff --git a/model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2fw.py b/model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2fw.py new file mode 100644 index 000000000..04afa236e --- /dev/null +++ b/model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2fw.py @@ -0,0 +1,56 @@ +from typing import Dict, Tuple, List, Any, Optional + +from model_compression_toolkit import DefaultDict +from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel +from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities, \ + OperationsSetToLayers + + +class AttachTpModelToFw: + + def __init__(self): + self._opset2layer = None + + # A mapping that associates each layer type in the operation set (with weight attributes and a quantization + # configuration in the target platform model) to its framework-specific attribute name. If not all layer types + # in the operation set are provided in the mapping, a DefaultDict should be supplied to handle missing entries. + self._opset2attr_mapping = None # Mapping of operation sets to their corresponding framework-specific layers + + def attach(self, tpc_model: TargetPlatformModel, + custom_opset2layer: Dict[str, Tuple[List[Any], Optional[Dict[str, DefaultDict]]]] = None + ) -> TargetPlatformCapabilities: + """ + Attaching a TargetPlatformModel which includes a platform capabilities description to specific + framework's operators. + + Args: + tpc_model: a TargetPlatformModel object. + custom_opset2layer: optional set of custom operator sets which allows to add/override the built-in set + of framework operator, to define a specific behavior for those operators. This dictionary should map + an operator set unique name to a pair of: a list of framework operators and an optional + operator's attributes names mapping. + + Returns: a TargetPlatformCapabilities object. + + """ + + tpc = TargetPlatformCapabilities(tpc_model) + + with tpc: + for opset_name, operators in self._opset2layer.items(): + attr_mapping = self._opset2attr_mapping.get(opset_name) + OperationsSetToLayers(opset_name, operators, attr_mapping=attr_mapping) + + if custom_opset2layer is not None: + for opset_name, operators in custom_opset2layer.items(): + if len(operators) == 1: + OperationsSetToLayers(opset_name, operators[0]) + elif len(operators) == 2: + OperationsSetToLayers(opset_name, operators[0], attr_mapping=operators[1]) + else: + raise ValueError(f"Custom operator set to layer mapping should include up to 2 elements - " + f"a list of layers to attach to the operator and an optional mapping of " + f"attributes names, but given a mapping contains {len(operators)} elements.") + + return tpc + diff --git a/model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2keras.py b/model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2keras.py new file mode 100644 index 000000000..f7c8a524c --- /dev/null +++ b/model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2keras.py @@ -0,0 +1,107 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import tensorflow as tf +from packaging import version + +from model_compression_toolkit.verify_packages import FOUND_SONY_CUSTOM_LAYERS + +if FOUND_SONY_CUSTOM_LAYERS: + from sony_custom_layers.keras.object_detection.ssd_post_process import SSDPostProcess + +if version.parse(tf.__version__) >= version.parse("2.13"): + from keras.src.layers import Conv2D, DepthwiseConv2D, Dense, Reshape, ZeroPadding2D, Dropout, \ + MaxPooling2D, Activation, ReLU, Add, Subtract, Multiply, PReLU, Flatten, Cropping2D, LeakyReLU, Permute, \ + Conv2DTranspose, Identity, Concatenate, BatchNormalization, Minimum, Maximum +else: + from keras.layers import Conv2D, DepthwiseConv2D, Dense, Reshape, ZeroPadding2D, Dropout, \ + MaxPooling2D, Activation, ReLU, Add, Subtract, Multiply, PReLU, Flatten, Cropping2D, LeakyReLU, Permute, \ + Conv2DTranspose, Concatenate, BatchNormalization, Minimum, Maximum + +from model_compression_toolkit import DefaultDict +from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS, \ + BIAS_ATTR, KERAS_KERNEL, KERAS_DEPTHWISE_KERNEL +from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OperatorSetNames +from model_compression_toolkit.target_platform_capabilities.target_platform import LayerFilterParams +from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2fw import \ + AttachTpModelToFw + + +class AttachTpModelToKeras(AttachTpModelToFw): + def __init__(self): + super().__init__() + + self._opset2layer = { + OperatorSetNames.OPSET_CONV.value: [Conv2D, tf.nn.conv2d], + OperatorSetNames.OPSET_DEPTHWISE_CONV.value: [DepthwiseConv2D, tf.nn.depthwise_conv2d], + OperatorSetNames.OPSET_CONV_TRANSPOSE.value: [Conv2DTranspose, tf.nn.conv2d_transpose], + OperatorSetNames.OPSET_FULLY_CONNECTED.value: [Dense], + OperatorSetNames.OPSET_CONCATENATE.value: [tf.concat, Concatenate], + OperatorSetNames.OPSET_STACK.value: [tf.stack], + OperatorSetNames.OPSET_UNSTACK.value: [tf.unstack], + OperatorSetNames.OPSET_GATHER.value: [tf.gather, tf.compat.v1.gather], + OperatorSetNames.OPSET_EXPAND.value: [], + OperatorSetNames.OPSET_BATCH_NORM.value: [BatchNormalization], + OperatorSetNames.OPSET_RELU.value: [tf.nn.relu, ReLU], + OperatorSetNames.OPSET_RELU6.value: [tf.nn.relu6], + OperatorSetNames.OPSET_LEAKY_RELU.value: [tf.nn.leaky_relu, LeakyReLU], + OperatorSetNames.OPSET_HARD_TANH.value: [LayerFilterParams(Activation, activation="hard_tanh")], + OperatorSetNames.OPSET_ADD.value: [tf.add, Add], + OperatorSetNames.OPSET_SUB.value: [tf.subtract, Subtract], + OperatorSetNames.OPSET_MUL.value: [tf.math.multiply, Multiply], + OperatorSetNames.OPSET_DIV.value: [tf.math.divide, tf.math.truediv], + OperatorSetNames.OPSET_MIN.value: [tf.math.minimum, Minimum], + OperatorSetNames.OPSET_MAX.value: [tf.math.maximum, Maximum], + OperatorSetNames.OPSET_PRELU.value: [PReLU], + OperatorSetNames.OPSET_SWISH.value: [tf.nn.swish, LayerFilterParams(Activation, activation="swish")], + OperatorSetNames.OPSET_SIGMOID.value: [tf.nn.sigmoid, LayerFilterParams(Activation, activation="sigmoid")], + OperatorSetNames.OPSET_TANH.value: [tf.nn.tanh, LayerFilterParams(Activation, activation="tanh")], + OperatorSetNames.OPSET_GELU.value: [tf.nn.gelu, LayerFilterParams(Activation, activation="gelu")], + OperatorSetNames.OPSET_HARDSIGMOID.value: [tf.keras.activations.hard_sigmoid, + LayerFilterParams(Activation, activation="hard_sigmoid")], + OperatorSetNames.OPSET_FLATTEN.value: [Flatten], + OperatorSetNames.OPSET_GET_ITEM.value: [tf.__operators__.getitem], + OperatorSetNames.OPSET_RESHAPE.value: [Reshape, tf.reshape], + OperatorSetNames.OPSET_PERMUTE.value: [Permute], + OperatorSetNames.OPSET_TRANSPOSE.value: [tf.transpose], + OperatorSetNames.OPSET_DROPOUT.value: [Dropout], + OperatorSetNames.OPSET_SPLIT.value: [tf.split], + OperatorSetNames.OPSET_MAXPOOL.value: [MaxPooling2D], + OperatorSetNames.OPSET_SHAPE.value: [tf.shape, tf.compat.v1.shape], + OperatorSetNames.OPSET_EQUAL.value: [tf.math.equal], + OperatorSetNames.OPSET_ARGMAX.value: [tf.math.argmax], + OperatorSetNames.OPSET_TOPK.value: [tf.nn.top_k], + OperatorSetNames.OPSET_FAKE_QUANT_WITH_MIN_MAX_VARS.value: [tf.quantization.fake_quant_with_min_max_vars], + OperatorSetNames.OPSET_COMBINED_NON_MAX_SUPPRESSION.value: [tf.image.combined_non_max_suppression], + OperatorSetNames.OPSET_CROPPING2D.value: [Cropping2D], + OperatorSetNames.OPSET_ZERO_PADDING2d.value: [ZeroPadding2D], + OperatorSetNames.OPSET_CAST.value: [tf.cast], + OperatorSetNames.OPSET_STRIDED_SLICE.value: [tf.strided_slice] + } + + if FOUND_SONY_CUSTOM_LAYERS: + self._opset2layer[OperatorSetNames.OPSET_POST_PROCESS] = [SSDPostProcess] + + self._opset2attr_mapping = {OperatorSetNames.OPSET_CONV.value: { + KERNEL_ATTR: DefaultDict(default_value=KERAS_KERNEL), + BIAS_ATTR: DefaultDict(default_value=BIAS)}, + OperatorSetNames.OPSET_DEPTHWISE_CONV.value: { + KERNEL_ATTR: DefaultDict({ + DepthwiseConv2D: KERAS_DEPTHWISE_KERNEL, + tf.nn.depthwise_conv2d: KERAS_DEPTHWISE_KERNEL}, default_value=KERAS_KERNEL), + BIAS_ATTR: DefaultDict(default_value=BIAS)}, + OperatorSetNames.OPSET_FULLY_CONNECTED.value: { + KERNEL_ATTR: DefaultDict(default_value=KERAS_KERNEL), + BIAS_ATTR: DefaultDict(default_value=BIAS)}} diff --git a/model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2pytorch.py b/model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2pytorch.py new file mode 100644 index 000000000..e68043596 --- /dev/null +++ b/model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2pytorch.py @@ -0,0 +1,91 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import operator + +import torch +from torch import add, sub, mul, div, divide, flatten, reshape, split, unsqueeze, dropout, sigmoid, tanh, \ + chunk, unbind, topk, gather, equal, transpose, permute, argmax, squeeze, multiply, subtract, minimum, \ + maximum +from torch.nn import Conv2d, Linear, ConvTranspose2d, MaxPool2d, BatchNorm2d +from torch.nn import Dropout, Flatten, Hardtanh +from torch.nn import ReLU, ReLU6, PReLU, SiLU, Sigmoid, Tanh, Hardswish, Hardsigmoid, LeakyReLU, GELU +import torch.nn.functional as F +from torch.nn.functional import relu, relu6, prelu, silu, hardtanh, hardswish, hardsigmoid, leaky_relu, gelu + +from model_compression_toolkit import DefaultDict +from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, PYTORCH_KERNEL, BIAS, \ + BIAS_ATTR +from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OperatorSetNames +from model_compression_toolkit.target_platform_capabilities.target_platform import LayerFilterParams +from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2fw import \ + AttachTpModelToFw + + +class AttachTpModelToPytorch(AttachTpModelToFw): + def __init__(self): + super().__init__() + + self._opset2layer = { + OperatorSetNames.OPSET_CONV.value: [Conv2d], + OperatorSetNames.OPSET_CONV_TRANSPOSE.value: [ConvTranspose2d], + OperatorSetNames.OPSET_FULLY_CONNECTED.value: [Linear], + OperatorSetNames.OPSET_CONCATENATE.value: [torch.cat, torch.concat, torch.concatenate], + OperatorSetNames.OPSET_STACK.value: [torch.stack], + OperatorSetNames.OPSET_UNSTACK.value: [unbind], + OperatorSetNames.OPSET_GATHER.value: [gather], + OperatorSetNames.OPSET_EXPAND.value: [torch.Tensor.expand], + OperatorSetNames.OPSET_BATCH_NORM.value: [BatchNorm2d], + OperatorSetNames.OPSET_RELU.value: [torch.relu, ReLU, relu], + OperatorSetNames.OPSET_RELU6.value: [ReLU6, relu6], + OperatorSetNames.OPSET_LEAKY_RELU.value: [LeakyReLU, leaky_relu], + OperatorSetNames.OPSET_HARD_TANH.value: [LayerFilterParams(Hardtanh, min_val=0), + LayerFilterParams(hardtanh, min_val=0)], + OperatorSetNames.OPSET_ADD.value: [operator.add, add], + OperatorSetNames.OPSET_SUB.value: [operator.sub, sub, subtract], + OperatorSetNames.OPSET_MUL.value: [operator.mul, mul, multiply], + OperatorSetNames.OPSET_DIV.value: [operator.truediv, div, divide], + OperatorSetNames.OPSET_MIN.value: [minimum], + OperatorSetNames.OPSET_MAX.value: [maximum], + OperatorSetNames.OPSET_PRELU.value: [PReLU, prelu], + OperatorSetNames.OPSET_SWISH.value: [SiLU, silu], + OperatorSetNames.OPSET_SIGMOID.value: [Sigmoid, sigmoid, F.sigmoid], + OperatorSetNames.OPSET_TANH.value: [Tanh, tanh, F.tanh], + OperatorSetNames.OPSET_GELU.value: [GELU, gelu], + OperatorSetNames.OPSET_HARDSIGMOID.value: [Hardsigmoid, hardsigmoid], + OperatorSetNames.OPSET_HARDSWISH.value: [Hardswish, hardswish], + OperatorSetNames.OPSET_FLATTEN.value: [Flatten, flatten], + OperatorSetNames.OPSET_GET_ITEM.value: [operator.getitem], + OperatorSetNames.OPSET_RESHAPE.value: [reshape], + OperatorSetNames.OPSET_UNSQUEEZE.value: [unsqueeze], + OperatorSetNames.OPSET_SQUEEZE.value: [squeeze], + OperatorSetNames.OPSET_PERMUTE.value: [permute], + OperatorSetNames.OPSET_TRANSPOSE.value: [transpose], + OperatorSetNames.OPSET_DROPOUT.value: [Dropout, dropout], + OperatorSetNames.OPSET_SPLIT.value: [split], + OperatorSetNames.OPSET_CHUNK.value: [chunk], + OperatorSetNames.OPSET_MAXPOOL.value: [MaxPool2d], + OperatorSetNames.OPSET_SIZE.value: [torch.Tensor.size], + OperatorSetNames.OPSET_SHAPE.value: [torch.Tensor.shape], + OperatorSetNames.OPSET_EQUAL.value: [equal], + OperatorSetNames.OPSET_ARGMAX.value: [argmax], + OperatorSetNames.OPSET_TOPK.value: [topk], + } + + pytorch_linear_attr_mapping = {KERNEL_ATTR: DefaultDict(default_value=PYTORCH_KERNEL), + BIAS_ATTR: DefaultDict(default_value=BIAS)} + self._opset2attr_mapping = {OperatorSetNames.OPSET_CONV.value: pytorch_linear_attr_mapping, + OperatorSetNames.OPSET_CONV_TRANSPOSE.value: pytorch_linear_attr_mapping, + OperatorSetNames.OPSET_FULLY_CONNECTED.value: pytorch_linear_attr_mapping}