From 72fa426537ce2bb28db43cf76ced3ed5c573edcd Mon Sep 17 00:00:00 2001 From: Suraj Sudhir Date: Wed, 28 Oct 2020 17:28:26 -0700 Subject: [PATCH] TOSA MLIR Dialect --- mlir/include/mlir/Dialect/CMakeLists.txt | 1 + mlir/include/mlir/Dialect/Tosa/CMakeLists.txt | 2 + .../mlir/Dialect/Tosa/IR/CMakeLists.txt | 17 + .../mlir/Dialect/Tosa/IR/TosaInterfaces.td | 34 + .../mlir/Dialect/Tosa/IR/TosaOpBase.td | 428 +++++ mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h | 56 + mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 1560 +++++++++++++++++ .../include/mlir/Dialect/Tosa/IR/TosaTraits.h | 33 + mlir/include/mlir/Dialect/Tosa/IR/TosaTypes.h | 31 + .../mlir/Dialect/Tosa/IR/TosaTypesBase.td | 129 ++ .../Dialect/Tosa/Transforms/CMakeLists.txt | 6 + .../mlir/Dialect/Tosa/Transforms/Passes.h | 36 + .../mlir/Dialect/Tosa/Transforms/Passes.td | 18 + .../mlir/Dialect/Tosa/Utils/QuantUtils.h | 84 + mlir/lib/Dialect/CMakeLists.txt | 1 + mlir/lib/Dialect/Tosa/CMakeLists.txt | 24 + mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 134 ++ .../Dialect/Tosa/Transforms/CMakeLists.txt | 13 + .../Tosa/Transforms/TosaMakeBroadcastable.cpp | 222 +++ mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp | 364 ++++ 20 files changed, 3193 insertions(+) create mode 100644 mlir/include/mlir/Dialect/Tosa/CMakeLists.txt create mode 100644 mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt create mode 100644 mlir/include/mlir/Dialect/Tosa/IR/TosaInterfaces.td create mode 100644 mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td create mode 100644 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h create mode 100644 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td create mode 100644 mlir/include/mlir/Dialect/Tosa/IR/TosaTraits.h create mode 100644 mlir/include/mlir/Dialect/Tosa/IR/TosaTypes.h create mode 100644 mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td create mode 100644 mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt create mode 100644 mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h create mode 100644 mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td create mode 100644 mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h create mode 100644 mlir/lib/Dialect/Tosa/CMakeLists.txt create mode 100644 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp create mode 100644 mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt create mode 100644 mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp create mode 100644 mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt index 103225948238f58..09c6ae569c18d7a 100644 --- a/mlir/include/mlir/Dialect/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/CMakeLists.txt @@ -13,4 +13,5 @@ add_subdirectory(SCF) add_subdirectory(Shape) add_subdirectory(SPIRV) add_subdirectory(StandardOps) +add_subdirectory(Tosa) add_subdirectory(Vector) diff --git a/mlir/include/mlir/Dialect/Tosa/CMakeLists.txt b/mlir/include/mlir/Dialect/Tosa/CMakeLists.txt new file mode 100644 index 000000000000000..9f57627c321fb0c --- /dev/null +++ b/mlir/include/mlir/Dialect/Tosa/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt new file mode 100644 index 000000000000000..5416bc777059909 --- /dev/null +++ b/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt @@ -0,0 +1,17 @@ +set(LLVM_TARGET_DEFINITIONS TosaOps.td) +mlir_tablegen(TosaOps.h.inc -gen-op-decls) +mlir_tablegen(TosaOps.cc.inc -gen-op-defs) +add_public_tablegen_target(MLIRTosaOpsIncGen) + +set(LLVM_TARGET_DEFINITIONS TosaOps.td) +mlir_tablegen(TosaStructs.h.inc -gen-struct-attr-decls) +mlir_tablegen(TosaStructs.cc.inc -gen-struct-attr-defs) +add_public_tablegen_target(MLIRTosaStructsIncGen) + + +set(LLVM_TARGET_DEFINITIONS TosaInterfaces.td) +mlir_tablegen(TosaInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(TosaInterfaces.cc.inc -gen-op-interface-defs) +add_public_tablegen_target(MLIRTosaInterfaceIncGen) + +add_mlir_doc(TosaOps -gen-op-doc TosaOps Dialects/) diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaInterfaces.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaInterfaces.td new file mode 100644 index 000000000000000..931ea05a5f6c006 --- /dev/null +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaInterfaces.td @@ -0,0 +1,34 @@ +//===-- TosaInterfaces.td - TOSA dialect interfaces *- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the dialect op interfaces for the TOSA dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef TOSA_OP_INTERFACES +#define TOSA_OP_INTERFACES + +include "mlir/IR/OpBase.td" + +def TosaOpInterface : OpInterface<"TosaOp"> { + let description = [{ + Implements interfaces for general Tosa op utility + }]; + + let methods = [ + InterfaceMethod< + [{Returns the TOSA version.}], + "StringRef", "getTOSAVersion", (ins), [{ + return "0.20"; + }] + >, + ]; + +} + +#endif diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td new file mode 100644 index 000000000000000..a8485030b963ef8 --- /dev/null +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td @@ -0,0 +1,428 @@ +//===-- TosaOpBase.td - TOSA dialect op builders *- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the operation builders for the TOSA dialect. +// +//===----------------------------------------------------------------------===// + + +#ifdef TOSA_OP_BASE +#else +#define TOSA_OP_BASE + +// Quantization attributes + +def Tosa_UnaryOpQuantizationAttr : StructAttr<"UnaryOpQuantizationAttr", Tosa_Dialect, [ + StructFieldAttr<"input_zp", I32Attr>, + StructFieldAttr<"output_zp", I32Attr>] > { + let description = "Attribute holding quantization information for Unary Ops."; +} + +def Tosa_ConvOpQuantizationAttr : StructAttr<"ConvOpQuantizationAttr", Tosa_Dialect, [ + StructFieldAttr<"input_zp", I32Attr>, + StructFieldAttr<"weight_zp", I32Attr>] > { + let description = "Attribute holding quantization information for Convolution Ops."; +} + +def Tosa_MatMulOpQuantizationAttr : StructAttr<"MatMulOpQuantizationAttr", Tosa_Dialect, [ + StructFieldAttr<"a_zp", I32Attr>, + StructFieldAttr<"b_zp", I32Attr>] > { + let description = "Attribute holding quantization information for Convolution Ops."; +} + +def Tosa_PadOpQuantizationAttr : StructAttr<"PadOpQuantizationAttr", Tosa_Dialect, [ + StructFieldAttr<"input_zp", I32Attr>] > { + let description = "Attribute holding quantization information for Pad Ops."; +} + +/////////////////////////////////// +// Tosa QuantizationInfo Builders +/////////////////////////////////// + +// ConvOp Quantization Info Builders + +// This builder is called on convolution operation types that need to create their +// OptionalAttr quantization_attr parameter. It happens transparently when legalize_something.cc calls +// the rewriter.replaceOpWithNewOp() or similar builder, with just the standard arguments and not the +// additional quantization_attr option. It may be explicitly specified, but is not necessary. +// If it is explicitly specified, a different auto-generated builder will handle it. +def Tosa_ConvOpQuantInfoBuilder : OpBuilderDAG< + (ins "Type":$output_type, "Value":$input, "Value":$filter, "Value":$bias, "ArrayAttr":$strides, "ArrayAttr":$dilations, "ArrayAttr":$padding), + [{ + $_state.addOperands(input); + $_state.addOperands(filter); + $_state.addOperands(bias); + $_state.addAttribute("strides", strides); + $_state.addAttribute("dilations", dilations); + $_state.addAttribute("padding", padding); + + auto quantattr = mlir::tosa::buildConvOpQuantizationAttr($_builder, + input, + filter); + if ( quantattr ) { + $_state.addAttribute("quantization_info", quantattr); + unsigned input_bits = input.getType().dyn_cast() + .getElementType().dyn_cast() + .getStorageTypeIntegralWidth(); + unsigned weight_bits = filter.getType().dyn_cast() + .getElementType().dyn_cast() + .getStorageTypeIntegralWidth(); + auto output_shape = output_type.dyn_cast().getShape(); + IntegerType acc_element_type; + if(input_bits == 16 && weight_bits == 8) { + acc_element_type = $_builder.getIntegerType(48); + } + else { + acc_element_type = $_builder.getI32Type(); + } + auto acc_type = RankedTensorType::get(output_shape, acc_element_type); + $_state.addTypes(acc_type); + } + else { + $_state.addTypes(output_type); + } + }]>; + +// A variant of ConvOpQuantInfo builder for transpose_conv op which has no bias parameter. +def Tosa_TransConvOpQuantInfoBuilder : OpBuilderDAG< + (ins "Type":$output_type, "Value":$input, "Value":$filter, "Value":$bias, "ArrayAttr":$strides, "ArrayAttr":$dilations, "ArrayAttr":$outpad, "ArrayAttr":$output_shape), + [{ + $_state.addOperands(input); + $_state.addOperands(filter); + $_state.addOperands(bias); + $_state.addAttribute("strides", strides); + $_state.addAttribute("dilations", dilations); + $_state.addAttribute("outpad", outpad); + $_state.addAttribute("output_shape", output_shape); + auto quantattr = mlir::tosa::buildConvOpQuantizationAttr($_builder, + input, + filter); + + if ( quantattr ) { + $_state.addAttribute("quantization_info", quantattr); + unsigned input_bits = input.getType().dyn_cast() + .getElementType().dyn_cast() + .getStorageTypeIntegralWidth(); + unsigned weight_bits = filter.getType().dyn_cast() + .getElementType().dyn_cast() + .getStorageTypeIntegralWidth(); + auto output_shape = output_type.dyn_cast().getShape(); + IntegerType acc_element_type; + if(input_bits == 16 && weight_bits == 8) { + acc_element_type = $_builder.getIntegerType(48); + } + else { + acc_element_type = $_builder.getI32Type(); + } + auto acc_type = RankedTensorType::get(output_shape, acc_element_type); + $_state.addTypes(acc_type); + } + else { + $_state.addTypes(output_type); + } + }]>; + +// All Conv legalizations are done in C++ so no TableGen builder + +// FullyConnectedOp Quantization Info Builder + +// This builder is called on FC operation that needs to create their +// OptionalAttr quantization_attr parameter. It happens transparently when legalize_something.cc calls +// the rewriter.replaceOpWithNewOp() or similar builder, with just the standard arguments and not the +// additional quantization_attr option. It may be explicitly specified, but is not necessary. +// If it is explicitly specified, a different auto-generated builder will handle it. +def Tosa_FCOpQuantInfoBuilder : OpBuilderDAG< + (ins "Type":$output_type, "Value":$input, "Value":$filter, "Value":$bias), + [{ + $_state.addOperands(input); + $_state.addOperands(filter); + $_state.addOperands(bias); + auto quantattr = mlir::tosa::buildConvOpQuantizationAttr($_builder, + input, + filter); + if ( quantattr ) { + $_state.addAttribute("quantization_info", quantattr); + unsigned input_bits = input.getType().dyn_cast() + .getElementType().dyn_cast() + .getStorageTypeIntegralWidth(); + unsigned weight_bits = filter.getType().dyn_cast() + .getElementType().dyn_cast() + .getStorageTypeIntegralWidth(); + auto output_shape = output_type.dyn_cast().getShape(); + IntegerType acc_element_type; + if(input_bits == 16 && weight_bits == 8) { + acc_element_type = $_builder.getIntegerType(48); + } + else { + acc_element_type = $_builder.getI32Type(); + } + auto acc_type = RankedTensorType::get(output_shape, acc_element_type); + $_state.addTypes(acc_type); + } + else { + $_state.addTypes(output_type); + } + }]>; + +// Similar to FCOpQuantInfoBuilder, but drop bias +def Tosa_MatMulOpQuantInfoBuilder : OpBuilderDAG< + (ins "Type":$output_type, "Value":$a, "Value":$b), + [{ + $_state.addOperands(a); + $_state.addOperands(b); + auto quantattr = mlir::tosa::buildMatMulOpQuantizationAttr($_builder, a, b); + + if ( quantattr ) { + $_state.addAttribute("quantization_info", quantattr); + unsigned input_bits = a.getType().dyn_cast() + .getElementType().dyn_cast() + .getStorageTypeIntegralWidth(); + auto output_shape = output_type.dyn_cast().getShape(); + IntegerType acc_element_type; + if(input_bits == 16) { + acc_element_type = $_builder.getIntegerType(48); + } + else { + acc_element_type = $_builder.getI32Type(); + } + auto acc_type = RankedTensorType::get(output_shape, acc_element_type); + $_state.addTypes(acc_type); + } + else { + $_state.addTypes(output_type); + } + }]>; + +// FC legalization done in C++ so no TableGen builder + +// This builder is called on pool operation types that need to create their +// OptionalAttr quantization_attr parameter. It works like earlier builders, except that pool has +// additional cmdline parameters that must be accommodated by the builder even though only +// the input and output parameters matter to the quantization op creation. +def Tosa_AvgPool2dOpQuantInfoBuilder : OpBuilderDAG< + (ins "Type":$output_type, "Value":$input, "ArrayAttr":$kernel_size, "ArrayAttr":$strides, "ArrayAttr":$padding), + [{ + $_state.addOperands(input); + $_state.addAttribute("kernel_size", kernel_size); + $_state.addAttribute("strides", strides); + $_state.addAttribute("padding", padding); + auto quantattr = mlir::tosa::buildUnaryOpQuantizationAttr($_builder, + input, + output_type); + if ( quantattr ) + $_state.addAttribute("quantization_info", quantattr); + $_state.types.push_back(output_type); + }]>; + +// This builder is called on single-parameter unary types that need to create their +// OptionalAttr quantization_attr parameter. +def Tosa_UnaryOpQuantInfoBuilder : OpBuilderDAG< + (ins "Type":$output_type, "Value":$input), + [{ + $_state.addOperands(input); + auto quantattr = mlir::tosa::buildUnaryOpQuantizationAttr($_builder, + input, + output_type); + if ( quantattr ) + $_state.addAttribute("quantization_info", quantattr); + $_state.types.push_back(output_type); + }]>; + +// This builder is called on single-parameter unary types that need to create their +// OptionalAttr quantization_attr parameter. +def Tosa_PadOpQuantInfoBuilder : OpBuilderDAG< + (ins "Type":$output_type, "Value":$input, "Value":$paddings), + [{ + $_state.addOperands(input); + $_state.addOperands(paddings); + auto quantattr = mlir::tosa::buildPadOpQuantizationAttr($_builder, + input); + if ( quantattr ) + $_state.addAttribute("quantization_info", quantattr); + $_state.types.push_back(output_type); + }]>; + +def Tosa_BroadcastableBinaryBuilder : OpBuilderDAG< + (ins "Value":$lhs, "Value":$rhs), + [{ + auto result_type = + OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType()); + if (!result_type) + mlir::emitError($_state.location, "Operands are not broadcastable"); + $_state.addOperands(lhs); + $_state.addOperands(rhs); + $_state.types.push_back(result_type); + }]>; + +/////////////////////////////// +// Tosa Operator Definitions +////////////////////////////// + +class Tosa_Op traits = []> : + Op { + + let extraClassDeclaration = [{ + static StringRef getTOSAOpType() { return ""; } // TBD + static StringRef getTOSAOpSubtype() { return ""; } // TBD + }]; + +} + +class Tosa_ElemwiseUnaryOp traits = []> : + Op { + + let extraClassDeclaration = [{ + static StringRef getTOSAOpType() { return "Elemwise"; } + static StringRef getTOSAOpSubtype() { return "Unary"; } + }]; + +} + +class Tosa_ElemwiseBinaryOp traits = []> : + Op { + + let extraClassDeclaration = [{ + static StringRef getTOSAOpType() { return "Elemwise"; } + static StringRef getTOSAOpSubtype() { return "Binary"; } + }]; + +} + +class Tosa_ElemwiseCompareOp traits = []> : + Op { + + let extraClassDeclaration = [{ + static StringRef getTOSAOpType() { return "Elemwise"; } + static StringRef getTOSAOpSubtype() { return "Compare"; } + }]; + +} + +class Tosa_ElemwiseTernaryOp traits = []> : + Op { + + let extraClassDeclaration = [{ + static StringRef getTOSAOpType() { return "Elemwise"; } + static StringRef getTOSAOpSubtype() { return "Ternary"; } + }]; + +} + +class Tosa_DataLayoutOp traits = []> : + Op { + + let extraClassDeclaration = [{ + static StringRef getTOSAOpType() { return "DataLayout"; } + static StringRef getTOSAOpSubtype() { return ""; } // TBD + }]; + +} + +class Tosa_DataNodeOp traits = []> : + Op { + + let extraClassDeclaration = [{ + static StringRef getTOSAOpType() { return "DataNode"; } + static StringRef getTOSAOpSubtype() { return ""; } // TBD + }]; + +} + +class Tosa_AggregationOp traits = []> : + Op { + + let extraClassDeclaration = [{ + static StringRef getTOSAOpType() { return "Aggregation"; } + static StringRef getTOSAOpSubtype() { return ""; } // TBD + }]; + +} + +class Tosa_TensorArgOp traits = []> : + Op { + + let extraClassDeclaration = [{ + static StringRef getTOSAOpType() { return "Tensor"; } + static StringRef getTOSAOpSubtype() { return "Arg"; } + }]; + +} + +class Tosa_TensorConvOp traits = []> : + Op { + + let extraClassDeclaration = [{ + static StringRef getTOSAOpType() { return "Tensor"; } + static StringRef getTOSAOpSubtype() { return "Conv"; } + }]; + +} + +class Tosa_TensorPoolOp traits = []> : + Op { + + let extraClassDeclaration = [{ + static StringRef getTOSAOpType() { return "Tensor"; } + static StringRef getTOSAOpSubtype() { return "Pool"; } + }]; + +} + +class Tosa_TensorImageOp traits = []> : + Op { + + let extraClassDeclaration = [{ + static StringRef getTOSAOpType() { return "Image"; } + static StringRef getTOSAOpSubtype() { return ""; } // TBD + }]; + +} + +class Tosa_ActivationOp traits = []> : + Op { + + let extraClassDeclaration = [{ + static StringRef getTOSAOpType() { return "Activation"; } + static StringRef getTOSAOpSubtype() { return ""; } // TBD + }]; + +} + +class Tosa_ReductionOp traits = []> : + Op { + + let extraClassDeclaration = [{ + static StringRef getTOSAOpType() { return "Reduction"; } + static StringRef getTOSAOpSubtype() { return ""; } // TBD + }]; + +} + +class Tosa_ImageOp traits = []> : + Op { + + let extraClassDeclaration = [{ + static StringRef getTOSAOpType() { return "Reduction"; } + static StringRef getTOSAOpSubtype() { return ""; } // TBD + }]; + +} + +class Tosa_ConversionOp traits = []> : + Op { + + let extraClassDeclaration = [{ + static StringRef getTOSAOpType() { return "Conversion"; } + static StringRef getTOSAOpSubtype() { return ""; } // TBD + }]; + +} + +// Specify traits of operators. + +#endif // TOSA_OP_BASE diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h new file mode 100644 index 000000000000000..6199f498ba99206 --- /dev/null +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h @@ -0,0 +1,56 @@ +//===-- TosaOps.h - TOSA dialect operation definitions *- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the TOSA Dialect in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TOSA_IR_TOSA_OPS_H +#define MLIR_DIALECT_TOSA_IR_TOSA_OPS_H + +#include +#include + +#include "mlir/Dialect/Quant/QuantOps.h" +#include "mlir/Dialect/Tosa/IR/TosaTraits.h" +#include "mlir/Dialect/Traits.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Interfaces/LoopLikeInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/LLVM.h" + +#include "mlir/Dialect/Tosa/IR/TosaStructs.h.inc" + +namespace mlir { +namespace tosa { + +class TosaDialect : public Dialect { + +public: + explicit TosaDialect(MLIRContext *context); + + static StringRef getDialectNamespace() { return "tosa"; } + + Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type, + Location loc) override; +}; + +#include "mlir/Dialect/Tosa/IR/TosaInterfaces.h.inc" + +} // end namespace tosa +} // end namespace mlir + +#define GET_OP_CLASSES +#include "mlir/Dialect/Tosa/IR/TosaOps.h.inc" + +#endif // TOSA_OPS_H diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td new file mode 100644 index 000000000000000..58a8ad00cc079ce --- /dev/null +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -0,0 +1,1560 @@ +//===-- TosaOps.td - TOSA dialect operation definitions *- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the operation set for the TOSA dialect. +// +//===----------------------------------------------------------------------===// + +#ifdef TOSA_OPS +#else +#define TOSA_OPS + +include "mlir/IR/OpBase.td" + +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/LoopLikeInterface.td" +include "mlir/Dialect/Tosa/IR/TosaInterfaces.td" + +include "mlir/Dialect/Tosa/IR/TosaTypesBase.td" + +def Tosa_Dialect : Dialect { + let name = "tosa"; + + let description = [{ + The Tosa dialect. + + Invariants: + + * All values are of Tensor type (in particular, scalars are + represented using zero-dimentional tensors); + }]; + + let cppNamespace = "mlir::tosa"; +} + +#ifdef TOSA_OP_BASE +#else +include "mlir/Dialect/Tosa/IR/TosaOpBase.td" +#endif + +/* TOSA Spec Section 2.2 */ +/* Operator Class: Tensor Data Engine Operators */ + +/* Operator: argmax */ + +def Tosa_ArgMaxOp : Tosa_TensorArgOp<"argmax", [NoSideEffect]> { + let summary = "Perform argmax on the input."; + + let description = [{ + This returns the index with the largest value across the given axis of the input tensor. + }]; + + let arguments = (ins + Tosa_Tensor: $input, + I64Attr: $axis); + + let results = (outs Tosa_Tensor: $output); + +} + +/* Operator: avg_pool2d */ + +def Tosa_AvgPool2dOp : Tosa_TensorPoolOp<"avg_pool2d", [NoSideEffect]> { + let summary = "Performs max pooling on the input."; + + let description = [{ + This performs an average pooling over the given input tensor. A sliding window of size + given by is passed over the input tensor, with the mean value being placed + in the output tensor. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + + Confined]>:$kernel_size, + Confined]>:$strides, + DefaultValuedAttr:$padding, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor:$output + ); + + let builders = [Tosa_AvgPool2dOpQuantInfoBuilder]; +} + +/* Operator: conv2d */ + +def Tosa_Conv2DOp : Tosa_TensorConvOp<"conv2d", [NoSideEffect]> { + let summary = [{ + 2D Convolution Operator + }]; + + let description = [{ + Performs a 2D convolution over the given tensor input, using the weight tensor. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + Tosa_Tensor:$filter, + Tosa_Tensor:$bias, + + DefaultValuedAttr:$strides, + DefaultValuedAttr:$dilations, + DefaultValuedAttr:$padding, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor:$output + ); + + let builders = [Tosa_ConvOpQuantInfoBuilder]; + +} + +/* Operator: conv3d */ + +def Tosa_Conv3DOp : Tosa_TensorConvOp<"conv3d", [NoSideEffect]> { + let summary = [{ + 3D Convolution operator + }]; + + let description = [{ + Performs a 3D convolution over the given input tensor. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + Tosa_Tensor:$filter, + Tosa_Tensor:$bias, + + DefaultValuedAttr]>, "{1, 1, 1, 1, 1}">:$strides, + DefaultValuedAttr]>, "{1, 1, 1, 1, 1}">:$dilations, + DefaultValuedAttr:$padding, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor:$output + ); + + let builders = [Tosa_ConvOpQuantInfoBuilder]; + +} + +def Tosa_DepthwiseConv2DOp : Tosa_TensorConvOp<"depthwise_conv2d", [NoSideEffect]> { + let summary = [{ + Depthwise 2D Convolution operator + }]; + + let description = [{ + Performs 2D convolutions separately over each channel of the given tensor input, using the weight tensor. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + Tosa_Tensor:$filter, + Tosa_Tensor:$bias, + + DefaultValuedAttr:$strides, + DefaultValuedAttr:$dilations, + DefaultValuedAttr:$padding, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor:$output + ); + + let builders = [Tosa_ConvOpQuantInfoBuilder]; + +} + +/* Operator: fully_connected */ + +def Tosa_FullyConnectedOp : Tosa_TensorConvOp<"fully_connected", [NoSideEffect]> { + let summary = "Fully Connected operator"; + + let description = [{ + Performs a fully connected network. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + Tosa_Tensor:$filter, + Tosa_TensorOfOrNone<[Tosa_AnyNumber]>:$bias, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor:$output + ); + + let builders = [Tosa_FCOpQuantInfoBuilder]; + +} + +/* Operator: matmul */ + +def Tosa_MatMulOp : Tosa_TensorConvOp<"matmul", [NoSideEffect]> { + let summary = "Matrix multiplication with bias"; + + let description = [{ + Performs a two dimensional matrix multiplication. This allows both inputs to be activations, + rather than reserving weights as an attribute in the FULLY_CONNECTED operator. + }]; + + let arguments = (ins + Tosa_Tensor:$a, + Tosa_Tensor:$b, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor:$c + ); + + let builders = [Tosa_MatMulOpQuantInfoBuilder]; + +} + +/* Operator: max_pool2d */ + +def Tosa_MaxPool2dOp : Tosa_TensorPoolOp<"max_pool2d", [NoSideEffect]> { + let summary = "Performs max pooling on the input."; + + let description = [{ + This performs a max pooling over the given input tensor. A sliding window of size given by + is passed over the input tensor, with the maximum value being placed in the + output tensor. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + + Confined]>:$kernel_size, + Confined]>:$strides, + DefaultValuedAttr:$padding + ); + + let results = (outs + Tosa_Tensor:$output + ); +} + +/* Operator: transpose_conv2d */ + +def Tosa_TransposeConv2DOp : Tosa_TensorConvOp<"transpose_conv2d", [NoSideEffect]> { + let summary = [{ + Transpose 2D Convolution operator. + }]; + + let description = [{ + Performs a 2D transposed convolution over the given tensor input, using the weights tensor. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + Tosa_Tensor:$filter, + Tosa_Tensor:$bias, + + I64ArrayAttr:$strides, + I64ArrayAttr:$dilations, + I64ArrayAttr:$outpad, + I64ArrayAttr:$output_shape, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor:$output + ); + + let builders = [Tosa_TransConvOpQuantInfoBuilder]; + +} + +/* TOSA Spec Section 2.3 */ +/* Operator Class: Activation Functions */ + +/* Operator: clamp */ + +def Tosa_ClampOp : Tosa_ActivationOp<"clamp", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes clamp(features, min, max)."; + + let description = [{ + Clamp to an arbitrary minimum and maximum value. Note that the maximum and minimum values are + specified as signed quantized values, no scaling happens before or after this operation. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + I64Attr:$min_int, + I64Attr:$max_int, + F32Attr:$min_fp, + F32Attr:$max_fp + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +/* Operator: reluN */ + +def Tosa_ReluNOp : Tosa_ActivationOp<"reluN", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes rectified linear: `max(features, N)`."; + + let description = [{ + ReLU with a scalar maximum value. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + I64Attr:$max_int, + F32Attr:$max_fp + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +/* Operator: sigmoid */ + +def Tosa_SigmoidOp : Tosa_ActivationOp<"sigmoid", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes elementwise sigmoid of input."; + + let description = [{ + Sigmoid function: output = 1 / (1 + exp(-input)) + For quantized integer data types, the TABLE operator should be used instead with the following definition. + The sigmoid table has 513 entries each of 16-bit precision and covering the input range -16.0 to +16.0 + in steps of 1/16. + }]; + + let arguments = (ins + Tosa_Tensor:$input + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +/* Operator: tanh */ + +def Tosa_TanhOp : Tosa_ActivationOp<"tanh", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes elementwise hyperbolic tangent of input"; + + let description = [{ + Parameterized hyperbolic tangent. + For quantized integer data types, the TABLE operator should be used instead with the following definition. + The tanh_table has 513 entries each of 16-bit precision and covering the input range -8.0 to +8.0 in steps of 1/32. + }]; + + let arguments = (ins + Tosa_Tensor:$input + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +/* TOSA Spec Section 2.4 */ +/* Operator Class: Elementwise unary/binary/ternary operators */ +/* Operator Subclass: Elementwise binary ops */ + +/* Operator: add */ + +def Tosa_AddOp : Tosa_ElemwiseBinaryOp<"add", [ResultsBroadcastableShape, NoSideEffect, Commutative]> { + let summary = "Elementwise addition operator"; + + let description = [{ + Elementwise addition of input1 and input2. Axis of size 1 will be broadcast, as necessary. + Rank of input tensors must match. + }]; + + let arguments = ( + ins Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs); + + let results = (outs Tosa_Tensor:$output); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +/* Operator: arithmetic_right_shift */ + +def Tosa_ArithmeticRightShiftOp : Tosa_ElemwiseBinaryOp<"arithmetic_right_shift", [ResultsBroadcastableShape, NoSideEffect]> { + let summary = "Elementwise Arithmetic Right Shift"; + + let description = [{ + Elementwise arithmetic right shift of input1 by the amount specified in input2. Axis of size 1 will be broadcast, as necessary. + Rank of input tensors must match. + }]; + + let arguments = ( + ins Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs); + + let results = (outs Tosa_Tensor:$output); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +/* Operator: bitwise_and */ + +def Tosa_BitwiseAndOp : Tosa_ElemwiseBinaryOp<"bitwise_and", [ResultsBroadcastableShape, NoSideEffect, Commutative]> { + let summary = "Bitwise AND operator"; + + let description = [{ + Elementwise bitwise AND of input tensor 0 and input tensor 1. Axis of size 1 will be broadcast as necessary. + Rank of input tensors must match. + }]; + + let arguments = ( + ins Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs); + + let results = (outs Tosa_Tensor:$output); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +/* Operator: bitwise_or */ + +def Tosa_BitwiseOrOp : Tosa_ElemwiseBinaryOp<"bitwise_or", [ResultsBroadcastableShape, NoSideEffect, Commutative]> { + let summary = "Bitwise OR operator"; + + let description = [{ + Elementwise bitwise OR of input1 and input2. Axis of size 1 will be broadcast as necessary. + Rank of input tensors must match. + }]; + + let arguments = ( + ins Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs); + + let results = (outs Tosa_Tensor:$output); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +/* Operator: bitwise_xor */ + +def Tosa_BitwiseXorOp : Tosa_ElemwiseBinaryOp<"bitwise_xor", [ResultsBroadcastableShape, NoSideEffect, Commutative]> { + let summary = "Bitwise XOR operator"; + + let description = [{ + Elementwise bitwise XOR of input1 and input2. Axis of size 1 will be broadcast as necessary. + Rank of input tensors must match. + }]; + + let arguments = ( + ins Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs); + + let results = (outs Tosa_Tensor:$output); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +/* Operator: logical_and */ + +def Tosa_LogicalAndOp : Tosa_ElemwiseBinaryOp<"logical_and", [ResultsBroadcastableShape, Commutative, NoSideEffect]> { + let summary = "Returns the truth value of x AND y element-wise."; + + let description = [{ + Elementwise logical AND of input1 and input2. Axis of size 1 will be broadcast, as necessary. + Rank of input tensors must match. + }]; + + let arguments = (ins + I1Tensor:$lhs, + I1Tensor:$rhs + ); + + let results = (outs + I1Tensor:$z + ); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +/* Operator: logical_left_shift */ + +def Tosa_LogicalLeftShiftOp : Tosa_ElemwiseBinaryOp<"logical_left_shift", [ResultsBroadcastableShape, NoSideEffect]> { + let summary = "Elementwise Logical Left Shift"; + + let description = [{ + Elementwise left shift of input1 and input2. Axis of size 1 will be broadcast, as necessary. + Rank of input tensors must match. + }]; + + let arguments = ( + ins Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs); + + let results = (outs Tosa_Tensor:$output); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +/* Operator: logical_right_shift */ + +def Tosa_LogicalRightShiftOp : Tosa_ElemwiseBinaryOp<"logical_right_shift", [ResultsBroadcastableShape, NoSideEffect]> { + let summary = "Elementwise Logical Right Shift"; + + let description = [{ + Elementwise logical right shift of input1 by the amount specified in input2. + Axis of size 1 will be broadcast, as necessary. + Rank of input tensors must match. + }]; + + let arguments = ( + ins Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs); + + let results = (outs Tosa_Tensor:$output); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +/* Operator: logical_or */ + +def Tosa_LogicalOrOp : Tosa_ElemwiseBinaryOp<"logical_or", [ResultsBroadcastableShape, Commutative, NoSideEffect]> { + let summary = "Returns the truth value of x OR y element-wise."; + + let description = [{ + Elementwise logical OR of input1 and input2. Axis of size 1 will be broadcast as necessary. + Rank of input tensors must match. + }]; + + let arguments = (ins + I1Tensor:$lhs, + I1Tensor:$rhs + ); + + let results = (outs + I1Tensor:$z + ); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +/* Operator: logical_xor */ + +def Tosa_LogicalXorOp : Tosa_ElemwiseBinaryOp<"logical_xor", [ResultsBroadcastableShape, Commutative, NoSideEffect]> { + let summary = "Returns the truth value of x XOR y element-wise."; + + let description = [{ + Elementwise logical XOR of input tensor 0 and input tensor 1. + Axis of size 1 will be broadcast as necessary. + Rank of input tensors must match. + }]; + + let arguments = (ins + I1Tensor:$lhs, + I1Tensor:$rhs + ); + + let results = (outs + I1Tensor:$z + ); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +/* Operator: maximum */ + +def Tosa_MaximumOp : Tosa_ElemwiseBinaryOp<"maximum", [ResultsBroadcastableShape, NoSideEffect, Commutative]> { + let summary = "Elementwise Maximum"; + + let description = [{ + Elementwise max of input1 and input2. Axis of size 1 will be broadcast, as necessary. + Rank of input tensors must match. + }]; + + let arguments = ( + ins Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs); + + let results = (outs Tosa_Tensor:$output); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +/* Operator: minimum */ + +def Tosa_MinimumOp : Tosa_ElemwiseBinaryOp<"minimum", [ResultsBroadcastableShape, NoSideEffect, Commutative]> { + let summary = "Elementwise Minimum"; + + let description = [{ + Elementwise minimum of input tensor 0 and input tensor 1. Axis of size 1 will be broadcast, as necessary. + Rank of input tensors must match. + }]; + + let arguments = ( + ins Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs); + + let results = (outs Tosa_Tensor:$output); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +/* Operator: mul */ + +def Tosa_MulOp : Tosa_ElemwiseBinaryOp<"mul", [ResultsBroadcastableShape, NoSideEffect, Commutative]> { + let summary = "Multiplication operator"; + + let description = [{ + Elementwise multiplication (Hadamard product) of input tensor 0 and input tensor 1. + Axis of size 1 will be broadcast, as necessary. + Rank of input tensors must match. + }]; + + let arguments = ( + ins Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs); + + let results = (outs Tosa_Tensor:$output); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +/* Operator: pow */ + +def Tosa_PowOp : Tosa_ElemwiseBinaryOp<"pow", [ResultsBroadcastableShape, NoSideEffect]> { + let summary = "Computes the power of one value to another."; + + let description = [{ + Elementwise input tensor 0 value raised to the power of input 1 tensor. + Axis of size 1 will be broadcast, as necessary. + Rank of input tensors must match. + }]; + + let arguments = (ins + Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs + ); + + let results = (outs + Tosa_Tensor:$z + ); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +/* Operator: sub */ + +def Tosa_SubOp : Tosa_ElemwiseBinaryOp<"sub", [ResultsBroadcastableShape, NoSideEffect]> { + let summary = "Elementwise subtraction operator"; + + let description = [{ + Elementwise subtraction of input tensor 0 and input tensor 1. + Axis of size 1 will be broadcast as necessary. + Rank of input tensors must match. + }]; + + let arguments = ( + ins Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs); + + let results = (outs Tosa_Tensor:$output); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +def Tosa_TableOp : Tosa_ElemwiseBinaryOp<"table", [NoSideEffect]> { + let summary = "Table lookup op"; + + let description = [{ + Interpolated table lookup operation. Input values are scaled to create a fixed-point 9.7 value. + The high 9 bits are used to index into the table. The fractional bits are used to interpolate + based on the looked up value and the index+1 value in the table. The TABLE operator then returns + a 16.7 interpolated value. Note that there must be 513 values to handle the full range of inputs. + + The TABLE operator is expected to be used as follows: + • A RECALE node is expected before the TABLE operator to scale the input to a full int16_t range + for the table lookup + • If an int16_t result is required then follow the TABLE operator with a RESCALE with a right + shift of 7 + • If an int8_t result is required then follow the TABLE operator with a RESCALE with a right + shift of 15 + }]; + + let arguments = ( ins + Tosa_Tensor: $input, + Tosa_Tensor: $lut + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +/* TOSA Spec Section 2.5 */ +/* Operator Class: Elementwise unary/binary/ternary operators */ +/* Operator Subclass: Elementwise unary ops */ + +def Tosa_AbsOp : Tosa_ElemwiseUnaryOp<"abs", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise abs op"; + + let description = [{ + Elementwise absolute value operation + }]; + + let arguments = (ins Tosa_Tensor:$input + ); + + let results = (outs Tosa_Tensor:$output + ); +} + +/* Operator: bitwise_not */ + +def Tosa_BitwiseNotOp : Tosa_ElemwiseUnaryOp<"bitwise_not", [ResultsBroadcastableShape, NoSideEffect]> { + let summary = "Bitwise NOT operator"; + + let description = [{ + Elementwise bitwise NOT of input tensor. + }]; + + let arguments = ( + ins Tosa_Tensor:$input + ); + + let results = (outs Tosa_Tensor:$output); + +} + +def Tosa_CeilOp : Tosa_ElemwiseUnaryOp<"ceil", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise ceil op"; + + let description = [{ + Elementwise ceiling operation + }]; + + let arguments = (ins Tosa_Tensor:$input + ); + + let results = (outs Tosa_Tensor:$output + ); +} + +def Tosa_ClzOp : Tosa_ElemwiseUnaryOp<"clz", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise count leading zero op"; + + let description = [{ + Elementwise count leading zeros operation + }]; + + let arguments = (ins Tosa_Tensor:$input + ); + let results = (outs Tosa_Tensor:$output + ); +} + +def Tosa_ExpOp : Tosa_ElemwiseUnaryOp<"exp", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise exp op"; + + let description = [{ + Elementwise e to the x operation + }]; + + let arguments = (ins Tosa_Tensor:$input + ); + + let results = (outs Tosa_Tensor:$output + ); +} + +def Tosa_FloorOp : Tosa_ElemwiseUnaryOp<"floor", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise floor op"; + + let description = [{ + Elementwise floor operation + }]; + + let arguments = (ins Tosa_Tensor:$input + ); + + let results = (outs Tosa_Tensor:$output + ); +} + +def Tosa_LogOp : Tosa_ElemwiseUnaryOp<"log", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise log op"; + + let description = [{ + Elementwise natural logarithm operation + }]; + + let arguments = (ins Tosa_Tensor:$input + ); + + let results = (outs Tosa_Tensor:$output + ); +} + +/* Operator: logical_not */ + +def Tosa_LogicalNotOp : Tosa_ElemwiseBinaryOp<"logical_not", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Returns the truth value of NOT x element-wise."; + + let description = [{ + Elementwise logical NOT of input. + }]; + + let arguments = (ins + I1Tensor:$x + ); + + let results = (outs + I1Tensor:$y + ); + +} + +def Tosa_NegateOp : Tosa_ElemwiseUnaryOp<"negate", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise negate op"; + + let description = [{ + Elementwise negation operation + }]; + + let arguments = (ins + Tosa_Tensor:$input, + OptionalAttr:$quantization_info + ); + let results = (outs Tosa_Tensor:$output + ); + + let builders = [Tosa_UnaryOpQuantInfoBuilder]; +} + +def Tosa_ReciprocalOp : Tosa_ElemwiseUnaryOp<"reciprocal", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise reciprocal op"; + + let description = [{ + Elementwise reciprocal operation. For integer operation, a TABLE should be used + with the appropriate ranges. + }]; + + let arguments = (ins Tosa_Tensor:$input + ); + let results = (outs Tosa_Tensor:$output + ); +} + +def Tosa_RsqrtOp : Tosa_ElemwiseUnaryOp<"rsqrt", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise 1/sqrt op"; + + let description = [{ + Elementwise reciprocal square root operation. For integer operation, a TABLE should be + used with the appropriate ranges. + }]; + + let arguments = (ins Tosa_Tensor:$input + ); + let results = (outs Tosa_Tensor:$output + ); +} + +/* TOSA Spec Section 2.6 */ +/* Operator Class: Elementwise unary/binary/ternary operators */ +/* Operator Subclass: Elementwise ternary ops */ + +/* Operator: select */ + +def Tosa_SelectOp : Tosa_ElemwiseTernaryOp<"select", [NoSideEffect]> { + let summary = "Elementwise select operator"; + + let description = [{ + Elementwise select of the output based on a condition. + }]; + + let arguments = (ins + Tosa_Tensor:$condition, + Tosa_Tensor:$a, + Tosa_Tensor:$b + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +/* TOSA Spec Section 2.7 */ +/* Operator Class: Logical Operations */ + +/* Operator: equal */ + +def Tosa_EqualOp : Tosa_ElemwiseCompareOp<"equal", [ResultsBroadcastableShape, Commutative, NoSideEffect]> { + let summary = "Returns the truth value of (x == y) element-wise."; + + let description = [{ + Elementwise comparison operation + }]; + + let arguments = (ins + Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs + ); + + let results = (outs + I1Tensor:$z + ); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +/* Operator: greater */ + +def Tosa_GreaterOp : Tosa_ElemwiseCompareOp<"greater", [ResultsBroadcastableShape, NoSideEffect]> { + let summary = "Returns the truth value of (x > y) element-wise."; + + let description = [{ + Elementwise greater than comparison operation + }]; + + let arguments = (ins + Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs + ); + + let results = (outs + I1Tensor:$z + ); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +/* Operator: greater_equal */ + +def Tosa_GreaterEqualOp : Tosa_ElemwiseCompareOp<"greater_equal", [ResultsBroadcastableShape, NoSideEffect]> { + let summary = "Returns the truth value of (x >= y) element-wise."; + + let description = [{ + Elementwise comparison operation + }]; + + let arguments = (ins + Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs + ); + + let results = (outs + I1Tensor:$z + ); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +/* TOSA Spec Section 2.8 */ +/* Operator Class: Reduction Ops */ + +/* Operator: reduce_all */ + +def Tosa_ReduceAllOp : Tosa_ReductionOp<"reduce_all", [NoSideEffect]> { + let summary = [{ + Reduce All operator + }]; + + let description = [{ + Reduce a tensor along the given axis with a logical AND operation + }]; + + let arguments = (ins + Tosa_Tensor:$input, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +/* Operator: reduce_any */ + +def Tosa_ReduceAnyOp : Tosa_ReductionOp<"reduce_any", [NoSideEffect]> { + let summary = [{ + Reduce Any operator + }]; + + let description = [{ + Reduce a tensor along the given axis with a logical OR operation + }]; + + let arguments = (ins + Tosa_Tensor:$input, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +/* Operator: reduce_max */ + +def Tosa_ReduceMaxOp : Tosa_ReductionOp<"reduce_max", [NoSideEffect]> { + let summary = [{ + Reduce Max operator + }]; + + let description = [{ + Reduce a tensor along the given axis with a maximum operation + }]; + + let arguments = (ins + Tosa_Tensor:$input, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +/* Operator: reduce_min */ + +def Tosa_ReduceMinOp : Tosa_ReductionOp<"reduce_min", [NoSideEffect]> { + let summary = [{ + Reduce Min operator + }]; + + let description = [{ + Reduce a tensor along the given axis with a minimum operation + }]; + + let arguments = (ins + Tosa_Tensor:$input, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +/* Operator: reduce_prod */ + +def Tosa_ReduceProdOp : Tosa_ReductionOp<"reduce_prod", [NoSideEffect]> { + let summary = [{ + Reduce Prod operator + }]; + + let description = [{ + Reduce a tensor along the given axis by computing the product of the axis. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +/* Operator: reduce_sum */ + +def Tosa_ReduceSumOp : Tosa_ReductionOp<"reduce_sum", [NoSideEffect]> { + let summary = [{ + Reduce Sum operator + }]; + + let description = [{ + Reduce a tensor along the given axis by computing the sum of the axis. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +/* TOSA Spec Section 2.9 */ +/* Operator Class: Data Layout / Memory Reinterpretation */ + +/* Operator: concat */ + +def Tosa_ConcatOp : Tosa_DataLayoutOp<"concat", [NoSideEffect]> { + let summary = "Concatenates tensors along one dimension."; + + let description = [{ + Concatenate two tensors along a given axis. No data conversion happens during a concat operation. + }]; + + let arguments = (ins + Tosa_Tensor:$a, + Tosa_Tensor:$b, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +/* Operator: pad */ + +def Tosa_PadOp : Tosa_DataLayoutOp<"pad", [NoSideEffect]> { + let summary = "Pads a tensor with zeros."; + + let description = [{ + Zero-pads a tensor along borders of each dimension. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + Tosa_Int32Or64Tensor:$paddings, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor:$output + ); + + let builders = [Tosa_PadOpQuantInfoBuilder]; + +} + +/* Operator: reshape */ + +def Tosa_ReshapeOp: Tosa_DataLayoutOp<"reshape", [ + NoSideEffect]> { + let summary = "Reshape operator"; + + let description = [{ + Returns a tensor with the same type/values as the input, with a new shape specified by the shape + argument. Reshape may operate on tensors of any rank. No data conversion happens during a reshape + operation. + }]; + + let arguments = ( + ins Tosa_Tensor:$input, + I64ArrayAttr:$shape); + + let results = (outs Tosa_Tensor:$output); +} + +/* Operator: reverse */ + +def Tosa_ReverseOp: Tosa_DataLayoutOp<"reverse", [ + NoSideEffect]> { + let summary = "Reverse operator"; + + let description = [{ + Returns a tensor with the same type/values as the input, with the data reversed along the given + axis. No data conversion happens during a reverse operation. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + I64Attr:$axis); + + let results = (outs + Tosa_Tensor:$output); +} + +/* Operator: slice */ + +def Tosa_SliceOp: Tosa_DataLayoutOp<"slice", [ + NoSideEffect]> { + let summary = "Slice operator"; + + let description = [{ + Extracts a slice of the input tensor 0 on the given axis, beginning at the start coordinates, + and extending for size elements in each direction. No data conversion happens during a slice operation. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + I64ArrayAttr:$begin, + I64ArrayAttr:$size + ); + + let results = (outs + Tosa_Tensor:$output); +} + +/* Operator: tile */ + +def Tosa_TileOp: Tosa_DataLayoutOp<"tile", [NoSideEffect]> { + let summary = "Tile operator"; + + let description = [{ + Replicates input 0 multiplies times along each dimension. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + I64ArrayAttr:$multiples); + + let results = (outs + Tosa_Tensor:$output); +} + +/* Operator: transpose */ + +def Tosa_TransposeOp : Tosa_DataLayoutOp<"transpose", [NoSideEffect]> { + let summary = "Transpose operator"; + + let description = [{ + Permutes the dimensions based on perm. + }]; + + let arguments = (ins + Tosa_Tensor:$x, + Tosa_Int32Or64Tensor:$perm + ); + + let results = ( + outs Tosa_Tensor:$y + ); + +} + +/* TOSA Spec Section 2.10 */ +/* Operator Class: Scatter/gather Operations */ + +/* Operator: gather */ + +def Tosa_GatherOp : Tosa_AggregationOp<"gather", [NoSideEffect]> { + let summary = [{ + Gather operation + }]; + + let description = [{ + Generate a tensor for which each element in the output is a subtensor of the values tensor along + the given axis, based on the value of indices. + }]; + + let arguments = (ins + Tosa_Tensor:$params, + Tosa_Int32Or64Tensor:$indices, + I64Attr:$axis, + + DefaultValuedAttr:$batch_dims + ); + + let results = (outs + Tosa_Tensor:$z + ); + +} + +/* TOSA Spec Section 2.11 */ +/* Operator Class: Image Frontend Functions */ + +/* Operator: resize */ + +def Tosa_ResizeOp : Tosa_ImageOp<"resize", [NoSideEffect]> { + + let summary = "Resize operation, supports various resize/upsample modes"; + + let description = [{ + Resizes a tensor. Resize is only allowed in the H and W dimensions. In expected use, + stride_y is approximately (IH< { + + let summary = "Cast operation"; + + let description = [{ + Performs a set of permissible cast operations + Mode Input Output + --------------------------------------- + signed 8 to bool int8 Boolean + signed 16 to bool int16 Boolean + signed 32 to bool int32 Boolean + bool to 8 Boolean int8 + bool to 16 Boolean int16 + bool to 32 Boolean int32 + signed 8 to signed 16 int8 int16 + signed 8 to signed 32 int8 int32 + signed 16 to signed 8 int16 int8 + signed 16 to signed 32 int16 int32 + signed 32 to signed 8 int32 int8 + signed 32 to signed 16 int32 int16 + float to signed 8 float int8 + float to signed 16 float int16 + signed 8 to float int8 float + signed 16 to float int16 float + }]; + + let arguments = ( + ins Tosa_Tensor:$input + ); + + let results = (outs Tosa_Tensor:$output); + +} + +/* Operator: rescale */ + +def Tosa_RescaleOp: Tosa_ConversionOp<"rescale", [NoSideEffect]> { + let summary = "Tosa rescale operator"; + + let description = [{ + Rescale quantized values into a new domain. Supported rescalings are: + Mode Input Output + signed 8 to 8 aint8 aint8 + signed 8 to 16 aint8 int16 + signed 8 to 32 aint8 int32 + signed 16 to 8 int16 aint8 + signed 16 to 16 int16 int16 + signed 16 to 32 int16 int32 + signed 32 to 8 int32 aint8 + signed 32 to 16 int32 int16 + signed 32 to 32 int32 int32 + signed 48 to 8 int48 aint8 + signed 48 to 16 int48 int16 + signed 48 to 32 int48 int32 + unsigned 8 to signed 8 uint8 aint8 + signed 8 to unsigned 8 aint8 uint8 + }]; + + let arguments = (ins + Tosa_Tensor:$input, + I32Attr:$input_zp, + I32Attr:$output_zp, + I32ArrayAttr:$multiplier, + I32ArrayAttr:$shift, + BoolAttr:$scale32, + BoolAttr:$double_round, + BoolAttr:$per_channel + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +/* TOSA Spec Section 2.13 */ +/* Operator Class: Data Node Ops */ + +/* Operator: const */ + +def Tosa_ConstOp : Tosa_DataNodeOp<"const", [ConstantLike, NoSideEffect, FirstAttrDerivedResultType]> { + let summary = "Constant op."; + + let description = [{ + A node containing constant data for use as the input to an operation. May hold data + in any of the supported data formats. + }]; + + let arguments = (ins ElementsAttr:$value); + + let results = (outs AnyTensor:$output); + + let builders = [ + OpBuilderDAG<(ins "Type":$type, "Attribute":$value)>, + ]; + +} + +/* Operator: identity */ + +def Tosa_IdentityOp: Tosa_DataNodeOp<"identity", [NoSideEffect]> { + let summary = "Identity operator"; + let description = [{ + Returns a tensor with the same shape, size, type + and content as the input. + }]; + + let arguments = (ins + Tosa_Tensor:$input + ); + + let results = (outs + Tosa_Tensor:$output); +} + +/* Operator: identityn */ + +def Tosa_IdentityNOp: Tosa_DataNodeOp<"identityn", [NoSideEffect]> { + let summary = "IdentityN operator"; + let description = [{ + Returns a list of tensors with the same shape, type, and contents as the + input list of tensors. + }]; + + let arguments = (ins + Variadic:$input + ); + + let results = (outs + Variadic:$output); +} + +/* Operator: placeholder */ + +def Tosa_PlaceholderOp : Tosa_DataNodeOp<"placeholder", [NoSideEffect]> { + let summary = "Placeholder op"; + + let description = [{ + A node where data will be inserted into the network at runtime. Generally used for inputs to the network. + }]; + + let arguments = (ins + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +/* TOSA Spec Section 2.14 */ +/* Operator Class: Custom Operators */ + +/* Operator: custom */ + +def Tosa_CustomOp : Tosa_Op<"custom"> { + + let summary = "Custom operator wrapper for Tosa"; + + let description = [{ + Hardware implementing TOSA may choose to add additional custom operators that are not expressed in + the existing TOSA operations. These operators are not expected to be portable across TOSA + implementations. The input and output signatures must be expressed in the corresponding TOSA node. + }]; + + let arguments = (ins + StrAttr:$identifier, + Variadic:$inputs + ); + + let results = (outs + Variadic:$outputs + ); + +} + +/* TOSA Spec Section 2.15 */ +/* Operator Class: Control Flow Operators */ + +/* Operator: cond_if */ + +def Tosa_IfOp : Tosa_Op<"cond_if", [ + SingleBlockImplicitTerminator<"YieldOp">]> { + let summary = "Conditional if operator"; + + let description = [{ + Evaluates a Boolean condition and then takes one of two distinct execution paths. This + implements the semantic If-then-else structure. + }]; + + let arguments = (ins + Tosa_Tensor:$cond, + Variadic:$inputs + ); + + let results = (outs + Variadic:$output + ); + + let regions = (region + SizedRegion<1>:$then_branch, + SizedRegion<1>:$else_branch + ); + +} + +/* Operator: while_loop */ + +def Tosa_WhileOp : Tosa_Op<"while_loop", [ + DeclareOpInterfaceMethods, + SingleBlockImplicitTerminator<"YieldOp">]> { + let summary = [{ + output = input; While (Cond(output)) { output = Body(output) } + }]; + + let description = [{ + Generates and evaluates a Bool condition and either executes a loop body or exits to + another control point. This action is performed repeatedly after updating and re-evaluating + the Boolean condition every iteration. This implements the semantic foreach or while + iterative loop structure. + }]; + + let arguments = (ins + Variadic:$inputs + ); + + let results = (outs + Variadic:$output); + + let regions = (region + SizedRegion<1>:$cond, + SizedRegion<1>:$body + ); + +} + +def Tosa_YieldOp : Tosa_Op<"yield", [Terminator]> { + let summary = "yield operator"; + + let description = [{ + return operation within the conditional abd body of + structured control flow. Operation takes variadic operands + but produces no results of its own. + }]; + + let arguments = (ins + Variadic:$inputs + ); + +} + +#endif // TOSA_OPS diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTraits.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaTraits.h new file mode 100644 index 000000000000000..b95d745330cbdd3 --- /dev/null +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTraits.h @@ -0,0 +1,33 @@ +//===-- TosaTraits.h - TOSA dialect operation traits *- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the TOSA Dialect OpTraits in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOSA_IR_TOSA_TRAITS_H +#define MLIR_TOSA_IR_TOSA_TRAITS_H + +#include "mlir/Dialect/Quant/QuantTypes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir { +namespace OpTrait { +namespace tosa { + +// TBD + +} +} // namespace OpTrait +} // namespace mlir + +#endif // MLIR_TOSA_IR_TOSA_TRAITS_H diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypes.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypes.h new file mode 100644 index 000000000000000..db179a96934e8d9 --- /dev/null +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypes.h @@ -0,0 +1,31 @@ +//===-- TosaTypes.h - TOSA dialect type definitions *- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the TOSA Dialect Types in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef TENSORFLOW_COMPILER_MLIR_TOSA_IR_TOSA_TYPES_H +#define TENSORFLOW_COMPILER_MLIR_TOSA_IR_TOSA_TYPES_H + +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" + +namespace mlir { + +namespace tosa { + +// TOSA specific types go here + +} // namespace tosa + +} // end namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TOSA_IR_TOSA_TYPES_H diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td new file mode 100644 index 000000000000000..bb9341011f6cfab --- /dev/null +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -0,0 +1,129 @@ +//===-- TosaTypesBase.td - TOSA type definitions *- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the type definitions for the TOSA dialect. +// +//===----------------------------------------------------------------------===// + +#ifdef TOSA_TYPES_BASE +#else +#define TOSA_TYPES_BASE + +include "mlir/IR/OpBase.td" + +/////////////////////////// +// Tosa Type Definitions +/////////////////////////// + +// The base class of a quantized type. +// Param tuple is: [bitwidth, zeropt, smantissa, sexp, low_end, high_end] +// Where low and high ends are 0,255 when unsigned, -128,127 when signe, for the 8-bit case +class Tosa_QuantizedType params, bit signed> + : Type()">, + CPred<"$_self.cast()" # + ".getStorageTypeIntegralWidth() == " # !head(params)>]>, + "Q" # !if (signed, "int", "uint") # !head(params) # " type"> { + string name = n; + string asTraitArgsStr = + StrJoinInt.result # !if(signed, ", true", ", false"); +} + +/* Non-Quantized Signed Integer Types + Used to express accumulator results or compare results */ + +// Booleans are currently assumed to be expressed using int8 or +// built in U1 / U1Tensor type + +def Tosa_Int32 : SI<32>; +def Tosa_Int48 : SI<48>; +def Tosa_Int64 : SI<64>; + +// Any signed integer type +def Tosa_SignedInt : AnyTypeOf<[Tosa_Int32, + Tosa_Int48, + Tosa_Int64]>; + +def Tosa_Int : AnyTypeOf<[Tosa_SignedInt]>; + +def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32, + Tosa_Int64]>; + +// Any integer tensor types +def Tosa_IntTensor : TensorOf<[Tosa_SignedInt]>; + +// Any integer tensor types +def Tosa_Int32Or64Tensor : TensorOf<[Tosa_Int32Or64]>; + +/* Quantized Integer Types + Datatype for network feature map or weight content */ + +def Tosa_Quint8 : Tosa_QuantizedType<"Uniform", [8], 0>; +def Tosa_Qint8 : Tosa_QuantizedType<"Uniform", [8], 1>; +def Tosa_Qint16 : Tosa_QuantizedType<"Uniform", [16], 1>; + +// Any quantized type +// aint8 : asymmetric per tensor, signed +// uint8: asymmetric per tensor , unsigned +// int4: symmetric per channel, signed +// int8 : symmetric per tensor/per channel, signed +// int16: symmetric per tensor, signed +def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"aint8", [8], 1>, + Tosa_QuantizedType<"uint8", [8], 0>, + Tosa_QuantizedType<"int4", [4, 0], 1>, + Tosa_QuantizedType<"int8", [8, 0], 1>, + Tosa_QuantizedType<"int16", [16, 0], 1>]>; + +/* Floating-point types */ + +def Tosa_Float : AnyTypeOf<[F32, + F16, + BF16]>; + +def Tosa_FpTensor : TensorOf<[Tosa_Float]>; + +// Multi-category type constraints + +def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, Tosa_Float], + "number">; + +// Tensors exclusively of numerical types +def Tosa_Tensor : TensorOf<[Tosa_AnyNumber]>; + +// Tensors exclusively of quantized types +def Tosa_QTypeTensor : TensorOf<[Tosa_QuantizedInt]>; + +class Tosa_TensorOrNone possibleTypes, string description = ""> : + AnyTypeOf<[TensorOf, NoneType], description>; + +// Any tensor element type allowed in Tosa ops +def Tosa_ElementType : Type, + "tosa.dtype">; + +// Tensor or None type. +class Tosa_TensorOfOrNone allowedTypes, string description = ""> : + AnyTypeOf<[TensorOf, NoneType], description>; + +// Any Tosa tensor type including string and bool +def Tosa_AnyTensor : TensorOf<[Tosa_ElementType]>; + +// String attribute constraints + +def Tosa_ResizeTypeAttr : StringBasedAttr< + CPred<"$_self.cast().getValue() == \"TRANSPOSE\" || " # + "$_self.cast().getValue() == \"BILINEAR\" || " # + "$_self.cast().getValue() == \"NEAREST_NEIGHBOR\"">, + "Supported resize/upsampling strategies">; + +def Tosa_TensorTypeAttr : TypeAttrBase<"TensorType", "Tensor type attribute">; + +// Tensor to buffer +def Tosa_Buffer : MemRefOf<[Tosa_AnyNumber]>; +def Tosa_TupleBuffer : NestedTupleOf<[Tosa_Buffer]>; +def Tosa_BufOrTuple : AnyTypeOf<[Tosa_Buffer, Tosa_TupleBuffer]>; + +#endif // TOSA_TYPES_BASE diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt new file mode 100644 index 000000000000000..d9b5375188b8bbc --- /dev/null +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt @@ -0,0 +1,6 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name TosaOpt) +add_public_tablegen_target(MLIRTosaPassIncGen) +add_dependencies(mlir-headers MLIRTosaPassIncGen) + +add_mlir_doc(Passes -gen-pass-doc TosaPasses ./) diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h new file mode 100644 index 000000000000000..fa572f4d3c90a8a --- /dev/null +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h @@ -0,0 +1,36 @@ +//===-- Passes.h - TOSA optimization pass declarations *- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the optimization passes for the TOSA Dialect in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H +#define MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { + +class FuncOp; +class ModuleOp; +class Pass; +template +class OperationPass; + +namespace tosa { + +std::unique_ptr> CreateTosaMakeBroadcastablePass(); + +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" + +} // namespace tosa +} // namespace mlir + +#endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td new file mode 100644 index 000000000000000..358c8b08b624d1a --- /dev/null +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -0,0 +1,18 @@ +//===-- Passes.td - TOSA optimization pass declarations *- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the optimization passes for the TOSA Dialect in MLIR. +// +//===----------------------------------------------------------------------===// + +include "mlir/Pass/PassBase.td" + +def TosaBinaryInputReshapePass : Pass<"tosa-make-broadcastable", "FuncOp"> { + let summary = "TOSA rank Reshape to enable Broadcasting"; + let constructor = "CreateTosaMakeBroadcastablePass()"; +} diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h new file mode 100644 index 000000000000000..bd4eaa67c12fd4a --- /dev/null +++ b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h @@ -0,0 +1,84 @@ +//===-- QuantUtils.h - TOSA numerical support declarations *- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Function declarations for TOSA numerical support functions and quantization +// attribute builders +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TOSA_UTILS_QUANT_UTILS_H +#define MLIR_DIALECT_TOSA_UTILS_QUANT_UTILS_H + +// Utils to support quantization handling in Tosa + +#include +#include +#include +#include +#include + +#include "mlir/Dialect/Quant/FakeQuantSupport.h" +#include "mlir/Dialect/Quant/UniformSupport.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringSwitch.h" + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" + +namespace mlir { +namespace tosa { + +void computeMultiplierAndShift(double scale, int32_t &multiplier, + int32_t &shift, int32_t scale_width); +void computeMultiplierAndShiftGtOne(double scale, int32_t &multiplier, + int32_t &shift, int32_t scale_width); +void computeMultiplierAndShiftLtOneExp(double scale, int32_t &multiplier, + int32_t &shift, int32_t scale_width); + +ConvOpQuantizationAttr buildConvOpQuantizationAttr(mlir::OpBuilder &builder, + Value input, Value weight); + +MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(mlir::OpBuilder &builder, + Value a, Value b); + +UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(mlir::OpBuilder &builder, + Value input, + Type output_raw_type); + +PadOpQuantizationAttr buildPadOpQuantizationAttr(mlir::OpBuilder &builder, + Value input); + +Type buildQTypeFromMinMax(OpBuilder builder, Type input_dtype, + Attribute minattr, Attribute maxattr, + IntegerAttr quant_bits, int filter_quantdim, + bool issigned, BoolAttr narrow_range); + +TypeAttr buildQTypeAttrFromMinMax(OpBuilder builder, Type input_dtype, + Attribute minattr, Attribute maxattr, + IntegerAttr quant_bits, int filter_quantdim, + bool issigned, BoolAttr narrow_range); + +} // namespace tosa +} // namespace mlir + +#endif // MLIR_DIALECT_TOSA_UTILS_QUANT_UTILS_H diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt index 24ffb192338a5d3..bc44049e2ef6041 100644 --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -14,6 +14,7 @@ add_subdirectory(SDBM) add_subdirectory(Shape) add_subdirectory(SPIRV) add_subdirectory(StandardOps) +add_subdirectory(Tosa) add_subdirectory(Vector) set(LLVM_OPTIONAL_SOURCES diff --git a/mlir/lib/Dialect/Tosa/CMakeLists.txt b/mlir/lib/Dialect/Tosa/CMakeLists.txt new file mode 100644 index 000000000000000..237c2b238eaf75a --- /dev/null +++ b/mlir/lib/Dialect/Tosa/CMakeLists.txt @@ -0,0 +1,24 @@ +add_mlir_dialect_library(MLIRTosa + IR/TosaOps.cpp + Utils/QuantUtils.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa + + DEPENDS + MLIRStandardOpsIncGen + MLIRTosaOpsIncGen + MLIRTosaStructsIncGen + MLIRTosaInterfaceIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRStandard + MLIRCallInterfaces + MLIRControlFlowInterfaces + MLIRSideEffectInterfaces + MLIRViewLikeInterface + ) + +# add_subdirectory(Serialization) +add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp new file mode 100644 index 000000000000000..62d848812fec7e9 --- /dev/null +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -0,0 +1,134 @@ +//===- TosaOps.cpp - MLIR Dialect for TOSA ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the TOSA Specification: +// https://developer.mlplatform.org/w/tosa/ +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" + +#include +#include +#include +#include +#include +#include + +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" +#include "mlir/Dialect/Traits.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/Parser.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/FoldUtils.h" +#include "mlir/Transforms/InliningUtils.h" +#include "mlir/Transforms/RegionUtils.h" + +using namespace mlir; +using namespace mlir::tosa; + +#include "mlir/Dialect/Tosa/IR/TosaStructs.cc.inc" + +// Tosa dialect interfaces + +#include "mlir/Dialect/Tosa/IR/TosaInterfaces.cc.inc" + +// Dialect Function Inliner Interface + +struct TosaInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + bool isLegalToInline(Operation *op, Region *region, + BlockAndValueMapping &map) const { + return true; + } + + /* This ensures the callable region of the operator can be inlined. + Without this, the regions will NOT inline. */ + bool isLegalToInline(Region *dest, Region *src, + BlockAndValueMapping &map) const { + return (isa(dest->getParentOp()) || + isa(dest->getParentOp())); + } +}; + +// TOSA control flow support + +Region &tosa::WhileOp::getLoopBody() { return body(); } + +bool tosa::WhileOp::isDefinedOutsideOfLoop(Value value) { + // WIP MLIR enhancements with exposed API + return false; +} + +LogicalResult WhileOp::moveOutOfLoop(llvm::ArrayRef ops) { + if (ops.empty()) + return success(); + + Operation *tosa_while_op = this->getOperation(); + for (auto op : ops) + op->moveBefore(tosa_while_op); + + return success(); +} + +struct TosaDialectFoldInterface : public DialectFoldInterface { + using DialectFoldInterface::DialectFoldInterface; + + bool shouldMaterializeInto(Region *region) const final { + return isa(region->getParentOp()); + } +}; + +// Tosa Dialect + +TosaDialect::TosaDialect(MLIRContext *context) + : Dialect(getDialectNamespace(), context, TypeID::get()) { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/Tosa/IR/TosaOps.cc.inc" + >(); + addInterfaces(); + + allowUnknownOperations(); +} + +//===----------------------------------------------------------------------===// +// ConstOp +//===----------------------------------------------------------------------===// + +void ConstOp::build(OpBuilder &builder, OperationState &result, Type type, + Attribute value) { + result.addTypes(type); + result.addAttribute("value", value); + return; +} + +#define GET_OP_CLASSES +#include "mlir/Dialect/Tosa/IR/TosaOps.cc.inc" + +Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value, + Type type, Location loc) { + if (value.isa() || + (value.isa() && value.getType() != type)) + return builder.create(loc, type, value.cast()); + return nullptr; +} diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt new file mode 100644 index 000000000000000..04acbf6425b75fc --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -0,0 +1,13 @@ +add_mlir_dialect_library(MLIRTosaTransforms + TosaMakeBroadcastable.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms + + DEPENDS + MLIRTosaPassIncGen + + LINK_LIBS PUBLIC + MLIRPass + MLIRTosa + ) diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp new file mode 100644 index 000000000000000..0a5b6a9a736dd39 --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp @@ -0,0 +1,222 @@ +//===- TosaMakeBroadcastable.cpp ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Insert reshape to binary op's input if needed to match rank +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include + +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tosa/IR//TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define PASS_NAME "tosa-make=broadcastable" +#define DEBUG_TYPE PASS_NAME + +namespace mlir { + +namespace tosa { + +namespace { + +class TosaMakeBroadcastable + : public PassWrapper { +public: + explicit TosaMakeBroadcastable() {} + void runOnFunction() override; +}; + +#define REPLACE_OP_LOGICAL(tosa_op, LHS_VALUE, RHS_VALUE) + +#define REPLACE_OP(tosa_op, LHS_VALUE, RHS_VALUE) \ + { \ + rewriter.replaceOpWithNewOp(op, output_type, LHS_VALUE, \ + RHS_VALUE); \ + } + +/* the legalization macro that reshapes lower rank input to output's shape + * if lower=[a], target=[a, b, c], [a] reshaped into [a, 1, 1] + * if lower=[b], target=[a, b, c], [b] should but NOT YET reshaped into [1, b, + * 1] (TODO) + * if lower=[c], target=[a, b, c], [c] reshaped into [1, 1, c] + * if lower=[a, c], target=[a, b, c], [a, c] reshaped into [a, 1, c] + * if lower=[a, b], target=[a, b, c], [a, b] reshaped into [a, b, 1] + * if lower=[b, c], target=[a, b, c], [b, c] reshaped into [1, b, c] + * if lower=[a], target=[a, a], [a] reshaped into [1, a] instead of [a, 1] + * if lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a] + * if lower=[], target=[a, b, c], [] reshaped into [1, 1, 1] */ + +#define DECL_TOSACONVERT_OP(tosa_op) \ + struct ConvertTosa##tosa_op##Op : public RewritePattern { \ + explicit ConvertTosa##tosa_op##Op(MLIRContext *context) \ + : RewritePattern(tosa::tosa_op##Op::getOperationName(), 1, context) {} \ + LogicalResult matchAndRewrite(Operation *op, \ + PatternRewriter &rewriter) const { \ + auto tosa_binary_op = cast(op); \ + \ + auto lhs = tosa_binary_op.lhs(); \ + auto rhs = tosa_binary_op.rhs(); \ + \ + int64_t lhs_rank = lhs.getType().dyn_cast().getRank(); \ + int64_t rhs_rank = rhs.getType().dyn_cast().getRank(); \ + \ + auto output_type = \ + tosa_binary_op.getResult().getType().dyn_cast(); \ + \ + int64_t higher_rank, lower_rank; \ + Value higher_tensor_value, lower_tensor_value; \ + /* return if rank already match */ \ + if (lhs_rank == rhs_rank) { \ + return failure(); \ + } else if (lhs_rank > rhs_rank) { \ + higher_rank = lhs_rank; \ + lower_rank = rhs_rank; \ + higher_tensor_value = lhs; \ + lower_tensor_value = rhs; \ + } else { \ + higher_rank = rhs_rank; \ + lower_rank = lhs_rank; \ + higher_tensor_value = rhs; \ + lower_tensor_value = lhs; \ + } \ + \ + ArrayRef higher_rank_shape = output_type.getShape(); \ + ArrayRef lower_rank_shape = lower_tensor_value.getType() \ + .dyn_cast() \ + .getShape(); \ + \ + SmallVector reshape_output_shape; \ + reshape_output_shape.assign(higher_rank, 1); \ + \ + int64_t higher_left_index = 0; \ + int64_t higher_right_index = higher_rank; \ + int64_t lower_left_index = 0; \ + int64_t lower_right_index = lower_rank; \ + int64_t higher_rank_dim, lower_rank_dim; \ + \ + if (lower_right_index != 0 && higher_right_index != 0) { \ + while (true) { \ + higher_rank_dim = higher_rank_shape[higher_right_index - 1]; \ + lower_rank_dim = lower_rank_shape[lower_right_index - 1]; \ + if (higher_rank_dim == lower_rank_dim) { \ + reshape_output_shape[higher_right_index - 1] = higher_rank_dim; \ + \ + if (higher_right_index > 0) { \ + higher_right_index--; \ + } \ + \ + if (lower_right_index > 0) { \ + lower_right_index--; \ + } \ + \ + if (higher_right_index == 0 || lower_right_index == 0) { \ + break; \ + } \ + } else { \ + break; \ + } \ + } \ + if (lower_right_index != 0 && higher_right_index != 0) { \ + while (true) { \ + higher_rank_dim = higher_rank_shape[higher_left_index]; \ + lower_rank_dim = lower_rank_shape[lower_left_index]; \ + if (higher_rank_dim == lower_rank_dim) { \ + reshape_output_shape[higher_left_index] = higher_rank_dim; \ + \ + if (higher_left_index < higher_right_index) { \ + higher_left_index++; \ + } \ + \ + if (lower_left_index < lower_right_index) { \ + lower_left_index++; \ + } \ + \ + if (higher_left_index == higher_right_index || \ + lower_left_index == lower_right_index) { \ + break; \ + } \ + } else { \ + break; \ + } \ + } \ + } \ + } \ + \ + auto reshape_input_type = \ + lower_tensor_value.getType().dyn_cast(); \ + auto reshape_output_type = \ + RankedTensorType::get(ArrayRef(reshape_output_shape), \ + reshape_input_type.getElementType()); \ + \ + auto reshape_lower = rewriter.create( \ + op->getLoc(), reshape_output_type, lower_tensor_value, \ + rewriter.getI64ArrayAttr(reshape_output_shape)); \ + \ + if (lhs_rank > rhs_rank) { \ + REPLACE_OP(tosa_op, higher_tensor_value, reshape_lower.getResult()); \ + } else { \ + REPLACE_OP(tosa_op, reshape_lower.getResult(), higher_tensor_value); \ + } \ + \ + return success(); \ + } \ + }; +DECL_TOSACONVERT_OP(Add) +DECL_TOSACONVERT_OP(Sub) +DECL_TOSACONVERT_OP(Mul) +DECL_TOSACONVERT_OP(LogicalLeftShift) +DECL_TOSACONVERT_OP(ArithmeticRightShift) +DECL_TOSACONVERT_OP(LogicalRightShift) +#undef DECL_TOSACONVERT_OP + +#undef REPLACE_OP + +void TosaMakeBroadcastable::runOnFunction() { + OwningRewritePatternList patterns; + auto *ctx = &getContext(); + auto func = getFunction(); + + // Add the generated patterns to the list. + patterns.insert(ctx); + patterns.insert(ctx); + patterns.insert(ctx); + patterns.insert(ctx); + patterns.insert(ctx); + patterns.insert(ctx); + applyPatternsAndFoldGreedily(func, std::move(patterns)); +} + +} // anonymous namespace + +std::unique_ptr> CreateTosaMakeBroadcastablePass() { + return std::make_unique(); +} + +static PassRegistration + pass(PASS_NAME, + "Perform broadcast on elementwise TosaOps to ensure same rank"); + +} // namespace tosa + +} // namespace mlir diff --git a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp new file mode 100644 index 000000000000000..6fe4638192905a6 --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp @@ -0,0 +1,364 @@ +//===- QuantUtils.cpp ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// TOSA numerical support functions and quantization attribute builders +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" + +namespace mlir { +namespace tosa { + +std::string tosa_quantized_type = "quint8"; + +namespace { + +void computeMultiplierAndShiftTosaScale16(double scale, int32_t &multiplier, + int32_t &shift) { + + /* Generates mantissa and shift values where mantissa is in [-1.0,-0.5] or + [0.5, 1.0] such that + multiplier = mantissa*2^shift */ + const double mantissa = std::frexp(scale, &shift); + auto shifted_m = std::round(mantissa * (int64_t(1) << 15)); + + assert(shifted_m <= (int64_t(1) << 15)); // can't be greater that 1.0 + if (shifted_m == (int64_t(1) << 15)) { + shifted_m /= 2; + shift++; + } + // TOSA expect right shift to be positive, and embed (1 << 15) into right + // shift bits + shift = (-shift) + 15; + + assert(shifted_m <= std::numeric_limits::max()); + + multiplier = static_cast(shifted_m); +} + +void computeMultiplierAndShiftTosaScale32(double scale, int32_t &multiplier, + int32_t &shift) { + + /* Generates mantissa and shift values where mantissa is in [-1.0,-0.5] or + [0.5, 1.0] such that + multiplier = mantissa*2^shift */ + const double mantissa = std::frexp(scale, &shift); + auto shifted_m = std::round(mantissa * (int64_t(1) << 31)); + + assert(shifted_m <= (int64_t(1) << 31)); // can't be greater that 1.0 + if (shifted_m == (int64_t(1) << 31)) { + shifted_m /= 2; + shift++; + } + // TOSA expect right shift to be positive, and embed (1 << 31) into right + // shift bits + shift = (-shift) + 31; + + assert(shifted_m <= std::numeric_limits::max()); + + multiplier = static_cast(shifted_m); +} + +} // namespace + +/* Generates a quantized multiplier / shift from double */ +void computeMultiplierAndShift(double scale, int32_t &multiplier, + int32_t &shift, int32_t scale_width) { + + switch (scale_width) { + case 16: + computeMultiplierAndShiftTosaScale16(scale, multiplier, shift); + return; + case 32: + computeMultiplierAndShiftTosaScale32(scale, multiplier, shift); + return; + default: + assert(0 && "Unsupported Tosa quantized_scale regime specified!"); + } +} + +void computeMultiplierAndShiftGtOne(double scale, int32_t &multiplier, + int32_t &shift, int32_t scale_width) { + assert(scale > double(1.0)); + computeMultiplierAndShift(scale, multiplier, shift, scale_width); + assert(shift >= 0); +} + +void computeMultiplierAndShiftLtOneExp(double scale, int32_t &multiplier, + int32_t &shift, int32_t scale_width) { + assert(scale < double(1.0)); + assert(scale > double(0.0)); + computeMultiplierAndShift(scale, multiplier, shift, scale_width); + assert(shift <= 0); +} + +#define GET_UQTYPE(input) \ + ((input) \ + .getType() \ + .dyn_cast() \ + .getElementType() \ + .dyn_cast()) + +/* method to build ConvOpQuantizationAttr, called from + * ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilder: input_zp: input zeropoint + * weight_zp: weight zeropoint + */ +ConvOpQuantizationAttr buildConvOpQuantizationAttr(mlir::OpBuilder &builder, + Value input, Value weight) { + + auto input_type = input.getType().dyn_cast(); + auto weight_type = weight.getType().dyn_cast(); + + if (!input_type || !weight_type) + return nullptr; + + bool input_is_qtype = + input_type.getElementType().isa(); + bool weight_is_qtype = + weight_type.getElementType().isa(); + + // Either all quantized or all not quantized + assert(!(input_is_qtype ^ weight_is_qtype)); + + if (input_is_qtype) { + + auto input_qtype = input_type.getElementType() + .dyn_cast(); + assert(input_qtype); // We don't support any other kind of input + // quantization here + + int64_t input_zp = input_qtype.getZeroPoint(); + int64_t weight_zp = 0; + + // per tensor quantization + if (auto weight_qtype = + weight_type.getElementType() + .dyn_cast()) { + weight_zp = weight_qtype.getZeroPoint(); + // per channel quantization + } else if (auto weight_qtype = + weight_type.getElementType() + .dyn_cast()) { + weight_zp = weight_qtype.getZeroPoints().front(); + } + + auto quantattr = mlir::tosa::ConvOpQuantizationAttr::get( + builder.getI32IntegerAttr(input_zp), + builder.getI32IntegerAttr(weight_zp), builder.getContext()); + + return quantattr; + } + + return nullptr; +} + +/* method to build MatMulOpQuantizationAttr, called from + * MatMulOpQuantInfoBuilder: a_zp: input a zeropoint b_zp: input b zeropoint + */ +MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(mlir::OpBuilder &builder, + Value a, Value b) { + + auto a_type = a.getType().dyn_cast(); + auto b_type = b.getType().dyn_cast(); + + if (!a_type || !b_type) + return nullptr; + + bool a_is_qtype = + a_type.getElementType().isa(); + bool b_is_qtype = + b_type.getElementType().isa(); + + // Either all quantized or all not quantized + assert(!(a_is_qtype ^ b_is_qtype)); + + if (a_is_qtype) { + + auto a_qtype = GET_UQTYPE(a); + auto b_qtype = GET_UQTYPE(b); + + assert(a_qtype && b_qtype); + + int64_t a_zp = a_qtype.getZeroPoint(); + int64_t b_zp = b_qtype.getZeroPoint(); + + auto quantattr = mlir::tosa::MatMulOpQuantizationAttr::get( + builder.getI32IntegerAttr(a_zp), builder.getI32IntegerAttr(b_zp), + builder.getContext()); + + return quantattr; + } + + return nullptr; +} + +/* method to build UnaryOpQuantizationAttr, called from + * UnaryOpQuantInfoBuilder: input_zp: input zeropoint output_zp: output + * zeropoint + */ +UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(mlir::OpBuilder &builder, + Value input, + Type output_raw_type) { + + auto input_type = input.getType().dyn_cast(); + auto output_type = output_raw_type.dyn_cast(); + + if (!input_type || !output_type) + return nullptr; + + bool input_is_qtype = + input_type.getElementType().isa(); + bool output_is_qtype = + output_type.getElementType().isa(); + + // Either all quantized or all not quantized + assert(!(input_is_qtype ^ output_is_qtype)); + + if (input_is_qtype) { + + auto input_qtype = input_type.getElementType() + .dyn_cast(); + auto output_qtype = output_type.getElementType() + .dyn_cast(); + assert(input_qtype && output_qtype); + + int64_t input_zp = input_qtype.getZeroPoint(); + int64_t output_zp = output_qtype.getZeroPoint(); + + auto quantattr = mlir::tosa::UnaryOpQuantizationAttr::get( + builder.getI32IntegerAttr(input_zp), + builder.getI32IntegerAttr(output_zp), builder.getContext()); + + return quantattr; + } + + return nullptr; +} + +/* method to build PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: + * input_zp: input zeropoint + */ +PadOpQuantizationAttr buildPadOpQuantizationAttr(mlir::OpBuilder &builder, + Value input) { + + auto input_type = input.getType().dyn_cast(); + + if (!input_type) + return nullptr; + + bool input_is_qtype = + input_type.getElementType().isa(); + + if (input_is_qtype) { + + auto input_qtype = input_type.getElementType() + .dyn_cast(); + assert(input_qtype); + + int64_t input_zp = input_qtype.getZeroPoint(); + + auto quantattr = mlir::tosa::PadOpQuantizationAttr::get( + builder.getI32IntegerAttr(input_zp), builder.getContext()); + + return quantattr; + } + + return nullptr; +} + +Type buildQTypeFromMinMax(OpBuilder builder, Type input_dtype, + Attribute minattr, Attribute maxattr, + IntegerAttr quant_bits, int filter_quantdim, + bool issigned, BoolAttr narrow_range) { + + quant::QuantizedType rettype; + + auto convfunc = + quant::ExpressedToQuantizedConverter::forInputType(input_dtype); + + auto minelems = minattr.dyn_cast(); + auto maxelems = maxattr.dyn_cast(); + + SmallVector min, max; + + if (minelems || maxelems) { // at least one is per-axis quantized elementsattr + + // must have the same number of elements + if (minelems.getNumElements() != maxelems.getNumElements()) + return {}; + + min.reserve(minelems.getNumElements()); + max.reserve(maxelems.getNumElements()); + for (auto i : minelems) { + min.push_back(FloatAttr::getValueAsDouble(i)); + } + for (auto i : maxelems) { + max.push_back(FloatAttr::getValueAsDouble(i)); + } + } else { // Just a single FP value + + auto minval = minattr.dyn_cast(); + if (minval) + min.push_back(minval.getValueAsDouble()); + else + return {}; + auto maxval = maxattr.dyn_cast(); + if (maxval) + max.push_back(maxval.getValueAsDouble()); + else + return {}; + } + + if (min.size() == max.size()) { + + if (min.size() == 1) { // Per-tensor quantization with one min/max pair + + rettype = quant::fakeQuantAttrsToType( + builder.getUnknownLoc(), quant_bits.getInt(), min[0], max[0], + narrow_range.getValue(), convfunc.expressedType, issigned); + + } else if (min.size() > 1) { // per-axis quant on filter_quantdim + + auto shape = input_dtype.dyn_cast(); + if (!shape) + return {}; + if ((filter_quantdim) >= 0 && (shape.getRank() > filter_quantdim)) { + + rettype = quant::fakeQuantAttrsToType( + builder.getUnknownLoc(), quant_bits.getInt(), filter_quantdim, + min[0], max[0], narrow_range.getValue(), convfunc.expressedType, + issigned); + } + + } else { + return {}; + } + } else { + return {}; + } + + if (!rettype) + return {}; + + return convfunc.convert(rettype); +} + +TypeAttr buildQTypeAttrFromMinMax(OpBuilder builder, Type input_dtype, + Attribute minattr, Attribute maxattr, + IntegerAttr quant_bits, int filter_quantdim, + bool issigned, BoolAttr narrow_range) { + + return TypeAttr::get( + buildQTypeFromMinMax(builder, input_dtype, minattr, maxattr, quant_bits, + filter_quantdim, issigned, narrow_range)); +} + +} // namespace tosa +} // namespace mlir