diff --git a/CMakeLists.txt b/CMakeLists.txt index e7e02b2d2973..6ce0cdc129db 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -77,6 +77,8 @@ tvm_option(USE_COREML "Build with coreml support" OFF) tvm_option(USE_TARGET_ONNX "Build with ONNX Codegen support" OFF) tvm_option(USE_ARM_COMPUTE_LIB "Build with Arm Compute Library" OFF) tvm_option(USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME "Build with Arm Compute Library graph runtime" OFF) +tvm_option(USE_TENSORRT_CODEGEN "Build with TensorRT Codegen support" OFF) +tvm_option(USE_TENSORRT_RUNTIME "Build with TensorRT runtime" OFF) # include directories include_directories(${CMAKE_INCLUDE_PATH}) @@ -347,6 +349,7 @@ include(cmake/modules/contrib/TF_TVMDSOOP.cmake) include(cmake/modules/contrib/CoreML.cmake) include(cmake/modules/contrib/ONNX.cmake) include(cmake/modules/contrib/ArmComputeLib.cmake) +include(cmake/modules/contrib/TensorRT.cmake) include(cmake/modules/Git.cmake) include(cmake/modules/LibInfo.cmake) diff --git a/cmake/config.cmake b/cmake/config.cmake index 9754385fa014..db125124ce70 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -218,6 +218,16 @@ set(USE_ETHOSN OFF) # otherwise use ETHOSN_HW (OFF) to use the software test infrastructure set(USE_ETHOSN_HW OFF) +# Whether to build with TensorRT codegen or runtime +# Examples are available here: docs/deploy/tensorrt.rst. +# +# USE_TENSORRT_CODEGEN - Support for compiling a relay graph where supported operators are +# offloaded to TensorRT. OFF/ON +# USE_TENSORRT_RUNTIME - Support for running TensorRT compiled modules, requires presense of +# TensorRT library. OFF/ON/"path/to/TensorRT" +set(USE_TENSORRT_CODEGEN OFF) +set(USE_TENSORRT_RUNTIME OFF) + # Build ANTLR parser for Relay text format # Possible values: # - ON: enable ANTLR by searching default locations (cmake find_program for antlr4 and /usr/local for jar) diff --git a/cmake/modules/contrib/TensorRT.cmake b/cmake/modules/contrib/TensorRT.cmake new file mode 100644 index 000000000000..1536d23205a7 --- /dev/null +++ b/cmake/modules/contrib/TensorRT.cmake @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# TensorRT Codegen only. This can be enabled independently of USE_TENSORRT_RUNTIME to enable +# compilation of TensorRT modules without requiring TensorRT to be installed. The compiled modules +# will only be able to be executed using a TVM built with USE_TENSORRT_RUNTIME=ON. +if(USE_TENSORRT_CODEGEN) + message(STATUS "Build with TensorRT codegen") + file(GLOB COMPILER_TENSORRT_SRCS src/relay/backend/contrib/tensorrt/*.cc) + set_source_files_properties(${COMPILER_TENSORRT_SRCS} PROPERTIES COMPILE_FLAGS "-Wno-deprecated-declarations") + file(GLOB RUNTIME_TENSORRT_SRCS src/runtime/contrib/tensorrt/tensorrt_runtime.cc) + set_source_files_properties(${RUNTIME_TENSORRT_SRCS} PROPERTIES COMPILE_FLAGS "-Wno-deprecated-declarations") + list(APPEND COMPILER_SRCS ${COMPILER_TENSORRT_SRCS}) + list(APPEND COMPILER_SRCS ${RUNTIME_TENSORRT_SRCS}) +endif() + +# TensorRT Runtime +if(USE_TENSORRT_RUNTIME) + if(IS_DIRECTORY ${USE_TENSORRT_RUNTIME}) + set(TENSORRT_ROOT_DIR ${USE_TENSORRT_RUNTIME}) + message(STATUS "Custom TensorRT path: " ${TENSORRT_ROOT_DIR}) + endif() + find_path(TENSORRT_INCLUDE_DIR NvInfer.h HINTS ${TENSORRT_ROOT_DIR} PATH_SUFFIXES include) + find_library(TENSORRT_LIB_DIR nvinfer HINTS ${TENSORRT_ROOT_DIR} PATH_SUFFIXES lib) + find_package_handle_standard_args(TENSORRT DEFAULT_MSG TENSORRT_INCLUDE_DIR TENSORRT_LIB_DIR) + if(NOT TENSORRT_FOUND) + message(ERROR "Could not find TensorRT.") + endif() + message(STATUS "TENSORRT_LIB_DIR: " ${TENSORRT_LIB_DIR}) + include_directories(${TENSORRT_INCLUDE_DIR}) + list(APPEND TVM_RUNTIME_LINKER_LIBS ${TENSORRT_LIB_DIR}) + + # TRT runtime sources + file(GLOB RUNTIME_TENSORRT_SRCS src/runtime/contrib/tensorrt/*.cc) + set_source_files_properties(${RUNTIME_TENSORRT_SRCS} PROPERTIES COMPILE_FLAGS "-Wno-deprecated-declarations") + list(APPEND RUNTIME_SRCS ${RUNTIME_TENSORRT_SRCS}) + + # Set defines + add_definitions(-DTVM_GRAPH_RUNTIME_TENSORRT) +endif() diff --git a/docs/deploy/index.rst b/docs/deploy/index.rst index b38a7f561ab3..68843ba18248 100644 --- a/docs/deploy/index.rst +++ b/docs/deploy/index.rst @@ -69,3 +69,4 @@ target device without relying on RPC. see the following resources on how to do s integrate hls arm_compute_lib + tensorrt diff --git a/docs/deploy/tensorrt.rst b/docs/deploy/tensorrt.rst new file mode 100644 index 000000000000..27f11e9b5377 --- /dev/null +++ b/docs/deploy/tensorrt.rst @@ -0,0 +1,297 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +Relay TensorRT Integration +========================== +**Author**: `Trevor Morris `_ + +Introduction +------------ + +NVIDIA TensorRT is a library for optimized deep learning inference. This integration will offload as +many operators as possible from Relay to TensorRT, providing a performance boost on NVIDIA GPUs +without the need to tune schedules. + +This guide will demonstrate how to install TensorRT and build TVM with TensorRT BYOC and runtime +enabled. It will also provide example code to compile and run a ResNet-18 model using TensorRT and +how to configure the compilation and runtime settings. Finally, we document the supported operators +and how to extend the integration to support other operators. + +Installing TensorRT +------------------- + +In order to download TensorRT, you will need to create an NVIDIA Developer program account. Please +see NVIDIA's documentation for more info: +https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html. If you have a Jetson device +such as a TX1, TX2, Xavier, or Nano, TensorRT will already be installed on the device via the +JetPack SDK. + +There are two methods to install TensorRT: + +* System install via deb or rpm package. +* Tar file installation. + +With the tar file installation method, you must provide the path of the extracted tar archive to +USE_TENSORRT_RUNTIME=/path/to/TensorRT. With the system install method, +USE_TENSORRT_RUNTIME=ON will automatically locate your installation. + +Building TVM with TensorRT support +---------------------------------- + +There are two separate build flags for TensorRT integration in TVM. These flags also enable +cross-compilation: USE_TENSORRT_CODEGEN=ON will also you to build a module with TensorRT support on +a host machine, while USE_TENSORRT_RUNTIME=ON will enable the TVM runtime on an edge device to +execute the TensorRT module. You should enable both if you want to compile and also execute models +with the same TVM build. + +* USE_TENSORRT_CODEGEN=ON/OFF - This flag will enable compiling a TensorRT module, which does not require any + TensorRT library. +* USE_TENSORRT_RUNTIME=ON/OFF/path-to-TensorRT - This flag will enable the TensorRT runtime module. + This will build TVM against the installed TensorRT library. + +Example setting in config.cmake file: + +.. code:: cmake + + set(USE_TENSORRT_CODEGEN ON) + set(USE_TENSORRT_RUNTIME /home/ubuntu/TensorRT-7.0.0.11) + + +Build and Deploy ResNet-18 with TensorRT +---------------------------------------- + +Create a Relay graph from a MXNet ResNet-18 model. + +.. code:: python + + import tvm + from tvm import relay + import mxnet + from mxnet.gluon.model_zoo.vision import get_model + + dtype = "float32" + input_shape = (1, 3, 224, 224) + block = get_model('resnet18_v1', pretrained=True) + mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype) + + +Annotate and partition the graph for TensorRT. All ops which are supported by the TensorRT +integration will be marked and offloaded to TensorRT. The rest of the ops will go through the +regular TVM CUDA compilation and code generation. + +.. code:: python + + from tvm.relay.op.contrib.tensorrt import partition_for_tensorrt + mod, config = partition_for_tensorrt(mod, params) + + +Build the Relay graph, using the new module and config returned by partition_for_tensorrt. The +target must always be a cuda target. ``partition_for_tensorrt`` will automatically fill out the +required values in the config, so there is no need to modify it - just pass it along to the +PassContext so the values can be read during compilation. + +.. code:: python + + target = "cuda" + with tvm.transform.PassContext(opt_level=3, config={'relay.ext.tensorrt.options': config}): + lib = relay.build(mod, target=target, params=params) + + +Export the module. + +.. code:: python + + lib.export_library('compiled.so') + + +Load module and run inference on the target machine, which must be built with +``USE_TENSORRT_RUNTIME`` enabled. The first run will take longer because the TensorRT engine will +have to be built. + +.. code:: python + + ctx = tvm.gpu(0) + loaded_lib = tvm.runtime.load_module('compiled.so') + gen_module = tvm.contrib.graph_runtime.GraphModule(loaded_lib['default'](ctx)) + input_data = np.random.uniform(0, 1, input_shape).astype(dtype) + gen_module.run(data=input_data) + + +Partitioning and Compilation Settings +------------------------------------- + +There are some options which can be configured in ``partition_for_tensorrt``. + +* ``version`` - TensorRT version to target as tuple of (major, minor, patch). If TVM is compiled + with USE_TENSORRT_RUNTIME=ON, the linked TensorRT version will be used instead. The version + will affect which ops can be partitioned to TensorRT. +* ``use_implicit_batch`` - Use TensorRT implicit batch mode (default true). Setting to false will + enable explicit batch mode which will widen supported operators to include those which modify the + batch dimension, but may reduce performance for some models. +* ``remove_no_mac_subgraphs`` - A heuristic to improve performance. Removes subgraphs which have + been partitioned for TensorRT if they do not have any multiply-accumulate operations. The removed + subgraphs will go through TVM's standard compilation instead. +* ``max_workspace_size`` - How many bytes of workspace size to allow each subgraph to use for + TensorRT engine creation. See TensorRT documentation for more info. Can be overriden at runtime. + + +Runtime Settings +---------------- + +There are some additional options which can be configured at runtime using environment variables. + +* Automatic FP16 Conversion - Environment variable ``TVM_TENSORRT_USE_FP16=1`` can be set to + automatically convert the TensorRT components of your model to 16-bit floating point precision. + This can greatly increase performance, but may cause some slight loss in the model accuracy. +* Caching TensorRT Engines - During the first inference, the runtime will invoke the TensorRT API + to build an engine. This can be time consuming, so you can set ``TVM_TENSORRT_CACHE_DIR`` to + point to a directory to save these built engines to on the disk. The next time you load the model + and give it the same directory, the runtime will load the already built engines to avoid the long + warmup time. A unique directory is required for each model. +* TensorRT has a paramter to configure the maximum amount of scratch space that each layer in the + model can use. It is generally best to use the highest value which does not cause you to run out + of memory. You can use ``TVM_TENSORRT_MAX_WORKSPACE_SIZE`` to override this by specifying the + workspace size in bytes you would like to use. + + +Operator support +---------------- ++------------------------+------------------------------------+ +| Relay Node | Remarks | ++========================+====================================+ +| nn.relu | | ++------------------------+------------------------------------+ +| sigmoid | | ++------------------------+------------------------------------+ +| tanh | | ++------------------------+------------------------------------+ +| nn.batch_norm | | ++------------------------+------------------------------------+ +| nn.softmax | | ++------------------------+------------------------------------+ +| nn.conv2d | | ++------------------------+------------------------------------+ +| nn.dense | | ++------------------------+------------------------------------+ +| nn.bias_add | | ++------------------------+------------------------------------+ +| add | | ++------------------------+------------------------------------+ +| subtract | | ++------------------------+------------------------------------+ +| multiply | | ++------------------------+------------------------------------+ +| divide | | ++------------------------+------------------------------------+ +| power | | ++------------------------+------------------------------------+ +| maximum | | ++------------------------+------------------------------------+ +| minimum | | ++------------------------+------------------------------------+ +| nn.max_pool2d | | ++------------------------+------------------------------------+ +| nn.avg_pool2d | | ++------------------------+------------------------------------+ +| nn.global_max_pool2d | | ++------------------------+------------------------------------+ +| nn.global_avg_pool2d | | ++------------------------+------------------------------------+ +| exp | | ++------------------------+------------------------------------+ +| log | | ++------------------------+------------------------------------+ +| sqrt | | ++------------------------+------------------------------------+ +| abs | | ++------------------------+------------------------------------+ +| negative | | ++------------------------+------------------------------------+ +| nn.batch_flatten | | ++------------------------+------------------------------------+ +| expand_dims | | ++------------------------+------------------------------------+ +| squeeze | | ++------------------------+------------------------------------+ +| concatenate | | ++------------------------+------------------------------------+ +| nn.conv2d_transpose | | ++------------------------+------------------------------------+ +| transpose | | ++------------------------+------------------------------------+ +| layout_transform | | ++------------------------+------------------------------------+ +| reshape | | ++------------------------+------------------------------------+ +| nn.pad | | ++------------------------+------------------------------------+ +| sum | | ++------------------------+------------------------------------+ +| prod | | ++------------------------+------------------------------------+ +| max | | ++------------------------+------------------------------------+ +| min | | ++------------------------+------------------------------------+ +| mean | | ++------------------------+------------------------------------+ +| nn.adaptive_max_pool2d | | ++------------------------+------------------------------------+ +| nn.adaptive_avg_pool2d | | ++------------------------+------------------------------------+ +| clip | Requires TensorRT 5.1.5 or greater | ++------------------------+------------------------------------+ +| nn.leaky_relu | Requires TensorRT 5.1.5 or greater | ++------------------------+------------------------------------+ +| sin | Requires TensorRT 5.1.5 or greater | ++------------------------+------------------------------------+ +| cos | Requires TensorRT 5.1.5 or greater | ++------------------------+------------------------------------+ +| atan | Requires TensorRT 5.1.5 or greater | ++------------------------+------------------------------------+ +| ceil | Requires TensorRT 5.1.5 or greater | ++------------------------+------------------------------------+ +| floor | Requires TensorRT 5.1.5 or greater | ++------------------------+------------------------------------+ +| strided_slice | Requires TensorRT 5.1.5 or greater | ++------------------------+------------------------------------+ +| nn.conv3d | Requires TensorRT 6.0.1 or greater | ++------------------------+------------------------------------+ +| nn.max_pool3d | Requires TensorRT 6.0.1 or greater | ++------------------------+------------------------------------+ +| nn.avg_pool3d | Requires TensorRT 6.0.1 or greater | ++------------------------+------------------------------------+ +| nn.conv3d_transpose | Requires TensorRT 6.0.1 or greater | ++------------------------+------------------------------------+ + + +Adding a new operator +--------------------- +To add support for a new operator, there are a series of files we need to make changes to: + +* `src/runtime/contrib/tensorrt/tensorrt_ops.cc` Create a new op converter class which + implements the ``TensorRTOpConverter`` interface. You must implement the constructor to specify how + many inputs there are and whether they are tensors or weights. You must also implement the + ``Convert`` method to perform the conversion. This is done by using the inputs, attributes, and + network from params to add the new TensorRT layers and push the layer outputs. You can use the + existing converters as an example. Finally, register your new op conventer in the + ``GetOpConverters()`` map. +* `python/relay/op/contrib/tensorrt.py` This file contains the annotation rules for TensorRT. These + determine which operators and their attributes that are supported. You must register an annotation + function for the relay operator and specify which attributes are supported by your converter, by + checking the attributes are returning true or false. +* `tests/python/contrib/test_tensorrt.py` Add unit tests for the given operator. diff --git a/python/tvm/relay/op/contrib/__init__.py b/python/tvm/relay/op/contrib/__init__.py index dbcd8055d30b..49abf36134b4 100644 --- a/python/tvm/relay/op/contrib/__init__.py +++ b/python/tvm/relay/op/contrib/__init__.py @@ -22,3 +22,4 @@ from .dnnl import * from .coreml import * from .ethosn import * +from .tensorrt import * diff --git a/python/tvm/relay/op/contrib/tensorrt.py b/python/tvm/relay/op/contrib/tensorrt.py new file mode 100644 index 000000000000..a0e23a043a72 --- /dev/null +++ b/python/tvm/relay/op/contrib/tensorrt.py @@ -0,0 +1,769 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +"""TensorRT supported operators.""" +import logging +import numpy as np +import tvm +from tvm import relay +from tvm.relay import transform +from tvm.relay.build_module import bind_params_by_name +from tvm.relay.expr import Call, Constant, Tuple, GlobalVar +from tvm.relay.expr_functor import ExprMutator + +logger = logging.getLogger("TensorRT") + + +def is_tensorrt_runtime_enabled(): + """Check if the TensorRT graph runtime is present. + Returns + ------- + ret: bool + True if present, False if not. + """ + check_enabled = tvm.get_global_func("relay.op.is_tensorrt_runtime_enabled", True) + if check_enabled: + return check_enabled() + return False + + +def get_tensorrt_version(): + """Gets the version of TensorRT that TVM is built against or is targeting. + + Returns + ------- + ret: Tuple[int, int, int] + TensorRT version as a tuple of major, minor, and patch number. If TVM + is not built with TensorRT, the value set by set_tensorrt_version() is returned instead. + """ + pass_ctx = tvm.transform.PassContext.current() + if "relay.ext.tensorrt.options" in pass_ctx.config: + return tuple(pass_ctx.config["relay.ext.tensorrt.options"].tensorrt_version) + return tuple(tvm.get_global_func("relay.op.get_tensorrt_version")()) + + +def get_tensorrt_use_implicit_batch_mode(): + pass_ctx = tvm.transform.PassContext.current() + if "relay.ext.tensorrt.options" in pass_ctx.config: + return pass_ctx.config["relay.ext.tensorrt.options"].use_implicit_batch + logger.warning( + "PassContext has no relay.ext.tensorrt.options config, using default value " + "use_implicit_batch=True." + ) + return True + + +def get_tensorrt_remove_no_mac_subgraphs(): + pass_ctx = tvm.transform.PassContext.current() + if "relay.ext.tensorrt.options" in pass_ctx.config: + return pass_ctx.config["relay.ext.tensorrt.options"].remove_no_mac_subgraphs + logger.warning( + "PassContext has no relay.ext.tensorrt.options config, using default value " + "remove_no_mac_subgraphs=False." + ) + return False + + +def partition_for_tensorrt( + mod, + params=None, + version=None, + use_implicit_batch=True, + remove_no_mac_subgraphs=False, + max_workspace_size=1 << 30, +): + """Partition the graph greedily offloading supported operators to TensorRT. + + Parameters + ---------- + mod : Module + The module to run passes on. + params : Optional[Dict[str, NDArray]] + Constant input parameters. + version : Optional[Tuple[int, int, int]] + TensorRT version to target as tuple of (major, minor, patch). If TVM is compiled with + USE_TENSORRT_RUNTIME=ON, the linked TensorRT version will be used instead. + use_implicit_batch : Optional[bool] + Use TensorRT implicit batch mode (default true). Setting to false will enable explicit batch + mode which will widen supported operators to include those which modify the batch dimension, + but may reduce performance for some models. + remove_no_mac_subgraphs : Optional[bool] + Removes subgraphs which have been partitioned for TensorRT if they do not have any + multiply-accumulate operations. The removed subgraphs will go through TVM's standard + compilation instead. Can improve performance. + max_workspace_size : Optional[int] + How many bytes of workspace size to allow each subgraph to use for TensorRT engine creation. + See TensorRT documentation for more info. + Returns + ------- + mod_and_config : Tuple[Module, Dict[str, Any]] + A tuple of 1) annotated and partitioned module and 2) "relay.ext.tensorrt.options" + configuration which should be given to PassContext when building. + """ + config = { + "use_implicit_batch": use_implicit_batch, + "max_workspace_size": max_workspace_size, + "remove_no_mac_subgraphs": remove_no_mac_subgraphs, + } + if version: + assert isinstance(version, tuple) and len(version) == 3 + config["tensorrt_version"] = version + else: + linked_version = tuple(tvm.get_global_func("relay.op.get_tensorrt_version")()) + if not linked_version: + logger.warning( + "TVM was not built against TensorRT and no version was provided to " + "partition_for_tensorrt. Defaulting to 6.0.1" + ) + linked_version = (6, 0, 1) + config["tensorrt_version"] = linked_version + + if params: + mod["main"] = bind_params_by_name(mod["main"], params) + seq = tvm.transform.Sequential( + [ + transform.InferType(), + RemoveDropoutPass(), + transform.RemoveUnusedFunctions(), + transform.ConvertLayout( + {"nn.conv2d": ["NCHW", "default"], "nn.conv3d": ["NCDHW", "default"]} + ), + transform.FoldConstant(), + transform.AnnotateTarget("tensorrt"), + transform.MergeCompilerRegions(), + transform.PartitionGraph(), + transform.InferType(), + ] + ) + with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}): + mod = seq(mod) + mod = prune_tensorrt_subgraphs(mod) + return mod, config + + +def _register_external_op_helper_with_checker(op_name, checker): + @tvm.ir.register_op_attr(op_name, "target.tensorrt") + def _func_wrapper(attrs, args): + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + return checker(attrs, args, op_name) + + return _func_wrapper + + +def _register_external_op_helper(op_name, supported=True): + return _register_external_op_helper_with_checker( + op_name, lambda attrs, args, op_name: supported + ) + + +# Ops which are always supported +_register_external_op_helper("nn.relu") +_register_external_op_helper("sigmoid") +_register_external_op_helper("tanh") +_register_external_op_helper("subtract") +_register_external_op_helper("multiply") +_register_external_op_helper("divide") +_register_external_op_helper("power") +_register_external_op_helper("maximum") +_register_external_op_helper("minimum") +_register_external_op_helper("exp") +_register_external_op_helper("log") +_register_external_op_helper("sqrt") +_register_external_op_helper("abs") +_register_external_op_helper("negative") +_register_external_op_helper("nn.batch_flatten") +_register_external_op_helper("clip") + + +@tvm.ir.register_op_attr("add", "target.tensorrt") +def add_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if add is supported by TensorRT.""" + + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if ( + not get_tensorrt_use_implicit_batch_mode() + and (isinstance(args[0], Constant) or isinstance(args[1], Constant)) + and args[0].checked_type.shape[0] == args[1].checked_type.shape[0] + and args[0].checked_type.shape[0] != 1 + and (len(args[0].checked_type.shape) > 3 or len(args[1].checked_type.shape) > 3) + ): + logger.info("add: bug in TRT with adding batched constants.") + return False + return True + + +@tvm.ir.register_op_attr("nn.batch_norm", "target.tensorrt") +def batch_norm_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if nn.batch_norm is supported by TensorRT.""" + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if int(attrs.axis) not in (1, 3): + logger.info("nn.batch_norm: axis is %d but must be 1 or 3.", int(attrs.axis)) + return False + return True + + +@tvm.ir.register_op_attr("nn.softmax", "target.tensorrt") +def softmax_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if nn.softmax is supported by TensorRT.""" + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if get_tensorrt_use_implicit_batch_mode() and int(attrs.axis) == 0: + logger.info("nn.softmax: can't modify batch dimension.") + return False + return True + + +@tvm.ir.register_op_attr("nn.conv2d", "target.tensorrt") +def conv2d_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if nn.conv2d is supported by TensorRT.""" + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if attrs.data_layout != "NCHW": + logger.info("nn.conv2d: data_layout is %s but must be NCHW.", attrs.data_layout) + return False + if attrs.kernel_layout != "OIHW": + logger.info("nn.conv2d: kernel_layout is %s but must be OIHW.", attrs.kernel_layout) + return False + if attrs.out_layout and attrs.out_layout != "NCHW": + logger.info("nn.conv2d: out_layout is %s but must be NCHW.", attrs.out_layout) + return False + return True + + +@tvm.ir.register_op_attr("nn.dense", "target.tensorrt") +def dense_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if dense is supported by TensorRT.""" + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + input_rank = len(args[0].checked_type.shape) + weight_rank = len(args[1].checked_type.shape) + if input_rank not in (2, 3, 4): + logger.info("nn.dense: input has rank %d but must be 2, 3 or 4.", input_rank) + return False + if weight_rank != 2: + logger.info("nn.dense: weight has rank %d but must be 2.", weight_rank) + return False + return True + + +@tvm.ir.register_op_attr("nn.bias_add", "target.tensorrt") +def bias_add_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if nn.bias_add is supported by TensorRT.""" + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + input_rank = len(args[0].checked_type.shape) + if input_rank not in (2, 3, 4): + logger.info("nn.bias_add: input rank is %d but must be 2, 3 or 4.", input_rank) + return False + return True + + +@tvm.ir.register_op_attr("nn.max_pool2d", "target.tensorrt") +def max_pool_2d_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if nn.max_pool2d is supported by TensorRT.""" + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if attrs.layout != "NCHW": + logger.info("nn.max_pool2d: layout is %s but must be NCHW.", attrs.layout) + return False + if attrs.ceil_mode and get_tensorrt_version() < (5, 1, 5): + logger.info("nn.avg_pool2d: ceil_mode=True requires TensorRT 5.1.5 or greater.") + return False + return True + + +@tvm.ir.register_op_attr("nn.avg_pool2d", "target.tensorrt") +def avg_pool_2d_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if nn.avg_pool2d is supported by TensorRT.""" + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if attrs.layout != "NCHW": + logger.info("nn.avg_pool2d: layout is %d but must be NCHW.", attrs.layout) + return False + if ( + attrs.count_include_pad + and len(attrs.padding) == 4 + and ( + int(attrs.padding[0]) != int(attrs.padding[2]) + or int(attrs.padding[1]) != int(attrs.padding[3]) + ) + ): + logger.info( + "nn.avg_pool2d: inclusive-counted blended or average " + "pooling is not supported in combination with asymmetric padding" + ) + return False + if attrs.ceil_mode and get_tensorrt_version() < (5, 1, 5): + logger.info("nn.avg_pool2d: ceil_mode=True requires TensorRT 5.1.5 or greater.") + return False + return True + + +@tvm.ir.register_op_attr("nn.global_max_pool2d", "target.tensorrt") +def global_max_pool_2d_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if nn.global_max_pool2d is supported by TensorRT.""" + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if attrs.layout != "NCHW": + logger.info("nn.global_max_pool2d: layout is %s but must be NCHW.", attrs.layout) + return False + return True + + +@tvm.ir.register_op_attr("nn.global_avg_pool2d", "target.tensorrt") +def global_avg_pool_2d_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if nn.global_avg_pool2d is supported by TensorRT.""" + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if attrs.layout != "NCHW": + logger.info("nn.global_avg_pool2d: layout is %s but must be NCHW.", attrs.layout) + return False + return True + + +@tvm.ir.register_op_attr("expand_dims", "target.tensorrt") +def expand_dims_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if expand_dims is supported by TensorRT.""" + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if get_tensorrt_use_implicit_batch_mode() and int(attrs.axis) == 0: + logger.info("expand_dims: can't modify batch dimension.") + return False + return True + + +@tvm.ir.register_op_attr("squeeze", "target.tensorrt") +def squeeze_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if squeeze is supported by TensorRT.""" + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if not attrs.axis: + logger.info("squeeze: must explicitly set axis.") + return False + if get_tensorrt_use_implicit_batch_mode() and any([axis == 0 for axis in map(int, attrs.axis)]): + logger.info("squeeze: can't modify batch dimension.") + return False + return True + + +@tvm.ir.register_op_attr("concatenate", "target.tensorrt") +def concatenate_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if concatenate is supported by TensorRT.""" + if any([x.dtype != "float32" for x in args[0].checked_type.fields]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if not get_tensorrt_use_implicit_batch_mode(): + return True + if int(attrs.axis) == 0: + logger.info("concatenate: can't modify batch dimension.") + return False + if isinstance(args[0], Tuple): + for tuple_input in args[0].fields: + if isinstance(tuple_input, Constant): + logger.info("concatenate: can't concatenate tensors with constants.") + return False + return True + + +@tvm.ir.register_op_attr("nn.conv2d_transpose", "target.tensorrt") +def conv2d_transpose_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if nn.conv2d_transpose is supported by TensorRT.""" + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if attrs.data_layout != "NCHW": + logger.info("nn.conv2d_transpose: data_layout is %s but must be NCHW.", attrs.data_layout) + return False + if attrs.kernel_layout != "OIHW": + logger.info( + "nn.conv2d_transpose: kernel_layout is %s but must be OIHW.", attrs.kernel_layout + ) + return False + if attrs.out_layout and attrs.out_layout != "NCHW": + logger.info("nn.conv2d_transpose: out_layout is %s but must be NCHW.", attrs.out_layout) + return False + if attrs.dilation and any([rate != 1 for rate in map(int, attrs.dilation)]): + logger.info("nn.conv2d_transpose: dilation rate must be 1.") + return False + return True + + +@tvm.ir.register_op_attr("transpose", "target.tensorrt") +def transpose_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if transpose is supported by TensorRT.""" + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if get_tensorrt_use_implicit_batch_mode() and int(attrs.axes[0]) != 0: + logger.info("transpose: can't modify batch dimension.") + return False + return True + + +@tvm.ir.register_op_attr("layout_transform", "target.tensorrt") +def layout_transform_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if layout_transform is supported by TensorRT.""" + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if (attrs.src_layout, attrs.dst_layout) not in [ + ("NCHW", "NHWC"), + ("NHWC", "NCHW"), + ("NDHWC", "NCDHW"), + ("NCDHW", "NDHWC"), + ]: + logger.info( + "layout_transform: %s to %s is not supported.", attrs.src_layout, attrs.dst_layout + ) + return False + return True + + +@tvm.ir.register_op_attr("reshape", "target.tensorrt") +def reshape_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if reshape is supported by TensorRT.""" + if args[0].checked_type.dtype != "float32": + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if any([x < -1 for x in map(int, attrs.newshape)]): + logger.info("reshape: new shape dims must be explicit.") + return False + if get_tensorrt_use_implicit_batch_mode(): + shape = list(map(int, args[0].checked_type.shape)) + new_shape = list(map(int, attrs.newshape)) + if len(new_shape) == 0 or len(shape) == 0: + logger.info("reshape: Can't reshape to or from scalar.") + return False + # TRT cannot modify batch dimension. + original_volume = np.prod(shape) + # First, resolve 0. + for i, value in enumerate(new_shape): + if value == 0: + new_shape[i] = shape[i] + # Resolve -1. + for i, value in enumerate(new_shape): + if value == -1: + new_shape[i] = original_volume // np.prod([x for x in new_shape if x != -1]) + if shape[0] != new_shape[0]: + logger.info("reshape: can't modify batch dimension.") + return False + return True + + +@tvm.ir.register_op_attr("nn.pad", "target.tensorrt") +def pad_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if nn.pad is supported by TensorRT.""" + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if attrs.pad_mode != "constant": + logger.info("nn.pad: pad mode is %s but must be constant.", attrs.pad_mode) + return False + if float(attrs.pad_value) != 0.0: + logger.info("nn.pad: pad value is %f but must be 0.0.", float(attrs.pad_value)) + return False + if any([x != 0 for x in attrs.pad_width[0]]) or any([x != 0 for x in attrs.pad_width[1]]): + logger.info("nn.pad: can't pad batch or channel dimensions.") + return False + if len(attrs.pad_width) == 5 and any([x != 0 for x in attrs.pad_width[2]]): + logger.info("nn.pad: can only pad last two dimensions for 5D inputs.") + return True + + +def reduce_annotate_fn(attrs, args, op_name): + """Helper for reduce operations.""" + if not attrs.axis or len(attrs.axis) == 0: + logger.info("%s: cannot reduce to scalar.", op_name) + return False + if attrs.exclude: + logger.info("%s: exclude not supported.", op_name) + return False + if get_tensorrt_use_implicit_batch_mode() and any([x == 0 for x in map(int, attrs.axis)]): + logger.info("%s: can't modify batch dimension.", op_name) + return False + return True + + +_register_external_op_helper_with_checker("sum", reduce_annotate_fn) +_register_external_op_helper_with_checker("prod", reduce_annotate_fn) +_register_external_op_helper_with_checker("max", reduce_annotate_fn) +_register_external_op_helper_with_checker("min", reduce_annotate_fn) +_register_external_op_helper_with_checker("mean", reduce_annotate_fn) + + +def trt_version_annotate_fn(version): + """Helper for ops which require a minimum TRT version""" + + def _func_wrapper(attrs, args, op_name): + if get_tensorrt_version() < version: + logger.info( + "%s: requires TensorRT version %s or higher.", op_name, ".".join(map(str, version)) + ) + return False + return True + + return _func_wrapper + + +_register_external_op_helper_with_checker("nn.leaky_relu", trt_version_annotate_fn((5, 1, 5))) +_register_external_op_helper_with_checker("sin", trt_version_annotate_fn((5, 1, 5))) +_register_external_op_helper_with_checker("cos", trt_version_annotate_fn((5, 1, 5))) +_register_external_op_helper_with_checker("atan", trt_version_annotate_fn((5, 1, 5))) +_register_external_op_helper_with_checker("ceil", trt_version_annotate_fn((5, 1, 5))) + + +@tvm.ir.register_op_attr("strided_slice", "target.tensorrt") +def strided_slice_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if strided_slice is supported by TensorRT.""" + if args[0].checked_type.dtype != "float32": + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if not trt_version_annotate_fn((5, 1, 5))(attrs, args, "strided_slice"): + return False + if get_tensorrt_use_implicit_batch_mode(): + batch_dim_begin_modified = attrs.begin[0] is not None and int(attrs.begin[0]) != 0 + batch_dim_end_modified = ( + attrs.end[0] is not None + and int(attrs.end[0]) != -1 + and int(attrs.end[0]) != int(args[0].checked_type.shape[0]) + ) + if batch_dim_begin_modified or batch_dim_end_modified: + logger.info("strided_slice: can't modify batch dimension.") + return False + if any([x is not None and x <= 0 for x in attrs.strides]): + logger.info("strided_slice: stride must be positive") + return False + return True + + +@tvm.ir.register_op_attr("nn.adaptive_max_pool2d", "target.tensorrt") +def adapative_max_pool2d_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if nn.adaptive_max_pool2d is supported by TensorRT.""" + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if len(attrs.output_size) == 0 or any([size != 1 for size in map(int, attrs.output_size)]): + logger.info("nn.adaptive_max_pool2d: output size must be (1, 1).") + return False + return True + + +@tvm.ir.register_op_attr("nn.adaptive_avg_pool2d", "target.tensorrt") +def adapative_avg_pool2d_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if nn.adaptive_avg_pool2d is supported by TensorRT.""" + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if len(attrs.output_size) == 0 or any([size != 1 for size in map(int, attrs.output_size)]): + logger.info("nn.adaptive_avg_pool2d: output size must be (1, 1).") + return False + return True + + +@tvm.ir.register_op_attr("nn.conv3d", "target.tensorrt") +def conv3d_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if nn.conv3d is supported by TensorRT.""" + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if not trt_version_annotate_fn((6, 0, 1))(attrs, args, "nn.conv3d"): + return False + if attrs.data_layout != "NCDHW": + logger.info("nn.conv3d: data_layout is %s but must be NCDHW.", attrs.data_layout) + return False + if attrs.kernel_layout != "OIDHW": + logger.info("nn.conv3d: kernel_layout is %s but must be OIDHW.", attrs.kernel_layout) + return False + if attrs.out_layout and attrs.out_layout != "NCDHW": + logger.info("nn.conv3d: out_layout is %s but must be NCDHW.", attrs.out_layout) + return False + return True + + +@tvm.ir.register_op_attr("nn.max_pool3d", "target.tensorrt") +def max_pool_3d_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if nn.max_pool3d is supported by TensorRT.""" + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if not trt_version_annotate_fn((6, 0, 1))(attrs, args, "nn.max_pool3d"): + return False + if attrs.layout != "NCDHW": + logger.info("nn.max_pool3d: layout is %s but must be NCDHW.", attrs.layout) + return False + return True + + +@tvm.ir.register_op_attr("nn.avg_pool3d", "target.tensorrt") +def avg_pool_3d_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if nn.avg_pool3d is supported by TensorRT.""" + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if not trt_version_annotate_fn((6, 0, 1))(attrs, args, "nn.avg_pool3d"): + return False + if attrs.layout != "NCDHW": + logger.info("nn.avg_pool3d: layout is %s but must be NCDHW.", attrs.layout) + return False + return True + + +@tvm.ir.register_op_attr("nn.conv3d_transpose", "target.tensorrt") +def conv3d_transpose_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if nn.conv3d_transpose is supported by TensorRT.""" + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if not trt_version_annotate_fn((6, 0, 1))(attrs, args, "nn.conv3d_transpose"): + return False + if attrs.data_layout != "NCDHW": + logger.info("nn.conv3d_transpose: data_layout is %s but must be NCDHW.", attrs.data_layout) + return False + if attrs.kernel_layout != "OIDHW": + logger.info( + "nn.conv3d_transpose: kernel_layout is %s but must be OIDHW.", attrs.kernel_layout + ) + return False + if attrs.out_layout and attrs.out_layout != "NCDHW": + logger.info("nn.conv3d_transpose: out_layout is %s but must be NCDHW.", attrs.out_layout) + return False + if attrs.dilation and any([rate != 1 for rate in map(int, attrs.dilation)]): + logger.info("nn.conv3d_transpose: dilation rate must be 1.") + return False + if attrs.output_padding and any([x != 0 for x in map(int, attrs.output_padding)]): + logger.info("nn.conv3d_transpose: output padding is not supported.") + return False + return True + + +def is_valid_subgraph(params, body): + """Final check on whether the subgraph is valid and should be offloaded to TensorRT.""" + # Remove invalid subgraphs for implicit batch mode. + if get_tensorrt_use_implicit_batch_mode(): + input_batch_sizes = [] + for var in params: + # In implicit batch mode, all inputs must have same batch size + if isinstance(var.checked_type, relay.TupleType): + for tupe_type in var.checked_type.fields: + # Scalar inputs not allowed + if len(tupe_type.shape) == 0: + logger.info("tensorrt: scalar inputs not supported") + return False + input_batch_sizes.append(int(tupe_type.shape[0])) + else: + # Scalar inputs not allowed + if len(var.checked_type.shape) == 0: + logger.info("tensorrt: scalar inputs not supported") + return False + input_batch_sizes.append(int(var.checked_type.shape[0])) + if len(input_batch_sizes) > 1 and len(set(input_batch_sizes)) != 1: + logger.info("tensorrt: inputs have different batch sizes") + return False + # Remove subgraphs with no multiply-accumulates + if get_tensorrt_remove_no_mac_subgraphs() and relay.analysis.get_total_mac_number(body) == 0: + return False + return True + + +def prune_tensorrt_subgraphs(mod): + """ + Removes invalid subgraphs and those with no multiply-accumulates (if remove_no_max_subgraphs + is set). + """ + + class SubgraphRemover(ExprMutator): + """ + Reverts subgraphs in subgraphs_to_remove back to TVM instead of using an external codegen. + """ + + def __init__(self, subgraphs_to_remove, mod, new_mod): + ExprMutator.__init__(self) + self.subgraphs_to_remove = subgraphs_to_remove + self.mod = mod + self.new_mod = new_mod + + def visit_call(self, call): + if isinstance(call.op, GlobalVar): + name = call.op.name_hint + if name in self.subgraphs_to_remove: + # "Inline" the subgraph back into new main function. + func = self.mod[name] + var_map = {} + for arg, param in zip(call.args, func.params): + var_map[param] = super().visit(arg) + new_body = relay.bind(func.body, var_map) + return new_body + if name != "main": + # Copy the GlobalVar (subgraph function) to the new module and call. + args = [] + for arg in call.args: + args.append(super().visit(arg)) + subgraph_gv = relay.GlobalVar(name) + self.new_mod[subgraph_gv] = self.mod[name] + return subgraph_gv(*args) + return super().visit_call(call) + + subgraphs_to_remove = [] + # Remove invalid subgraphs + for subgraph in mod.get_global_vars(): + name = subgraph.name_hint + if not mod[name].attrs or mod[name].attrs["Compiler"] != "tensorrt": + continue + if not is_valid_subgraph(mod[name].params, mod[name].body): + subgraphs_to_remove.append(name) + # Create new pruned module + new_mod = tvm.IRModule() + new_mod["main"] = SubgraphRemover(subgraphs_to_remove, mod, new_mod).visit(mod["main"]) + return new_mod + + +class RemoveDropout(ExprMutator): + """ + Removes all nn.dropout from an expr. + """ + + def visit_tuple_getitem(self, op): + visit = super().visit_tuple_getitem(op) + if ( + isinstance(visit.tuple_value, Call) + and visit.tuple_value.op.name == "nn.dropout" + and visit.index == 0 + ): + return visit.tuple_value.args[0] + return visit + + +@transform.function_pass(opt_level=0) +class RemoveDropoutPass: + def transform_function(self, func, mod, _): + return RemoveDropout().visit(func) diff --git a/src/relay/backend/contrib/tensorrt/codegen.cc b/src/relay/backend/contrib/tensorrt/codegen.cc new file mode 100644 index 000000000000..f692da3f31ac --- /dev/null +++ b/src/relay/backend/contrib/tensorrt/codegen.cc @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/backend/contrib/tensorrt/codegen.cc + * \brief Implementation of the TensorRT JSON serializer. + */ +#include +#include +#include + +#include +#include +#include + +#include "../../utils.h" +#include "../codegen_json/codegen_json.h" + +#if TVM_GRAPH_RUNTIME_TENSORRT +#include "NvInfer.h" +#endif + +namespace tvm { +namespace relay { +namespace contrib { + +/*! \brief Attributes to store the compiler options for TensorRT. */ +struct TensorRTCompilerConfigNode : public tvm::AttrsNode { + Array tensorrt_version; + bool use_implicit_batch; + size_t max_workspace_size; + bool remove_no_mac_subgraphs; + + TVM_DECLARE_ATTRS(TensorRTCompilerConfigNode, "ext.attrs.TensorRTCompilerConfigNode") { + TVM_ATTR_FIELD(tensorrt_version) + .describe("TensorRT version as (major, minor, patch).") + .set_default(Array({6, 0, 1})); + TVM_ATTR_FIELD(use_implicit_batch).set_default(true); + TVM_ATTR_FIELD(max_workspace_size).set_default(size_t(1) << 30); + TVM_ATTR_FIELD(remove_no_mac_subgraphs).set_default(false); + } +}; + +class TensorRTCompilerConfig : public Attrs { + public: + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorRTCompilerConfig, Attrs, + TensorRTCompilerConfigNode); +}; + +TVM_REGISTER_NODE_TYPE(TensorRTCompilerConfigNode); +TVM_REGISTER_PASS_CONFIG_OPTION("relay.ext.tensorrt.options", TensorRTCompilerConfig); + +/*! + * \brief Generates an TensorRTModule from a relay expression by serializing the expression to a + * json representation. TensorRT is not required here because use of TensorRT APIs is deferred until + * runtime. + */ +class TensorRTJSONSerializer : public backend::contrib::JSONSerializer { + using JSONGraphNode = tvm::runtime::json::JSONGraphNode; + using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry; + + public: + TensorRTJSONSerializer(const std::string& symbol, const Expr& expr) + : JSONSerializer(symbol, expr) {} + + std::vector VisitExpr_(const CallNode* cn) { + std::string name; + if (const auto* op_node = cn->op.as()) { + name = op_node->name; + } else { + return JSONSerializer::VisitExpr_(cn); + } + + std::vector inputs; + for (const auto& arg : cn->args) { + auto res = VisitExpr(arg); + inputs.insert(inputs.end(), res.begin(), res.end()); + } + auto node = std::make_shared(name, /* name_ */ + "kernel", /* op_type_ */ + inputs, 1 /* num_outputs_ */); + if (name == "nn.pad") { + SetPadNodeAttribute(node, cn); + } else if (name == "strided_slice") { + SetStridedSliceNodeAttribute(node, cn); + } else { + SetCallNodeAttribute(node, cn); + } + // These attributes are global to the whole module. + SaveGlobalAttributes(node); + return AddNode(node, GetRef(cn)); + } + + void SetPadNodeAttribute(std::shared_ptr node, const CallNode* cn) { + const auto* pad_attr = cn->attrs.as(); + CHECK(pad_attr); + auto p = pad_attr->pad_width; + const int dim_h = (p.size() == 5) ? 3 : 2; + const int dim_w = (p.size() == 5) ? 4 : 3; + std::vector padding = {std::to_string(p[dim_h][0].as()->value), + std::to_string(p[dim_w][0].as()->value), + std::to_string(p[dim_h][1].as()->value), + std::to_string(p[dim_w][1].as()->value)}; + std::vector padding_attr; + padding_attr.emplace_back(padding); + node->SetAttr("padding", padding_attr); + } + + void SetStridedSliceNodeAttribute(std::shared_ptr node, const CallNode* cn) { + const auto* attrs = cn->attrs.as(); + CHECK(attrs && attrs->begin && attrs->end && attrs->strides) + << "StridedSlice must have static begin, end, and strides."; + const bool default_strides = + !attrs->strides.value().defined() || attrs->strides.value().size() == 0; + auto ishape = backend::GetShape(cn->args[0]->checked_type()); + + auto process_slice_index = [](Integer x, int default_value, int dim_value) { + if (!x.defined()) return default_value; + int value = x.as()->value; + if (value < 0) value += dim_value; + return value; + }; + + std::vector start, size, strides; + for (size_t i = 0; i < attrs->begin.value().size(); ++i) { + const int begin_value = process_slice_index(attrs->begin.value()[i], 0, ishape[i]); + const int end_value = process_slice_index(attrs->end.value()[i], ishape[i], ishape[i]); + const int stride_value = (default_strides || i >= attrs->strides.value().size() || + !attrs->strides.value()[i].defined()) + ? 1 + : attrs->strides.value()[i].as()->value; + CHECK_GT(stride_value, 0); + const int size_value = (end_value - begin_value + stride_value - 1) / stride_value; + CHECK_GE(begin_value, 0); + CHECK_GT(size_value, 0); + start.push_back(std::to_string(begin_value)); + size.push_back(std::to_string(size_value)); + strides.push_back(std::to_string(stride_value)); + } + std::vector start_attr, size_attr, strides_attr; + start_attr.emplace_back(start); + size_attr.emplace_back(size); + strides_attr.emplace_back(strides); + node->SetAttr("start", start_attr); + node->SetAttr("size", size_attr); + node->SetAttr("strides", strides_attr); + } + + void SaveGlobalAttributes(std::shared_ptr node) { + auto ctx = transform::PassContext::Current(); + auto cfg = ctx->GetConfig("relay.ext.tensorrt.options"); + if (!cfg.defined()) { + cfg = AttrsWithDefaultValues(); + } + CHECK_EQ(cfg.value()->tensorrt_version.size(), 3); + std::vector tensorrt_version = {std::to_string(cfg.value()->tensorrt_version[0]), + std::to_string(cfg.value()->tensorrt_version[1]), + std::to_string(cfg.value()->tensorrt_version[2])}; + std::vector use_implicit_batch = {std::to_string(cfg.value()->use_implicit_batch)}; + std::vector max_workspace_size = {std::to_string(cfg.value()->max_workspace_size)}; + std::vector tensorrt_version_attr, use_implicit_batch_attr, max_workspace_size_attr; + tensorrt_version_attr.emplace_back(tensorrt_version); + use_implicit_batch_attr.emplace_back(use_implicit_batch); + max_workspace_size_attr.emplace_back(max_workspace_size); + node->SetAttr("tensorrt_version", tensorrt_version_attr); + node->SetAttr("use_implicit_batch", use_implicit_batch_attr); + node->SetAttr("max_workspace_size", max_workspace_size_attr); + } +}; + +/*! + * \brief Create a runtime module for TensorRT. + * \param ref The ext_func Relay expression/module to be executed using extern ops. + * \return A runtime module. + */ +runtime::Module TensorRTCompiler(const ObjectRef& ref) { + CHECK(ref->IsInstance()) << "The input ref is expected to be a Relay function."; + Function func = Downcast(ref); + std::string func_name = backend::GetExtSymbol(func); + + TensorRTJSONSerializer serializer(func_name, func); + serializer.serialize(); + std::string graph_json = serializer.GetJSON(); + auto param_names = serializer.GetParams(); + const auto* pf = runtime::Registry::Get("runtime.tensorrt_runtime_create"); + CHECK(pf != nullptr) << "Cannot find TensorRT runtime module create function."; + runtime::Module lib = (*pf)(func_name, graph_json, param_names); + return lib; +} + +TVM_REGISTER_GLOBAL("relay.ext.tensorrt").set_body_typed(TensorRTCompiler); + +/*! + * \brief Check whether TensorRT graph runtime is enabled. + * \return True if enabled, False if not. + */ +inline constexpr bool IsTensorRTRuntimeEnabled() { +#if TVM_GRAPH_RUNTIME_TENSORRT + return true; +#else + return false; +#endif // TVM_GRAPH_RUNTIME_TENSORRT +} + +/*! + * \brief Get TensorRT version that TVM is built against. + * \return Array of three integers for major, minor, and patch, or empty array if TensorRT graph + * runtime is not enabled. + */ +Array GetTensorRTVersion() { +#if TVM_GRAPH_RUNTIME_TENSORRT + return {Integer(NV_TENSORRT_MAJOR), Integer(NV_TENSORRT_MINOR), Integer(NV_TENSORRT_PATCH)}; +#else + return {}; +#endif // TVM_GRAPH_RUNTIME_TENSORRT +} + +TVM_REGISTER_GLOBAL("relay.op.is_tensorrt_runtime_enabled") + .set_body_typed(IsTensorRTRuntimeEnabled); +TVM_REGISTER_GLOBAL("relay.op.get_tensorrt_version").set_body_typed(GetTensorRTVersion); + +} // namespace contrib +} // namespace relay +} // namespace tvm diff --git a/src/runtime/contrib/tensorrt/tensorrt_builder.cc b/src/runtime/contrib/tensorrt/tensorrt_builder.cc new file mode 100644 index 000000000000..bf0dbfe724ed --- /dev/null +++ b/src/runtime/contrib/tensorrt/tensorrt_builder.cc @@ -0,0 +1,222 @@ +/* * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file runtime/contrib/tensorrt/tensorrt_builder.cc + * \brief The TensorRTBuilder class can be used to convert a JSONRuntime graph into a TRT engine + * which can be used for inference. + */ + +#include "tensorrt_builder.h" + +#include + +#include +#include + +#include "tensorrt_logger.h" +#include "tensorrt_ops.h" +#include "tensorrt_utils.h" + +namespace tvm { +namespace runtime { +namespace contrib { + +TensorRTBuilder::TensorRTBuilder(TensorRTLogger* logger, size_t max_workspace_size, + bool use_implicit_batch, bool use_fp16, int batch_size) + : max_workspace_size_(max_workspace_size), + use_implicit_batch_(use_implicit_batch), + use_fp16_(use_fp16), + batch_size_(batch_size) { + // Create TRT builder and network. + builder_ = nvinfer1::createInferBuilder(*logger); +#if TRT_VERSION_GE(6, 0, 1) + // Use INetworkV2. + auto flags = + 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); + if (use_implicit_batch_) { + flags = 0U; + builder_->setMaxBatchSize(batch_size_); + } + network_ = builder_->createNetworkV2(flags); +#else + // Use INetwork with implicit batch. + builder_->setMaxBatchSize(batch_size_); + builder_->setMaxWorkspaceSize(max_workspace_size_); + builder_->setFp16Mode(use_fp16_); + network_ = builder_->createNetwork(); +#endif +} + +void TensorRTBuilder::AddInput(int nid, const JSONGraphNode& node) { + auto node_name = node.GetOpName(); + auto shapes = node.GetOpShape(); + auto dtypes = node.GetOpDataType(); + CHECK_EQ(shapes.size(), dtypes.size()); + node_output_map_[nid] = {}; + for (size_t i = 0; i < shapes.size(); ++i) { + const std::string name = node_name + "_" + std::to_string(i); + auto shape = shapes[i]; + // Remove batch dim when not in explicit batch mode. + if (use_implicit_batch_ && shape.size() > 1) { + shape.erase(shape.begin()); + } + nvinfer1::Dims dims = VectorToTrtDims(shape); + CHECK(TypeMatch(dtypes[i], kDLFloat, 32)) << "Only FP32 inputs are supported."; + auto input_tensor = network_->addInput(name.c_str(), nvinfer1::DataType::kFLOAT, dims); + node_output_map_[nid].push_back(TensorRTOpInput(input_tensor)); + network_input_names_.push_back(input_tensor->getName()); + } +} + +void TensorRTBuilder::AddConstant(int nid, const DLTensor* data) { + nvinfer1::Weights weight = GetDLTensorAsWeights(data, kDLCPU); + std::vector shape(data->shape, data->shape + data->ndim); + // Remove batch dim when not in explicit batch mode. + if (use_implicit_batch_ && shape.size() > 1 && shape[0] == 1) { + shape.erase(shape.begin()); + } + node_output_map_[nid] = {TensorRTOpInput(weight, shape)}; +} + +void TensorRTBuilder::AddOutput(const JSONGraphNodeEntry& node) { + auto it = node_output_map_.find(node.id_); + CHECK(it != node_output_map_.end()) << "Output was not found."; + auto out_tensor = it->second[node.index_].tensor; + std::string name = "tensorrt_output_" + std::to_string(network_output_names_.size()); + out_tensor->setName(name.c_str()); + network_->markOutput(*out_tensor); + network_output_names_.push_back(out_tensor->getName()); +} + +void TensorRTBuilder::AddLayer(int nid, const JSONGraphNode& node) { + TensorRTOpConverterParams params(network_, node, &trt_weights_); + // Look up converter. + auto it = GetOpConverters()->find(params.op_name); + CHECK(it != GetOpConverters()->end()) + << "Unsupported operator conversion to TRT, op name: " << params.op_name; + const auto converter = it->second; + // Get inputs. + for (size_t i = 0; i < node.GetInputs().size(); ++i) { + auto in_node = node.GetInputs()[i]; + auto it = node_output_map_.find(in_node.id_); + CHECK(it != node_output_map_.end()) << "Input was not found."; + auto input = it->second[in_node.index_]; + if (!converter->variable_input_count) { + if (converter->input_types[i] == kTensor && input.type == kWeight) { + input = TensorRTOpInput(GetInputAsTensor(input)); + } else if (converter->input_types[i] == kWeight && input.type == kTensor) { + LOG(FATAL) << "Input " << i << " for " << params.op_name + << " requires weights but got a tensor."; + } + } + params.inputs.push_back(input); + } + CHECK(converter->variable_input_count || converter->input_types.size() == params.inputs.size()) + << "Op expected a different number of inputs."; + + // Convert op to TRT. + converter->Convert(¶ms); + + // Get outputs. + node_output_map_[nid] = {}; + for (auto out : params.outputs) { + node_output_map_[nid].push_back(TensorRTOpInput(out)); + } +} + +TensorRTEngineAndContext TensorRTBuilder::BuildEngine() { + // Process graph to create INetworkDefinition. +// Build engine. +#if TRT_VERSION_GE(6, 0, 1) + config_ = builder_->createBuilderConfig(); + config_->setMaxWorkspaceSize(max_workspace_size_); + if (use_fp16_) { + config_->setFlag(nvinfer1::BuilderFlag::kFP16); + } + // Add profiles. + if (!use_implicit_batch_) { + auto profile = builder_->createOptimizationProfile(); + for (int i = 0; i < network_->getNbInputs(); ++i) { + auto name = network_->getInput(i)->getName(); + auto dims = network_->getInput(i)->getDimensions(); + profile->setDimensions(name, nvinfer1::OptProfileSelector::kMIN, dims); + profile->setDimensions(name, nvinfer1::OptProfileSelector::kOPT, dims); + profile->setDimensions(name, nvinfer1::OptProfileSelector::kMAX, dims); + } + config_->addOptimizationProfile(profile); + } + nvinfer1::ICudaEngine* engine = builder_->buildEngineWithConfig(*network_, *config_); +#else + nvinfer1::ICudaEngine* engine = builder_->buildCudaEngine(*network_); +#endif + CHECK_EQ(engine->getNbBindings(), network_input_names_.size() + network_output_names_.size()); + nvinfer1::IExecutionContext* context = engine->createExecutionContext(); + CleanUp(); + return {engine, context, network_input_names_, network_output_names_}; +} + +nvinfer1::Weights TensorRTBuilder::GetDLTensorAsWeights(const DLTensor* dptr, + DLDeviceType src_device) { + CHECK_EQ(dptr->ctx.device_type, src_device); + CHECK(static_cast(dptr->dtype.code) == kDLFloat || + static_cast(dptr->dtype.code) == kDLInt); + const auto trt_dtype = static_cast(dptr->dtype.code) == kDLFloat + ? nvinfer1::DataType::kFLOAT + : nvinfer1::DataType::kINT32; + const size_t weight_bytes = GetDataSize(*dptr); + nvinfer1::Weights weight{trt_dtype, nullptr, 0}; + size_t count = 1; + for (tvm_index_t i = 0; i < dptr->ndim; ++i) { + count *= dptr->shape[i]; + } + CHECK_EQ(count * 4, weight_bytes); + weight.count = count; + weight.values = new float[count]; + CHECK_EQ(TVMArrayCopyToBytes(const_cast(dptr), const_cast(weight.values), + weight_bytes), + 0) + << TVMGetLastError(); + trt_weights_.push_back(weight); + return weight; +} + +nvinfer1::ITensor* TensorRTBuilder::GetInputAsTensor(const TensorRTOpInput& input) { + if (input.type == kTensor) return input.tensor; + auto dims = VectorToTrtDims(input.weight_shape); + return network_->addConstant(dims, input.weight)->getOutput(0); +} + +void TensorRTBuilder::CleanUp() { + network_->destroy(); +#if TRT_VERSION_GE(6, 0, 1) + config_->destroy(); +#endif + builder_->destroy(); + for (auto weight : trt_weights_) { + if (weight.type == nvinfer1::DataType::kFLOAT) { + delete[] static_cast(weight.values); + } else { + delete[] static_cast(weight.values); + } + } +} + +} // namespace contrib +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/contrib/tensorrt/tensorrt_builder.h b/src/runtime/contrib/tensorrt/tensorrt_builder.h new file mode 100644 index 000000000000..efb4d8175650 --- /dev/null +++ b/src/runtime/contrib/tensorrt/tensorrt_builder.h @@ -0,0 +1,159 @@ +/* * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file runtime/contrib/tensorrt/tensorrt_builder.h + * \brief The TensorRTBuilder class can be used to convert a JSONRuntime graph into a TRT engine + * which can be used for inference. + */ + +#ifndef TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_BUILDER_H_ +#define TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_BUILDER_H_ + +#include +#include +#include + +#include "../json/json_node.h" +#include "NvInfer.h" +#include "tensorrt_logger.h" +#include "tensorrt_ops.h" + +namespace tvm { +namespace runtime { +namespace contrib { + +using JSONGraphNode = tvm::runtime::json::JSONGraphNode; +using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry; + +/*! + * \brief The product of TensorRTBuilder which provides everything needed to + * perform inference. + */ +struct TensorRTEngineAndContext { + nvinfer1::ICudaEngine* engine; + nvinfer1::IExecutionContext* context; + std::vector inputs; + std::vector outputs; +}; + +/*! + * \brief Converts a JSONRuntime graph into a TensorRT engine and execution context. Inputs, + * constants, layers, and outputs can be added to construct the TensorRT network definition. + * BuildEngine() will then use the network definition to build the TensorRT engine and context which + * can be used to run inference - this phase can take a long time because TensorRT will query the + * performance of all available kernels and fusions to optimize the engine. + */ +class TensorRTBuilder { + public: + /*! + * \brief Create TensorRT builder. + * \param logger TensorRT logger to use for errors and warnings. + * \param max_workspace_size Workspace size parameter for TensorRT engine build phase. + * \param use_implicit_batch Whether to use implicit batch mode (default) + * \param use_fp16 Whether to use implicit batch mode (default) + * \param batch_size If use_implicit_batch, + */ + TensorRTBuilder(TensorRTLogger* logger, size_t max_workspace_size, bool use_implicit_batch, + bool use_fp16, int batch_size); + + /*! + * \brief Add TensorRT input(s) for input node in network definition. + * \param nid The input node id. + * \param node The input node. + */ + void AddInput(int nid, const JSONGraphNode& node); + + /*! + * \brief Add TensorRT weight for input constant in network definition. + * \param nid The input node id. + * \param node The data tensor on CPU. + */ + void AddConstant(int nid, const DLTensor* data); + + /*! + * \brief Add TensorRT layer for op node in network definition. + * \param nid The input node id. + * \param node The op node. + */ + void AddLayer(int nid, const JSONGraphNode& node); + + /*! + * \brief Mark TensorRT output in network definition. + * \param entry The output node entry. + */ + void AddOutput(const JSONGraphNodeEntry& entry); + + /*! + * \brief Takes network definition and "compiles" a TensorRT engine which can be used for + * inference. This step is time confusing. + * \return TRT engine, context, and input/output information. + */ + TensorRTEngineAndContext BuildEngine(); + + private: + /*! \brief Convert a DLTensor to a TensorRT weight. */ + nvinfer1::Weights GetDLTensorAsWeights(const DLTensor* dptr, DLDeviceType src_device); + + /*! \brief Convert an input to a Tensor if it is a Weight */ + nvinfer1::ITensor* GetInputAsTensor(const TensorRTOpInput& input); + + /*! \brief Clean up resources used to create engine. */ + void CleanUp(); + + /*! \brief Maps a node to its outputs. */ + std::unordered_map> node_output_map_; + + /*! \brief TensorRT builder. */ + nvinfer1::IBuilder* builder_; + +#if TRT_VERSION_GE(6, 0, 1) + /*! \brief TensorRT builder config. */ + nvinfer1::IBuilderConfig* config_; +#endif + + /*! \brief TensorRT network definition. */ + nvinfer1::INetworkDefinition* network_; + + /*! \brief List of all weights held in memory. */ + std::vector trt_weights_; + + /*! \brief Max workspace size in bytes for TRT. */ + size_t max_workspace_size_; + + /*! \brief Whether to use implicit batch mode. */ + bool use_implicit_batch_; + + /*! \brief Whether to automatically convert model to 16-bit floating point precision. */ + bool use_fp16_; + + /*! \brief Batch size to optimize for. */ + int batch_size_; + + /*! \brief Input names. */ + std::vector network_input_names_; + + /*! \brief Output names. */ + std::vector network_output_names_; +}; + +} // namespace contrib +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_BUILDER_H_ diff --git a/src/runtime/contrib/tensorrt/tensorrt_logger.h b/src/runtime/contrib/tensorrt/tensorrt_logger.h new file mode 100644 index 000000000000..53b6dfeea763 --- /dev/null +++ b/src/runtime/contrib/tensorrt/tensorrt_logger.h @@ -0,0 +1,78 @@ +/* * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file runtime/contrib/tensorrt/tensorrt_logger.h + * \brief Contains TensorRTLogger class which is required by TRT and used to + * print info, warnings, and errors. + */ + +#ifndef TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_LOGGER_H_ +#define TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_LOGGER_H_ + +#include + +#include "NvInfer.h" +#include "tensorrt_utils.h" + +namespace tvm { +namespace runtime { +namespace contrib { + +/*! \brief Logger for TensorRT info/warning/errors. */ +class TensorRTLogger : public nvinfer1::ILogger { + public: + TensorRTLogger() : TensorRTLogger(Severity::kWARNING) {} + explicit TensorRTLogger(Severity severity) : reportable_severity(severity) {} + void log(Severity severity, const char* msg) override { + // suppress messages with severity enum value greater than the reportable + if (severity > reportable_severity) return; + + switch (severity) { + case Severity::kINTERNAL_ERROR: + LOG(ERROR) << "INTERNAL_ERROR: " << msg; + break; + case Severity::kERROR: + LOG(ERROR) << "ERROR: " << msg; + break; + case Severity::kWARNING: + LOG(WARNING) << "WARNING: " << msg; + break; + case Severity::kINFO: + LOG(INFO) << "INFO: " << msg; + break; +#if TRT_VERSION_GE(5, 1, 5) + case Severity::kVERBOSE: + DLOG(INFO) << "VERBOSE: " << msg; + break; +#endif + default: + LOG(INFO) << "UNKNOWN: " << msg; + break; + } + } + + private: + Severity reportable_severity{Severity::kWARNING}; +}; + +} // namespace contrib +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_LOGGER_H_ diff --git a/src/runtime/contrib/tensorrt/tensorrt_ops.cc b/src/runtime/contrib/tensorrt/tensorrt_ops.cc new file mode 100644 index 000000000000..a1da6c39f68e --- /dev/null +++ b/src/runtime/contrib/tensorrt/tensorrt_ops.cc @@ -0,0 +1,1070 @@ +/* * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file runtime/contrib/tensorrt/tensorrt_ops.cc + * \brief Converters from Relay ops into TensorRT layers. Converters should + * inherit from TensorRTOpConverter and implement the Convert() method. + */ + +#include "tensorrt_ops.h" + +#include +#include +#include +#include +#include +#include + +#include "../json/json_node.h" +#include "NvInfer.h" +#include "tensorrt_utils.h" + +namespace tvm { +namespace runtime { +namespace contrib { + +TensorRTOpConverter::TensorRTOpConverter(const std::vector& input_types, + bool variable_input_count) + : input_types(input_types), variable_input_count(variable_input_count) {} + +nvinfer1::ITensor* TensorRTOpConverter::Reshape(TensorRTOpConverterParams* params, + nvinfer1::ITensor* input, + const std::vector& new_shape) const { + auto layer = params->network->addShuffle(*input); + CHECK(layer != nullptr); + layer->setReshapeDimensions(VectorToTrtDims(new_shape)); + return layer->getOutput(0); +} + +nvinfer1::ITensor* TensorRTOpConverter::Transpose(TensorRTOpConverterParams* params, + nvinfer1::ITensor* input, + const std::vector& order) const { + auto layer = params->network->addShuffle(*input); + CHECK(layer != nullptr); + nvinfer1::Permutation perm; + if (TRT_HAS_IMPLICIT_BATCH(params)) { + // Batch dimension cannot be modified. + CHECK_EQ(input->getDimensions().nbDims, order.size() - 1); + CHECK_EQ(order[0], 0); + for (size_t i = 0; i < order.size(); ++i) { + perm.order[i] = order[i + 1] - 1; + } + } else { + CHECK_EQ(input->getDimensions().nbDims, order.size()); + for (size_t i = 0; i < order.size(); ++i) { + perm.order[i] = order[i]; + } + } + layer->setFirstTranspose(perm); + return layer->getOutput(0); +} + +int TensorRTOpConverter::ConvertAxis(TensorRTOpConverterParams* params, int axis, + int input_rank) const { + // Add 1 for missing batch dim. + if (TRT_HAS_IMPLICIT_BATCH(params)) { + input_rank += 1; + } + CHECK(axis >= -input_rank && axis < input_rank); + if (axis < 0) axis += input_rank; + if (TRT_HAS_IMPLICIT_BATCH(params)) { + // Can't modify batch dimenson. + CHECK_NE(axis, 0); + // Subtract 1 for implicit batch dim. + axis -= 1; + } + return axis; +} + +nvinfer1::ITensor* TensorRTOpConverter::CreateScalar( + TensorRTOpConverterParams* params, float value, const nvinfer1::Dims& broadcast_to_dims) const { + nvinfer1::Dims dims; + dims.nbDims = broadcast_to_dims.nbDims; + std::fill_n(dims.d, dims.nbDims, 1); + float* values = new float[1]; + values[0] = value; + nvinfer1::Weights weights{nvinfer1::DataType::kFLOAT, static_cast(values), 1}; + params->trt_weights->push_back(weights); + return params->network->addConstant(dims, weights)->getOutput(0); +} + +void TensorRTOpConverter::GetPadding(const std::vector& padding, + bool* use_asymmetric_padding, nvinfer1::DimsHW* prepadding, + nvinfer1::DimsHW* postpadding) const { + CHECK(padding.size() == 1 || padding.size() == 2 || padding.size() == 4); + if (padding.size() == 4) { + // four int : padding width in the order of (top, left, bottom, right). + *prepadding = nvinfer1::DimsHW(std::stoi(padding[0]), std::stoi(padding[1])); + *postpadding = nvinfer1::DimsHW(std::stoi(padding[2]), std::stoi(padding[3])); + *use_asymmetric_padding = true; + } else if (padding.size() == 2) { + // two int : bottom, right will use same padding as top, left + *prepadding = nvinfer1::DimsHW(std::stoi(padding[0]), std::stoi(padding[1])); + *postpadding = *prepadding; + *use_asymmetric_padding = false; + } else { + // one int : same padding used on all sides + *prepadding = nvinfer1::DimsHW(std::stoi(padding[0]), std::stoi(padding[0])); + *postpadding = *prepadding; + *use_asymmetric_padding = false; + } +} + +void TensorRTOpConverter::GetPadding3D(const std::vector& padding, + bool* use_asymmetric_padding, nvinfer1::Dims* prepadding, + nvinfer1::Dims* postpadding) const { + CHECK(padding.size() == 1 || padding.size() == 3 || padding.size() == 6); + if (padding.size() == 6) { + // six int : padding width in the order of (front, top, left, back, bottom, right) + *prepadding = + nvinfer1::Dims3(std::stoi(padding[0]), std::stoi(padding[1]), std::stoi(padding[2])); + *postpadding = + nvinfer1::Dims3(std::stoi(padding[3]), std::stoi(padding[4]), std::stoi(padding[5])); + *use_asymmetric_padding = true; + } else if (padding.size() == 3) { + // three int : back, bottom, right will use same padding as front, top, left + *prepadding = + nvinfer1::Dims3(std::stoi(padding[0]), std::stoi(padding[1]), std::stoi(padding[2])); + *postpadding = *prepadding; + *use_asymmetric_padding = false; + } else { + // one int : same padding used on all sides + *prepadding = + nvinfer1::Dims3(std::stoi(padding[0]), std::stoi(padding[0]), std::stoi(padding[0])); + *postpadding = *prepadding; + *use_asymmetric_padding = false; + } +} + +class ActivationOpConverter : public TensorRTOpConverter { + public: + ActivationOpConverter() : TensorRTOpConverter({kTensor}) {} + + void Convert(TensorRTOpConverterParams* params) const { + static const std::unordered_map op_map = { + {"nn.relu", nvinfer1::ActivationType::kRELU}, + {"sigmoid", nvinfer1::ActivationType::kSIGMOID}, + {"tanh", nvinfer1::ActivationType::kTANH}, +#if TRT_VERSION_GE(5, 1, 5) + {"clip", nvinfer1::ActivationType::kCLIP}, + {"nn.leaky_relu", nvinfer1::ActivationType::kLEAKY_RELU}, +#endif + }; + auto it = op_map.find(params->op_name); + CHECK(it != op_map.end()) << "Unsupported activation type " << params->op_name; + nvinfer1::IActivationLayer* act_layer = + params->network->addActivation(*params->inputs.at(0).tensor, it->second); +#if TRT_VERSION_GE(5, 1, 5) + if (params->op_name == "clip") { + float a_min = std::stof(params->node.GetAttr>("a_min")[0]); + float a_max = std::stof(params->node.GetAttr>("a_max")[0]); + act_layer->setAlpha(a_min); + act_layer->setBeta(a_max); + } else if (params->op_name == "nn.leaky_relu") { + float alpha = std::stof(params->node.GetAttr>("alpha")[0]); + act_layer->setAlpha(alpha); + } +#endif + CHECK(act_layer != nullptr); + params->outputs.push_back(act_layer->getOutput(0)); + } +}; + +class ElementWiseBinaryOpConverter : public TensorRTOpConverter { + public: + ElementWiseBinaryOpConverter() : TensorRTOpConverter({kTensor, kTensor}) {} + + void Convert(TensorRTOpConverterParams* params) const { + static const std::unordered_map op_map = { + {"add", nvinfer1::ElementWiseOperation::kSUM}, + {"subtract", nvinfer1::ElementWiseOperation::kSUB}, + {"multiply", nvinfer1::ElementWiseOperation::kPROD}, + {"divide", nvinfer1::ElementWiseOperation::kDIV}, + {"power", nvinfer1::ElementWiseOperation::kPOW}, + {"maximum", nvinfer1::ElementWiseOperation::kMAX}, + {"minimum", nvinfer1::ElementWiseOperation::kMIN}}; + auto it = op_map.find(params->op_name); + CHECK(it != op_map.end()) << "Unsupported elementwise type " << params->op_name; + // Broadcast + auto input0 = params->inputs.at(0).tensor; + auto input0_dims = TrtDimsToVector(input0->getDimensions()); + auto input1 = params->inputs.at(1).tensor; + auto input1_dims = TrtDimsToVector(input1->getDimensions()); + const bool need_broadcast = input0_dims.size() != input1_dims.size(); + if (need_broadcast) { + if (input0_dims.size() < input1_dims.size()) { + std::vector new_shape(input0_dims); + while (new_shape.size() < input1_dims.size()) new_shape.insert(new_shape.begin(), 1); + input0 = Reshape(params, input0, new_shape); + } else if (input1_dims.size() < input0_dims.size()) { + std::vector new_shape(input1_dims); + while (new_shape.size() < input0_dims.size()) new_shape.insert(new_shape.begin(), 1); + input1 = Reshape(params, input1, new_shape); + } + } + + nvinfer1::IElementWiseLayer* elemwise_layer = + params->network->addElementWise(*input0, *input1, it->second); + CHECK(elemwise_layer != nullptr); + params->outputs.push_back(elemwise_layer->getOutput(0)); + } +}; + +class Conv2DOpConverter : public TensorRTOpConverter { + public: + Conv2DOpConverter() : TensorRTOpConverter({kTensor, kWeight}) {} + + void Convert(TensorRTOpConverterParams* params) const { + auto input_tensor = params->inputs.at(0).tensor; + auto input_dims = TrtDimsToVector(input_tensor->getDimensions()); + auto weight_shape = params->inputs.at(1).weight_shape; + CHECK_EQ(params->node.GetAttr>("data_layout")[0], "NCHW"); + CHECK(params->node.GetAttr>("out_layout")[0] == "" || + params->node.GetAttr>("out_layout")[0] == "NCHW"); + CHECK_EQ(params->node.GetAttr>("kernel_layout")[0], "OIHW"); + auto str_strides = params->node.GetAttr>("strides"); + auto str_dilation = params->node.GetAttr>("dilation"); + auto str_padding = params->node.GetAttr>("padding"); + int groups = std::stoi(params->node.GetAttr>("groups")[0]); + int channels = std::stoi(params->node.GetAttr>("channels")[0]); + // TRT conv2d op doesn't support asymmetric padding before 5.1, so we + // workaround by adding a padding layer before the pooling op. + nvinfer1::DimsHW prepadding, postpadding; + bool use_asymmetric_padding; + GetPadding(str_padding, &use_asymmetric_padding, &prepadding, &postpadding); +#if !TRT_VERSION_GE(5, 1, 5) + if (use_asymmetric_padding) { + auto pad_layer = params->network->addPadding(*input_tensor, prepadding, postpadding); + CHECK(pad_layer != nullptr); + input_tensor = pad_layer->getOutput(0); + // No need for conv op to do any padding. + use_asymmetric_padding = false; + prepadding = nvinfer1::DimsHW(0, 0); + } +#endif + + const auto kernel_size = nvinfer1::DimsHW(weight_shape[2], weight_shape[3]); + nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT, nullptr, 0}; + auto conv_layer = params->network->addConvolution(*input_tensor, channels, kernel_size, + params->inputs.at(1).weight, bias); + CHECK(conv_layer != nullptr); + if (use_asymmetric_padding) { +#if TRT_VERSION_GE(5, 1, 5) + conv_layer->setPrePadding(prepadding); + conv_layer->setPostPadding(postpadding); +#endif + } else { + conv_layer->setPadding(prepadding); + } + CHECK_EQ(str_strides.size(), 2); + const auto strides = nvinfer1::DimsHW(std::stoi(str_strides[0]), std::stoi(str_strides[1])); + conv_layer->setStride(strides); + CHECK_EQ(str_dilation.size(), 2); + const auto dilation = nvinfer1::DimsHW(std::stoi(str_dilation[0]), std::stoi(str_dilation[1])); + conv_layer->setDilation(dilation); + conv_layer->setNbGroups(groups); + params->outputs.push_back(conv_layer->getOutput(0)); + } +}; + +#if TRT_VERSION_GE(6, 0, 1) +class Conv3DOpConverter : public TensorRTOpConverter { + public: + Conv3DOpConverter() : TensorRTOpConverter({kTensor, kWeight}) {} + + void Convert(TensorRTOpConverterParams* params) const { + auto input_tensor = params->inputs.at(0).tensor; + auto input_dims = TrtDimsToVector(input_tensor->getDimensions()); + auto weight_shape = params->inputs.at(1).weight_shape; + CHECK_EQ(params->node.GetAttr>("data_layout")[0], "NCDHW"); + CHECK(params->node.GetAttr>("out_layout")[0] == "" || + params->node.GetAttr>("out_layout")[0] == "NCDHW"); + CHECK_EQ(params->node.GetAttr>("kernel_layout")[0], "OIDHW"); + auto str_strides = params->node.GetAttr>("strides"); + auto str_dilation = params->node.GetAttr>("dilation"); + auto str_padding = params->node.GetAttr>("padding"); + int groups = std::stoi(params->node.GetAttr>("groups")[0]); + + nvinfer1::Dims prepadding, postpadding; + bool use_asymmetric_padding; + GetPadding3D(str_padding, &use_asymmetric_padding, &prepadding, &postpadding); + + // Could use attrs->channels.as()->value + const int num_outputs = weight_shape[0]; + const auto kernel_size = nvinfer1::Dims3(weight_shape[2], weight_shape[3], weight_shape[4]); + nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT, nullptr, 0}; + auto conv_layer = params->network->addConvolutionNd(*input_tensor, num_outputs, kernel_size, + params->inputs.at(1).weight, bias); + CHECK(conv_layer != nullptr); + if (use_asymmetric_padding) { + conv_layer->setPrePadding(prepadding); + conv_layer->setPostPadding(postpadding); + } else { + conv_layer->setPaddingNd(prepadding); + } + CHECK_EQ(str_strides.size(), 3); + const auto strides = nvinfer1::Dims3(std::stoi(str_strides[0]), std::stoi(str_strides[1]), + std::stoi(str_strides[2])); + conv_layer->setStrideNd(strides); + CHECK_EQ(str_dilation.size(), 3); + const auto dilation = nvinfer1::Dims3(std::stoi(str_dilation[0]), std::stoi(str_dilation[1]), + std::stoi(str_dilation[2])); + conv_layer->setDilationNd(dilation); + conv_layer->setNbGroups(groups); + params->outputs.push_back(conv_layer->getOutput(0)); + } +}; +#endif // TRT_VERSION_GE(6, 0, 1) + +class DenseOpConverter : public TensorRTOpConverter { + public: + DenseOpConverter() : TensorRTOpConverter({kTensor, kWeight}) {} + + void Convert(TensorRTOpConverterParams* params) const { + auto input_tensor = params->inputs.at(0).tensor; + auto input_dims = TrtDimsToVector(input_tensor->getDimensions()); + CHECK(input_dims.size() > 0 && input_dims.size() <= 3); + const size_t required_rank = TRT_HAS_IMPLICIT_BATCH(params) ? 3 : 4; + const bool need_reshape_on_input = input_dims.size() != required_rank; + if (need_reshape_on_input) { + // Add dims of size 1 until rank is required_rank. + std::vector new_shape(input_dims); + while (new_shape.size() < required_rank) new_shape.insert(new_shape.end(), 1); + input_tensor = Reshape(params, input_tensor, new_shape); + } + // Weights are in KC format. + CHECK_EQ(params->inputs.at(1).weight_shape.size(), 2); + const int num_units = params->inputs.at(1).weight_shape[0]; + nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT, nullptr, 0}; + nvinfer1::IFullyConnectedLayer* fc_layer = params->network->addFullyConnected( + *input_tensor, num_units, params->inputs.at(1).weight, bias); + CHECK(fc_layer != nullptr); + auto output_tensor = fc_layer->getOutput(0); + if (need_reshape_on_input) { + // Remove added dims. + input_dims[input_dims.size() - 1] = num_units; + output_tensor = Reshape(params, output_tensor, input_dims); + } + params->outputs.push_back(output_tensor); + } +}; + +class BatchNormOpConverter : public TensorRTOpConverter { + public: + BatchNormOpConverter() : TensorRTOpConverter({kTensor, kWeight, kWeight, kWeight, kWeight}) {} + + void Convert(TensorRTOpConverterParams* params) const { + auto input = params->inputs.at(0).tensor; + auto gamma = params->inputs.at(1).weight; + auto beta = params->inputs.at(2).weight; + auto mean = params->inputs.at(3).weight; + auto var = params->inputs.at(4).weight; + CHECK_EQ(gamma.count, beta.count); + CHECK_EQ(gamma.count, mean.count); + CHECK_EQ(gamma.count, var.count); + const float epsilon = std::stof(params->node.GetAttr>("epsilon")[0]); + const int axis = std::stoi(params->node.GetAttr>("axis")[0]); + const bool scale = std::stoi(params->node.GetAttr>("scale")[0]); + const bool center = std::stoi(params->node.GetAttr>("center")[0]); + CHECK(axis == 1 || axis == 3); + const bool need_transpose = axis == 3; + + void* weight_scale_ptr = new float[gamma.count]; + nvinfer1::Weights weight_scale{nvinfer1::DataType::kFLOAT, weight_scale_ptr, gamma.count}; + params->trt_weights->push_back(weight_scale); + void* weight_shift_ptr = new float[gamma.count]; + nvinfer1::Weights weight_shift{nvinfer1::DataType::kFLOAT, weight_shift_ptr, gamma.count}; + params->trt_weights->push_back(weight_shift); + nvinfer1::Weights power{nvinfer1::DataType::kFLOAT, nullptr, 0}; + + // fill in the content of weights for the Scale layer + const float* gamma_ptr = reinterpret_cast(gamma.values); + const float* beta_ptr = reinterpret_cast(beta.values); + const float* mean_ptr = reinterpret_cast(mean.values); + const float* var_ptr = reinterpret_cast(var.values); + float* scale_ptr = reinterpret_cast(weight_scale_ptr); + float* shift_ptr = reinterpret_cast(weight_shift_ptr); + for (int i = 0; i < gamma.count; ++i) { + scale_ptr[i] = 1.0 / std::sqrt(var_ptr[i] + epsilon); + if (scale) { + scale_ptr[i] *= gamma_ptr[i]; + } + shift_ptr[i] = -mean_ptr[i] * scale_ptr[i]; + if (center) { + shift_ptr[i] += beta_ptr[i]; + } + } + if (need_transpose) { + input = Transpose(params, input, {0, 3, 1, 2}); + } + nvinfer1::IScaleLayer* scale_layer = params->network->addScale( + *input, nvinfer1::ScaleMode::kCHANNEL, weight_shift, weight_scale, power); + CHECK(scale_layer != nullptr); + auto output = scale_layer->getOutput(0); + if (need_transpose) { + output = Transpose(params, output, {0, 2, 3, 1}); + } + params->outputs.push_back(output); + } +}; + +class BatchFlattenOpConverter : public TensorRTOpConverter { + public: + BatchFlattenOpConverter() : TensorRTOpConverter({kTensor}) {} + + void Convert(TensorRTOpConverterParams* params) const { + std::vector new_shape{-1}; + if (!TRT_HAS_IMPLICIT_BATCH(params)) { + new_shape.insert(new_shape.begin(), params->inputs.at(0).tensor->getDimensions().d[0]); + } + params->outputs.push_back(Reshape(params, params->inputs.at(0).tensor, new_shape)); + } +}; + +class SoftmaxOpConverter : public TensorRTOpConverter { + public: + SoftmaxOpConverter() : TensorRTOpConverter({kTensor}) {} + + void Convert(TensorRTOpConverterParams* params) const { + auto input = params->inputs.at(0).tensor; + const int input_rank = input->getDimensions().nbDims; + const int original_axis = std::stoi(params->node.GetAttr>("axis")[0]); + const int axis = ConvertAxis(params, original_axis, input_rank); + nvinfer1::ISoftMaxLayer* softmax_layer = params->network->addSoftMax(*input); + softmax_layer->setAxes(1 << axis); + CHECK(softmax_layer != nullptr); + params->outputs.push_back(softmax_layer->getOutput(0)); + } +}; + +class PoolingOpConverter : public TensorRTOpConverter { + public: + PoolingOpConverter() : TensorRTOpConverter({kTensor}) {} + + void Convert(TensorRTOpConverterParams* params) const { + auto input = params->inputs.at(0).tensor; + static const std::unordered_map op_map = { + {"nn.max_pool2d", nvinfer1::PoolingType::kMAX}, + {"nn.avg_pool2d", nvinfer1::PoolingType::kAVERAGE}}; + auto it = op_map.find(params->op_name); + CHECK(it != op_map.end()) << "Unsupported pooling type " << params->op_name << " in TensorRT"; + CHECK_EQ(params->node.GetAttr>("layout")[0], "NCHW"); + auto str_pool_size = params->node.GetAttr>("pool_size"); + auto str_padding = params->node.GetAttr>("padding"); + auto str_strides = params->node.GetAttr>("strides"); + nvinfer1::DimsHW prepadding, postpadding; + bool use_asymmetric_padding; + GetPadding(str_padding, &use_asymmetric_padding, &prepadding, &postpadding); + bool ceil_mode = std::stoi(params->node.GetAttr>("ceil_mode")[0]); + +// TRT pooling op doesn't support asymmetric padding before 5.1, so we +// workaround by adding a padding layer before the pooling op. +#if !TRT_VERSION_GE(5, 1, 5) + if (use_asymmetric_padding) { + auto pad_layer = params->network->addPadding(*input, prepadding, postpadding); + CHECK(pad_layer != nullptr); + input = pad_layer->getOutput(0); + // No need for pooling op to do any padding. + use_asymmetric_padding = false; + prepadding = nvinfer1::DimsHW(0, 0); + } +#endif + + nvinfer1::DimsHW window_size = + nvinfer1::DimsHW(std::stoi(str_pool_size[0]), std::stoi(str_pool_size[1])); + auto pool_layer = params->network->addPooling(*input, it->second, window_size); + CHECK(pool_layer != nullptr); + nvinfer1::DimsHW strides = + nvinfer1::DimsHW(std::stoi(str_strides[0]), std::stoi(str_strides[1])); + pool_layer->setStride(strides); + if (use_asymmetric_padding) { +#if TRT_VERSION_GE(5, 1, 5) + pool_layer->setPrePadding(prepadding); + pool_layer->setPostPadding(postpadding); +#endif + } else { + pool_layer->setPadding(prepadding); + } + if (params->op_name == "nn.avg_pool2d") { + bool count_include_pad = + std::stoi(params->node.GetAttr>("count_include_pad")[0]); + // count_include_pad=True is useless if there is no padding. TRT doesn't + // like count_include_pad in combination with strides even when there is + // no padding or assymetric padding even, so turn off inclusive to avoid + // error message. Note: Padding will always be symmetric with + // count_include_pad since partitioner will prevent unsupported case. + if (prepadding.h() == 0 && prepadding.w() == 0) { + count_include_pad = false; + } + pool_layer->setAverageCountExcludesPadding(!count_include_pad); + } +#if TRT_VERSION_GE(5, 1, 5) + if (ceil_mode) { + pool_layer->setPaddingMode(nvinfer1::PaddingMode::kEXPLICIT_ROUND_UP); + } +#else + CHECK(!ceil_mode); +#endif + params->outputs.push_back(pool_layer->getOutput(0)); + } +}; + +#if TRT_VERSION_GE(6, 0, 1) +class Pooling3DOpConverter : public TensorRTOpConverter { + public: + Pooling3DOpConverter() : TensorRTOpConverter({kTensor}) {} + + void Convert(TensorRTOpConverterParams* params) const { + auto input = params->inputs.at(0).tensor; + static const std::unordered_map op_map = { + {"nn.max_pool3d", nvinfer1::PoolingType::kMAX}, + {"nn.avg_pool3d", nvinfer1::PoolingType::kAVERAGE}}; + auto it = op_map.find(params->op_name); + CHECK(it != op_map.end()) << "Unsupported pooling type " << params->op_name << " in TensorRT"; + CHECK_EQ(params->node.GetAttr>("layout")[0], "NCDHW"); + auto str_pool_size = params->node.GetAttr>("pool_size"); + auto str_padding = params->node.GetAttr>("padding"); + auto str_strides = params->node.GetAttr>("strides"); + nvinfer1::DimsHW prepadding, postpadding; + bool use_asymmetric_padding; + GetPadding3D(str_padding, &use_asymmetric_padding, &prepadding, &postpadding); + bool ceil_mode = std::stoi(params->node.GetAttr>("ceil_mode")[0]); + nvinfer1::Dims window_size = nvinfer1::Dims3( + std::stoi(str_pool_size[0]), std::stoi(str_pool_size[1]), std::stoi(str_pool_size[2])); + auto pool_layer = params->network->addPoolingNd(*input, it->second, window_size); + CHECK(pool_layer != nullptr); + nvinfer1::Dims strides = nvinfer1::Dims3(std::stoi(str_strides[0]), std::stoi(str_strides[1]), + std::stoi(str_strides[2])); + pool_layer->setStrideNd(strides); + if (use_asymmetric_padding) { + pool_layer->setPrePadding(prepadding); + pool_layer->setPostPadding(postpadding); + } else { + pool_layer->setPaddingNd(prepadding); + } + if (params->op_name == "nn.avg_pool3d") { + bool count_include_pad = + std::stoi(params->node.GetAttr>("count_include_pad")[0]); + pool_layer->setAverageCountExcludesPadding(!count_include_pad); + } + if (ceil_mode) { + pool_layer->setPaddingMode(nvinfer1::PaddingMode::kEXPLICIT_ROUND_UP); + } + params->outputs.push_back(pool_layer->getOutput(0)); + } +}; +#endif // TRT_VERSION_GE(6, 0, 1) + +class GlobalPoolingOpConverter : public TensorRTOpConverter { + public: + GlobalPoolingOpConverter() : TensorRTOpConverter({kTensor}) {} + + void Convert(TensorRTOpConverterParams* params) const { + auto input_tensor = params->inputs.at(0).tensor; + auto input_dims = TrtDimsToVector(input_tensor->getDimensions()); + static const std::unordered_map op_map = { + {"nn.global_max_pool2d", nvinfer1::PoolingType::kMAX}, + {"nn.global_avg_pool2d", nvinfer1::PoolingType::kAVERAGE}}; + auto it = op_map.find(params->op_name); + CHECK(it != op_map.end()) << "Unsupported pooling type " << params->op_name << " in TensorRT"; + CHECK_EQ(params->node.GetAttr>("layout")[0], "NCHW"); + const int h = TRT_HAS_IMPLICIT_BATCH(params) ? input_dims[1] : input_dims[2]; + const int w = TRT_HAS_IMPLICIT_BATCH(params) ? input_dims[2] : input_dims[3]; + auto pool_layer = + params->network->addPooling(*input_tensor, it->second, nvinfer1::DimsHW(h, w)); + CHECK(pool_layer != nullptr); + params->outputs.push_back(pool_layer->getOutput(0)); + } +}; + +class ExpandDimsOpConverter : public TensorRTOpConverter { + public: + ExpandDimsOpConverter() : TensorRTOpConverter({kTensor}) {} + + void Convert(TensorRTOpConverterParams* params) const { + auto input_tensor = params->inputs.at(0).tensor; + auto input_dims = TrtDimsToVector(input_tensor->getDimensions()); + const int original_axis = std::stoi(params->node.GetAttr>("axis")[0]); + const int num_newaxis = + std::stoi(params->node.GetAttr>("num_newaxis")[0]); + const int axis = ConvertAxis(params, original_axis, input_dims.size() + 1); + for (int i = 0; i < num_newaxis; ++i) { + input_dims.insert(input_dims.begin() + axis, 1); + } + params->outputs.push_back(Reshape(params, params->inputs.at(0).tensor, input_dims)); + } +}; + +class SqueezeOpConverter : public TensorRTOpConverter { + public: + SqueezeOpConverter() : TensorRTOpConverter({kTensor}) {} + + void Convert(TensorRTOpConverterParams* params) const { + auto input_tensor = params->inputs.at(0).tensor; + auto input_dims = TrtDimsToVector(input_tensor->getDimensions()); + auto str_axis = params->node.GetAttr>("axis"); + for (size_t i = 0; i < str_axis.size(); ++i) { + const int axis = ConvertAxis(params, std::stoi(str_axis[i]), input_dims.size()); + input_dims[axis] = 0; + } + input_dims.erase(std::remove(input_dims.begin(), input_dims.end(), 0), input_dims.end()); + params->outputs.push_back(Reshape(params, params->inputs.at(0).tensor, input_dims)); + } +}; + +class UnaryOpConverter : public TensorRTOpConverter { + public: + UnaryOpConverter() : TensorRTOpConverter({kTensor}) {} + + void Convert(TensorRTOpConverterParams* params) const { + // The following ops are supported by TRT but don't exist in relay yet: + // recip, tan, sinh, cosh, asin, acos, asinh, acosh, atanh + static const std::unordered_map op_map = { + {"exp", nvinfer1::UnaryOperation::kEXP}, + {"log", nvinfer1::UnaryOperation::kLOG}, + {"sqrt", nvinfer1::UnaryOperation::kSQRT}, + {"abs", nvinfer1::UnaryOperation::kABS}, + {"negative", nvinfer1::UnaryOperation::kNEG}, +#if TRT_VERSION_GE(5, 1, 5) + {"sin", nvinfer1::UnaryOperation::kSIN}, + {"cos", nvinfer1::UnaryOperation::kCOS}, + {"atan", nvinfer1::UnaryOperation::kATAN}, + {"ceil", nvinfer1::UnaryOperation::kCEIL}, + {"floor", nvinfer1::UnaryOperation::kFLOOR}, +#endif + }; + auto it = op_map.find(params->op_name); + CHECK(it != op_map.end()) << "Unsupported unary type " << params->op_name; + nvinfer1::IUnaryLayer* unary_layer = + params->network->addUnary(*params->inputs.at(0).tensor, it->second); + CHECK(unary_layer != nullptr); + params->outputs.push_back(unary_layer->getOutput(0)); + } +}; + +class ConcatOpConverter : public TensorRTOpConverter { + public: + ConcatOpConverter() : TensorRTOpConverter({}, /*variable_input_count=*/true) {} + + void Convert(TensorRTOpConverterParams* params) const { + const int num_inputs = params->inputs.size(); + CHECK_GT(num_inputs, 0); + const int input_rank = params->inputs[0].tensor->getDimensions().nbDims; + std::vector input_tensors; + for (auto input : params->inputs) { + CHECK(input.type == kTensor); + CHECK_EQ(input_rank, input.tensor->getDimensions().nbDims); + input_tensors.push_back(input.tensor); + } + + const int original_axis = std::stoi(params->node.GetAttr>("axis")[0]); + const int axis = ConvertAxis(params, original_axis, input_rank); + + nvinfer1::IConcatenationLayer* concat_layer = + params->network->addConcatenation(input_tensors.data(), input_tensors.size()); + CHECK(concat_layer != nullptr); + concat_layer->setAxis(axis); + params->outputs.push_back(concat_layer->getOutput(0)); + } +}; + +class BiasAddOpConverter : public TensorRTOpConverter { + public: + BiasAddOpConverter() : TensorRTOpConverter({kTensor, kWeight}) {} + + void Convert(TensorRTOpConverterParams* params) const { + auto input_tensor = params->inputs.at(0).tensor; + auto input_dims = TrtDimsToVector(input_tensor->getDimensions()); + const size_t required_rank = TRT_HAS_IMPLICIT_BATCH(params) ? 3 : 4; + CHECK(input_dims.size() > 0 && input_dims.size() <= required_rank); + const bool need_reshape_on_input = input_dims.size() != required_rank; + if (need_reshape_on_input) { + // Add dims of size 1 until rank is required_rank. + std::vector new_shape(input_dims); + while (new_shape.size() < required_rank) new_shape.insert(new_shape.end(), 1); + input_tensor = Reshape(params, input_tensor, new_shape); + } + + nvinfer1::Weights shift{nvinfer1::DataType::kFLOAT, nullptr, 0}; + nvinfer1::Weights power{nvinfer1::DataType::kFLOAT, nullptr, 0}; + nvinfer1::IScaleLayer* scale_layer = params->network->addScale( + *input_tensor, nvinfer1::ScaleMode::kCHANNEL, params->inputs.at(1).weight, shift, power); + CHECK(scale_layer != nullptr); + auto output_tensor = scale_layer->getOutput(0); + if (need_reshape_on_input) { + // Remove added dims. + output_tensor = Reshape(params, output_tensor, input_dims); + } + params->outputs.push_back(output_tensor); + } +}; + +class Conv2DTransposeOpConverter : public TensorRTOpConverter { + public: + Conv2DTransposeOpConverter() : TensorRTOpConverter({kTensor, kWeight}) {} + + void Convert(TensorRTOpConverterParams* params) const { + auto input_tensor = params->inputs.at(0).tensor; + auto weight_shape = params->inputs.at(1).weight_shape; + CHECK_EQ(params->node.GetAttr>("data_layout")[0], "NCHW"); + CHECK(params->node.GetAttr>("out_layout")[0] == "" || + params->node.GetAttr>("out_layout")[0] == "NCHW"); + CHECK_EQ(params->node.GetAttr>("kernel_layout")[0], "OIHW"); + auto str_dilation = params->node.GetAttr>("dilation"); + CHECK(std::stoi(str_dilation[0]) == 1 && std::stoi(str_dilation[1]) == 1); + auto str_strides = params->node.GetAttr>("strides"); + auto str_padding = params->node.GetAttr>("padding"); + auto str_output_padding = params->node.GetAttr>("output_padding"); + int groups = std::stoi(params->node.GetAttr>("groups")[0]); + + // TRT deconv op doesn't support asymmetric padding before 5.1, so we + // workaround by adding a padding layer before the pooling op. + nvinfer1::DimsHW prepadding, postpadding; + bool use_asymmetric_padding; + GetPadding(str_padding, &use_asymmetric_padding, &prepadding, &postpadding); +#if !TRT_VERSION_GE(5, 1, 5) + if (use_asymmetric_padding) { + auto pad_layer = params->network->addPadding(*input_tensor, prepadding, postpadding); + CHECK(pad_layer != nullptr); + input_tensor = pad_layer->getOutput(0); + // No need for conv op to do any padding. + use_asymmetric_padding = false; + prepadding = nvinfer1::DimsHW(0, 0); + } +#endif + + // Could use conv2d_attr->channels.as()->value + const int num_outputs = weight_shape[1]; + const auto kernel_size = nvinfer1::DimsHW(weight_shape[2], weight_shape[3]); + nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT, nullptr, 0}; + auto deconv_layer = params->network->addDeconvolution(*input_tensor, num_outputs, kernel_size, + params->inputs.at(1).weight, bias); + CHECK(deconv_layer != nullptr); + if (use_asymmetric_padding) { +#if TRT_VERSION_GE(5, 1, 5) + deconv_layer->setPrePadding(prepadding); + deconv_layer->setPostPadding(postpadding); +#endif + } else { + deconv_layer->setPadding(prepadding); + } + const auto strides = nvinfer1::DimsHW(std::stoi(str_strides[0]), std::stoi(str_strides[1])); + deconv_layer->setStride(strides); + deconv_layer->setNbGroups(groups); + nvinfer1::ITensor* output = deconv_layer->getOutput(0); + // Output padding. + if (str_output_padding.size()) { + GetPadding(str_output_padding, &use_asymmetric_padding, &prepadding, &postpadding); + if (prepadding.h() != 0 || prepadding.w() != 0 || postpadding.h() != 0 || + postpadding.w() != 0) { + // Output padding for Conv2D transpose is always asymmetric and applied to post only. + prepadding = nvinfer1::DimsHW(0, 0); + auto pad_layer = params->network->addPadding(*output, prepadding, postpadding); + output = pad_layer->getOutput(0); + } + } + params->outputs.push_back(output); + } +}; + +#if TRT_VERSION_GE(6, 0, 1) +class Conv3DTransposeOpConverter : public TensorRTOpConverter { + public: + Conv3DTransposeOpConverter() : TensorRTOpConverter({kTensor, kWeight}) {} + + void Convert(TensorRTOpConverterParams* params) const { + auto input_tensor = params->inputs.at(0).tensor; + auto weight_shape = params->inputs.at(1).weight_shape; + CHECK_EQ(params->node.GetAttr>("data_layout")[0], "NCDHW"); + CHECK(params->node.GetAttr>("out_layout")[0] == "" || + params->node.GetAttr>("out_layout")[0] == "NCDHW"); + CHECK_EQ(params->node.GetAttr>("kernel_layout")[0], "OIDHW"); + auto str_dilation = params->node.GetAttr>("dilation"); + CHECK_EQ(str_dilation.size(), 3); + CHECK(std::stoi(str_dilation[0]) == 1 && std::stoi(str_dilation[1]) == 1 && + std::stoi(str_dilation[2]) == 1); + auto str_strides = params->node.GetAttr>("strides"); + auto str_padding = params->node.GetAttr>("padding"); + auto str_output_padding = params->node.GetAttr>("output_padding"); + int groups = std::stoi(params->node.GetAttr>("groups")[0]); + nvinfer1::Dims prepadding, postpadding; + bool use_asymmetric_padding; + GetPadding3D(str_padding, &use_asymmetric_padding, &prepadding, &postpadding); + + // Could use attrs->channels.as()->value + const int num_outputs = weight_shape[1]; + const auto kernel_size = nvinfer1::Dims3(weight_shape[2], weight_shape[3], weight_shape[4]); + nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT, nullptr, 0}; + auto deconv_layer = params->network->addDeconvolutionNd(*input_tensor, num_outputs, kernel_size, + params->inputs.at(1).weight, bias); + CHECK(deconv_layer != nullptr); + if (use_asymmetric_padding) { + deconv_layer->setPrePadding(prepadding); + deconv_layer->setPostPadding(postpadding); + } else { + deconv_layer->setPaddingNd(prepadding); + } + CHECK_EQ(str_strides.size(), 3); + const auto strides = nvinfer1::Dims3(std::stoi(str_strides[0]), std::stoi(str_strides[1]), + std::stoi(str_strides[2])); + deconv_layer->setStrideNd(strides); + deconv_layer->setNbGroups(groups); + nvinfer1::ITensor* output = deconv_layer->getOutput(0); + // Output padding. + if (str_output_padding.size()) { + GetPadding3D(str_output_padding, &use_asymmetric_padding, &prepadding, &postpadding); + // Are any post-padding values non-zero? + CHECK(!std::any_of(postpadding.d, postpadding.d + postpadding.nbDims, [](int x) { + return x != 0; + })) << "TRT does not support padding on 3 dimensions."; + } + params->outputs.push_back(output); + } +}; +#endif // TRT_VERSION_GE(6, 0, 1) + +class TransposeOpConverter : public TensorRTOpConverter { + public: + TransposeOpConverter() : TensorRTOpConverter({kTensor}) {} + + void Convert(TensorRTOpConverterParams* params) const { + auto input = params->inputs.at(0).tensor; + auto str_axes = params->node.GetAttr>("axes"); + std::vector order; + for (size_t i = 0; i < str_axes.size(); ++i) { + order.push_back(std::stoi(str_axes[i])); + } + params->outputs.push_back(Transpose(params, input, order)); + } +}; + +class LayoutTransformOpConverter : public TensorRTOpConverter { + public: + LayoutTransformOpConverter() : TensorRTOpConverter({kTensor}) {} + + void Convert(TensorRTOpConverterParams* params) const { + auto input = params->inputs.at(0).tensor; + auto src = params->node.GetAttr>("src_layout")[0]; + auto dst = params->node.GetAttr>("dst_layout")[0]; + std::vector order; + if (src == "NCHW" && dst == "NHWC") { + order = {0, 2, 3, 1}; + } else if (src == "NHWC" && dst == "NCHW") { + order = {0, 3, 1, 2}; + } else if (src == "NDHWC" && dst == "NCDHW") { + order = {0, 4, 1, 2, 3}; + } else if (src == "NCDHW" && dst == "NDHWC") { + order = {0, 2, 3, 4, 1}; + } + params->outputs.push_back(Transpose(params, input, order)); + } +}; + +class ReshapeOpConverter : public TensorRTOpConverter { + public: + ReshapeOpConverter() : TensorRTOpConverter({kTensor}) {} + + void Convert(TensorRTOpConverterParams* params) const { + auto input = params->inputs.at(0).tensor; + CHECK_EQ(std::stoi(params->node.GetAttr>("reverse")[0]), false); + auto str_newshape = params->node.GetAttr>("newshape"); + std::vector new_shape; + const int start_index = TRT_HAS_IMPLICIT_BATCH(params) ? 1 : 0; + for (size_t i = start_index; i < str_newshape.size(); ++i) { + const int value = std::stoi(str_newshape[i]); + CHECK_GE(value, -1); + new_shape.push_back(value); + } + params->outputs.push_back(Reshape(params, input, new_shape)); + } +}; + +class PadOpConverter : public TensorRTOpConverter { + public: + PadOpConverter() : TensorRTOpConverter({kTensor}) {} + + void Convert(TensorRTOpConverterParams* params) const { + auto input = params->inputs.at(0).tensor; + auto str_paddding = params->node.GetAttr>("padding"); + nvinfer1::DimsHW prepadding = + nvinfer1::DimsHW(std::stoi(str_paddding[0]), std::stoi(str_paddding[1])); + nvinfer1::DimsHW postpadding = + nvinfer1::DimsHW(std::stoi(str_paddding[2]), std::stoi(str_paddding[3])); + auto pad_layer = params->network->addPadding(*input, prepadding, postpadding); + params->outputs.push_back(pad_layer->getOutput(0)); + } +}; + +class ReduceOpConverter : public TensorRTOpConverter { + public: + ReduceOpConverter() : TensorRTOpConverter({kTensor}) {} + + void Convert(TensorRTOpConverterParams* params) const { + static const std::unordered_map op_map = { + {"sum", nvinfer1::ReduceOperation::kSUM}, + {"prod", nvinfer1::ReduceOperation::kPROD}, + {"max", nvinfer1::ReduceOperation::kMAX}, + {"min", nvinfer1::ReduceOperation::kMIN}, + {"mean", nvinfer1::ReduceOperation::kAVG}}; + auto it = op_map.find(params->op_name); + CHECK(it != op_map.end()) << "Unsupported reduce type " << params->op_name; + + auto input = params->inputs.at(0).tensor; + CHECK_EQ(std::stoi(params->node.GetAttr>("exclude")[0]), false); + bool keepdims = std::stoi(params->node.GetAttr>("keepdims")[0]); + auto str_axis = params->node.GetAttr>("axis"); + // TODO(trevmorr): Support reduce to scalar. + CHECK_GT(str_axis.size(), 0); + uint32_t reduce_axes = 0; + for (size_t i = 0; i < str_axis.size(); ++i) { + const int axis = ConvertAxis(params, std::stoi(str_axis[i]), input->getDimensions().nbDims); + reduce_axes |= 1 << axis; + } + auto reduce_layer = params->network->addReduce(*input, it->second, reduce_axes, keepdims); + params->outputs.push_back(reduce_layer->getOutput(0)); + } +}; + +#if TRT_VERSION_GE(5, 1, 5) +class StridedSliceOpConverter : public TensorRTOpConverter { + public: + StridedSliceOpConverter() : TensorRTOpConverter({kTensor, kWeight, kWeight, kWeight}) {} + + void Convert(TensorRTOpConverterParams* params) const { + auto input = params->inputs.at(0).tensor; + auto input_dims = TrtDimsToVector(input->getDimensions()); + auto str_start = params->node.GetAttr>("start"); + auto str_size = params->node.GetAttr>("size"); + auto str_strides = params->node.GetAttr>("strides"); + std::vector start, size, strides; + std::transform(str_start.begin(), str_start.end(), std::back_inserter(start), + [](const std::string& s) { return std::stoi(s); }); + std::transform(str_size.begin(), str_size.end(), std::back_inserter(size), + [](const std::string& s) { return std::stoi(s); }); + std::transform(str_strides.begin(), str_strides.end(), std::back_inserter(strides), + [](const std::string& s) { return std::stoi(s); }); + if (TRT_HAS_IMPLICIT_BATCH(params)) { + start.erase(start.begin()); + size.erase(size.begin()); + strides.erase(strides.begin()); + } + auto slice_layer = params->network->addSlice(*input, VectorToTrtDims(start), + VectorToTrtDims(size), VectorToTrtDims(strides)); + params->outputs.push_back(slice_layer->getOutput(0)); + } +}; +#endif + +class AdaptivePoolingOpConverter : public TensorRTOpConverter { + public: + AdaptivePoolingOpConverter() : TensorRTOpConverter({kTensor}) {} + + void Convert(TensorRTOpConverterParams* params) const { + auto input_tensor = params->inputs.at(0).tensor; + auto input_dims = TrtDimsToVector(input_tensor->getDimensions()); + static const std::unordered_map op_map = { + {"nn.adaptive_max_pool2d", nvinfer1::PoolingType::kMAX}, + {"nn.adaptive_avg_pool2d", nvinfer1::PoolingType::kAVERAGE}}; + auto it = op_map.find(params->op_name); + CHECK(it != op_map.end()) << "Unsupported pooling type " << params->op_name << " in TensorRT"; + CHECK_EQ(params->node.GetAttr>("layout")[0], "NCHW"); + + // This is an approximation of adaptive pooling. Results will not be + // mathematically exact except when output_size is (1, 1). + // Annotation rules will only allow output size of (1, 1). + auto output_size = nvinfer1::DimsHW(1, 1); + const int h = TRT_HAS_IMPLICIT_BATCH(params) ? input_dims[1] : input_dims[2]; + const int w = TRT_HAS_IMPLICIT_BATCH(params) ? input_dims[2] : input_dims[3]; + const auto stride = nvinfer1::DimsHW(h / output_size.h(), w / output_size.w()); + const auto window_size = nvinfer1::DimsHW(h - (output_size.h() - 1) * stride.h(), + w - (output_size.w() - 1) * stride.w()); + auto pool_layer = params->network->addPooling(*input_tensor, it->second, window_size); + CHECK(pool_layer != nullptr); + pool_layer->setStride(stride); + params->outputs.push_back(pool_layer->getOutput(0)); + } +}; + +const std::shared_ptr>> +GetOpConverters() { + static auto map = + std::make_shared>>(); + if (!map->empty()) return map; + map->emplace("nn.relu", std::make_shared()); + map->emplace("sigmoid", std::make_shared()); + map->emplace("tanh", std::make_shared()); + map->emplace("nn.batch_norm", std::make_shared()); + map->emplace("nn.softmax", std::make_shared()); + map->emplace("nn.conv2d", std::make_shared()); + map->emplace("nn.dense", std::make_shared()); + map->emplace("nn.bias_add", std::make_shared()); + map->emplace("add", std::make_shared()); + map->emplace("subtract", std::make_shared()); + map->emplace("multiply", std::make_shared()); + map->emplace("divide", std::make_shared()); + map->emplace("power", std::make_shared()); + map->emplace("maximum", std::make_shared()); + map->emplace("minimum", std::make_shared()); + map->emplace("nn.max_pool2d", std::make_shared()); + map->emplace("nn.avg_pool2d", std::make_shared()); + map->emplace("nn.global_max_pool2d", std::make_shared()); + map->emplace("nn.global_avg_pool2d", std::make_shared()); + map->emplace("exp", std::make_shared()); + map->emplace("log", std::make_shared()); + map->emplace("sqrt", std::make_shared()); + map->emplace("abs", std::make_shared()); + map->emplace("negative", std::make_shared()); + map->emplace("nn.batch_flatten", std::make_shared()); + map->emplace("expand_dims", std::make_shared()); + map->emplace("squeeze", std::make_shared()); + map->emplace("concatenate", std::make_shared()); + map->emplace("nn.conv2d_transpose", std::make_shared()); + map->emplace("transpose", std::make_shared()); + map->emplace("layout_transform", std::make_shared()); + map->emplace("reshape", std::make_shared()); + map->emplace("nn.pad", std::make_shared()); + map->emplace("sum", std::make_shared()); + map->emplace("prod", std::make_shared()); + map->emplace("max", std::make_shared()); + map->emplace("min", std::make_shared()); + map->emplace("mean", std::make_shared()); + map->emplace("nn.adaptive_max_pool2d", std::make_shared()); + map->emplace("nn.adaptive_avg_pool2d", std::make_shared()); +#if TRT_VERSION_GE(5, 1, 5) + map->emplace("clip", std::make_shared()); + map->emplace("nn.leaky_relu", std::make_shared()); + map->emplace("sin", std::make_shared()); + map->emplace("cos", std::make_shared()); + map->emplace("atan", std::make_shared()); + map->emplace("ceil", std::make_shared()); + map->emplace("floor", std::make_shared()); + map->emplace("strided_slice", std::make_shared()); +#endif // TRT_VERSION_GE(5, 1, 5) +#if TRT_VERSION_GE(6, 0, 1) + map->emplace("nn.conv3d", std::make_shared()); + map->emplace("nn.max_pool3d", std::make_shared()); + map->emplace("nn.avg_pool3d", std::make_shared()); + map->emplace("nn.conv3d_transpose", std::make_shared()); +#endif // TRT_VERSION_GE(6, 0, 1) + return map; +} + +} // namespace contrib +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/contrib/tensorrt/tensorrt_ops.h b/src/runtime/contrib/tensorrt/tensorrt_ops.h new file mode 100644 index 000000000000..e9871d42146c --- /dev/null +++ b/src/runtime/contrib/tensorrt/tensorrt_ops.h @@ -0,0 +1,207 @@ +/* * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file runtime/contrib/tensorrt/tensorrt_ops.h + * \brief Converters from Relay ops into TensorRT layers. Converters should + * inherit from TensorRTOpConverter and implement the Convert() method. + */ + +#ifndef TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_OPS_H_ +#define TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_OPS_H_ + +#include +#include +#include +#include +#include +#include + +#include "../json/json_node.h" +#include "NvInfer.h" +#include "tensorrt_utils.h" + +#if TRT_VERSION_GE(6, 0, 1) +#define TRT_HAS_IMPLICIT_BATCH(params) (params->network->hasImplicitBatchDimension()) +#else +#define TRT_HAS_IMPLICIT_BATCH(params) (true) +#endif + +namespace tvm { +namespace runtime { +namespace contrib { + +using JSONGraphNode = tvm::runtime::json::JSONGraphNode; + +/*! + * \brief An input to a op may be either kTensor in the case of nvinfer::ITensor* + * or kWeight for nvinfer1::Weights. + */ +enum TensorRTInputType { + kTensor, + kWeight, +}; + +/*! + * \brief An input to a TensorRTOpConverter. The type of the input is either kTensor + * or kWeight. For kTensor, "tensor" contains the input tensor. For kWeight, + * "weight" contains the input weight and "weight_shape" contains the shape. + */ +struct TensorRTOpInput { + /*! \brief If type is kTensor, will store input tensor. */ + nvinfer1::ITensor* tensor; + + /*! \brief If type is kWeight, will store input weight. */ + nvinfer1::Weights weight; + + /*! \brief Whether the input is in tensor or weight. */ + TensorRTInputType type; + + /*! \brief If type is kWeight, will store weight shape. */ + std::vector weight_shape; + + explicit TensorRTOpInput(nvinfer1::ITensor* tensor) + : tensor(tensor), weight({nvinfer1::DataType::kFLOAT, nullptr, 0}), type(kTensor) {} + TensorRTOpInput(nvinfer1::Weights weight, const std::vector& shape) + : tensor(nullptr), weight(weight), type(kWeight), weight_shape(shape) {} +}; + +/*! \brief Parameters to convert an Op from Relay to TensorRT. */ +struct TensorRTOpConverterParams { + /*! \brief The TRT network that the new layer should be added to. */ + nvinfer1::INetworkDefinition* network; + /*! \brief The corresponding serialized node. */ + const JSONGraphNode& node; + /*! \brief The type of op. */ + std::string op_name; + /*! \brief Inputs to the op. */ + std::vector inputs; + /*! \brief Outputs of the op should be populated here during Convert(). */ + std::vector outputs; + /*! \brief Any newly allocated weights should be stored here also. */ + std::vector* trt_weights; + + TensorRTOpConverterParams(nvinfer1::INetworkDefinition* network, const JSONGraphNode& node, + std::vector* trt_weights) + : network(network), node(node), trt_weights(trt_weights) { + op_name = node.GetOpName(); + } +}; + +/*! \brief Base class for an op converter from Relay to TRT. */ +class TensorRTOpConverter { + public: + /*! \brief Used to specify whether each input is tensor or weight. */ + const std::vector input_types; + /*! \brief If set to true, any number of tensor inputs can be used for the op. + */ + const bool variable_input_count; + + /*! + * \brief Converter subclasses should call this constructor to set + * input_types or variable_input_count. + * \param input_types For each input to the op, there should be a + * corresponding entry in input_types to determine whether that input should + * be a tensor or a weight. TensorRTBuilder will prepare inputs in + * TensorRTOpConverter according to this. + * \param variable_input_count If the op can have multiple inputs, set this to + * true. input_types vector will be ignored and any number of input tensors + * can be used for this op. All inputs will be tensors and not weights. + */ + explicit TensorRTOpConverter(const std::vector& input_types, + bool variable_input_count = false); + + /*! + * \brief Convert to TRT. Implementation should use inputs and attributes + * from the CallNode to add the corresponding TRT layers to network. Outputs + * should be pushed to outputs vector. + * \param params Parameters for this op. + */ + virtual void Convert(TensorRTOpConverterParams* params) const = 0; + + /*! + * \brief Helper function to reshape a tensor. + * \param params Parameters for this op. + * \param input Tensor to reshape. + * \param new_shape New shape, does not include batch dim. + * \return Reshaped tensor + */ + nvinfer1::ITensor* Reshape(TensorRTOpConverterParams* params, nvinfer1::ITensor* input, + const std::vector& new_shape) const; + + /*! + * \brief Helper function to transpose a tensor. + * \param params Parameters for this op. + * \param input Tensor to transpose. + * \param order New order of axes, does include batch dim. + * \return Transposed tensor + */ + nvinfer1::ITensor* Transpose(TensorRTOpConverterParams* params, nvinfer1::ITensor* input, + const std::vector& order) const; + + /*! + * \brief Helper function to convert an axis to TRT format. + * \param axis Axis from TVM. + * \param input_rank Rank of input, does not include batch dim. + * \return Axis in TRT format. + */ + int ConvertAxis(TensorRTOpConverterParams* params, int axis, int input_rank) const; + + /*! + * \brief Create constant that is broadcastable. + * \param params Parameters for this op. + * \param value Value of scalar. + * \param broadcast_to_dims Dims that scalar should be broadcastable against. + * \return Constant tensor. + */ + nvinfer1::ITensor* CreateScalar(TensorRTOpConverterParams* params, float value, + const nvinfer1::Dims& broadcast_to_dims) const; + + /*! + * \brief Get pre/post padding values from padding attributes array. + * \param padding Serialized padding from op attributes. + * \param padding_is_asymmetric True if both pre and post are needed for asymmetric padding. + * \param prepadding Prepadding value or symmetric padding values if !padding_is_asymmetric. + * \param postpadding Postpadding value if padding_is_asymmetric. + */ + void GetPadding(const std::vector& padding, bool* use_asymmetric_padding, + nvinfer1::DimsHW* prepadding, nvinfer1::DimsHW* postpadding) const; + + /*! + * \brief Get pre/post padding values from padding attributes array for volumetric ops. + * \param padding Serialized padding from op attributes. + * \param padding_is_asymmetric True if both pre and post are needed for asymmetric padding. + * \param prepadding Prepadding value or symmetric padding values if !padding_is_asymmetric. + * \param postpadding Postpadding value if padding_is_asymmetric. + */ + void GetPadding3D(const std::vector& padding, bool* use_asymmetric_padding, + nvinfer1::Dims* prepadding, nvinfer1::Dims* postpadding) const; +}; + +/*! + * \brief Get the map of available TensorRTOpConverters, where the key is the name of the relay op. + * \return Map of TensorRTOpConverters. + */ +const std::shared_ptr>> +GetOpConverters(); + +} // namespace contrib +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_OPS_H_ diff --git a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc new file mode 100644 index 000000000000..72c025695f7d --- /dev/null +++ b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc @@ -0,0 +1,312 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/contrib/tensorrt/tensorrt_runtime.cc + * \brief JSON runtime implementation for TensorRT. + */ + +#include +#include +#include + +#include + +#include "../../file_utils.h" +#include "../json/json_node.h" +#include "../json/json_runtime.h" + +#ifdef TVM_GRAPH_RUNTIME_TENSORRT +#include "NvInfer.h" +#include "tensorrt_builder.h" +#endif + +namespace tvm { +namespace runtime { +namespace contrib { + +using namespace tvm::runtime::json; + +class TensorRTRuntime : public JSONRuntimeBase { + public: + /*! + * \brief The TensorRT runtime module. Deserialize the provided functions + * on creation and store in the layer cache. + * + * \param symbol_name The name of the function. + * \param graph_json serialized JSON representation of a sub-graph. + * \param const_names The names of each constant in the sub-graph. + */ + explicit TensorRTRuntime(const std::string& symbol_name, const std::string& graph_json, + const Array& const_names) + : JSONRuntimeBase(symbol_name, graph_json, const_names), + use_implicit_batch_(true), + max_workspace_size_(size_t(1) << 30) {} + + /*! + * \brief The type key of the module. + * + * \return module type key. + */ + const char* type_key() const override { return "tensorrt"; } + + /*! + * \brief Initialize runtime. Create TensorRT layer from JSON + * representation. + * + * \param consts The constant params from compiled model. + */ + void Init(const Array& consts) override { + CHECK_EQ(consts.size(), const_idx_.size()) + << "The number of input constants must match the number of required."; + LoadGlobalAttributes(); + if (GetCachedEnginesFromDisk()) return; + SetupConstants(consts); + BuildEngine(); + CacheEngineToDisk(); + } + + void LoadGlobalAttributes() { + // These settings are global to the entire subgraph. Codegen will add them as attributes to all + // op nodes. Read from first one. + for (size_t i = 0; i < nodes_.size(); ++i) { + if (nodes_[i].HasAttr("use_implicit_batch") && nodes_[i].HasAttr("max_workspace_size")) { + use_implicit_batch_ = + std::stoi(nodes_[i].GetAttr>("use_implicit_batch")[0]); + // Allow max_workspace_size to be overridden at runtime. + size_t runtime_max_workspace_size = + dmlc::GetEnv("TVM_TENSORRT_MAX_WORKSPACE_SIZE", size_t(0)); + if (runtime_max_workspace_size != 0) { + max_workspace_size_ = runtime_max_workspace_size; + } else { + max_workspace_size_ = + std::stoul(nodes_[i].GetAttr>("max_workspace_size")[0]); + } + return; + } + } + } + +#ifdef TVM_GRAPH_RUNTIME_TENSORRT + /*! \brief Run inference using built engine. */ + void Run() override { + auto& engine_and_context = trt_engine_cache_.at(symbol_name_); + auto engine = engine_and_context.engine; + auto context = engine_and_context.context; + std::vector bindings(engine->getNbBindings(), nullptr); + + for (size_t i = 0; i < input_nodes_.size(); ++i) { + auto nid = input_nodes_[i]; + if (nodes_[nid].GetOpType() == "input") { + for (size_t j = 0; j < nodes_[nid].GetOpShape().size(); ++j) { + uint32_t eid = EntryID(nid, j); + const std::string name = nodes_[nid].GetOpName() + "_" + std::to_string(j); + int binding_index = engine->getBindingIndex(name.c_str()); + CHECK_NE(binding_index, -1); + bindings[binding_index] = data_entry_[eid]->data; + } + } + } + + for (size_t i = 0; i < outputs_.size(); ++i) { + uint32_t eid = EntryID(outputs_[i]); + const std::string& name = engine_and_context.outputs[i]; + int binding_index = engine->getBindingIndex(name.c_str()); + CHECK_NE(binding_index, -1); + bindings[binding_index] = data_entry_[eid]->data; + } + +#if TRT_VERSION_GE(6, 0, 1) + if (use_implicit_batch_) { + CHECK(context->execute(batch_size_, bindings.data())) << "Running TensorRT failed."; + } else { + CHECK(context->executeV2(bindings.data())) << "Running TensorRT failed."; + } +#else + CHECK(context->execute(batch_size_, bindings.data())) << "Running TensorRT failed."; +#endif + } + + private: + /*! + * \brief Build TensorRT engine from JSON representation. + */ + void BuildEngine() { + DLOG(INFO) << "Building new TensorRT engine for subgraph " << symbol_name_; + const bool use_fp16 = dmlc::GetEnv("TVM_TENSORRT_USE_FP16", false); + batch_size_ = GetBatchSize(); + TensorRTBuilder builder(&logger_, max_workspace_size_, use_implicit_batch_, use_fp16, + batch_size_); + + // Add inputs and constants. + for (size_t i = 0; i < input_nodes_.size(); ++i) { + auto nid = input_nodes_[i]; + const auto& node = nodes_[nid]; + std::string name = node.GetOpName(); + if (node.GetOpType() == "input") { + builder.AddInput(nid, node); + } else { + CHECK_EQ(node.GetOpType(), "const"); + uint32_t eid = EntryID(nid, 0); + builder.AddConstant(nid, data_entry_[eid]); + } + } + + // Add layers. + for (size_t nid = 0; nid < nodes_.size(); ++nid) { + const auto& node = nodes_[nid]; + if (node.GetOpType() != "kernel") continue; + builder.AddLayer(nid, node); + } + + // Add outputs. + for (size_t i = 0; i < outputs_.size(); ++i) { + builder.AddOutput(outputs_[i]); + } + + // Build engine. + trt_engine_cache_[symbol_name_] = builder.BuildEngine(); + DLOG(INFO) << "Finished building TensorRT engine for subgraph " << symbol_name_; + } + + /*! \brief If TVM_TENSORRT_CACHE_DIR is set, will check that directory for + * already built TRT engines and load into trt_engine_cache_ so they don't + * have to be built at first inference. + */ + bool GetCachedEnginesFromDisk() { + std::string cache_dir = dmlc::GetEnv("TVM_TENSORRT_CACHE_DIR", std::string("")); + if (cache_dir.empty()) return false; + std::string key = GetSubgraphKey(); + std::string path = cache_dir + "/" + key + ".plan"; + // Check if engine is in the cache. + std::ifstream infile(path, std::ios::binary); + if (!infile.good()) return false; + DLOG(INFO) << "Loading cached TensorRT engine from " << path; + infile.close(); + std::string serialized_engine; + LoadBinaryFromFile(path, &serialized_engine); + // Deserialize engine + nvinfer1::IRuntime* runtime = nvinfer1::createInferRuntime(logger_); + TensorRTEngineAndContext engine_and_context; + engine_and_context.engine = + runtime->deserializeCudaEngine(&serialized_engine[0], serialized_engine.size(), nullptr); + engine_and_context.context = engine_and_context.engine->createExecutionContext(); + // Load metadata + std::string meta_path = cache_dir + "/" + key + ".meta"; + std::string serialized_meta; + LoadBinaryFromFile(meta_path, &serialized_meta); + std::istringstream is(serialized_meta); + dmlc::JSONReader reader(&is); + dmlc::JSONObjectReadHelper helper; + helper.DeclareField("inputs", &engine_and_context.inputs); + helper.DeclareField("outputs", &engine_and_context.outputs); + helper.ReadAllFields(&reader); + trt_engine_cache_[symbol_name_] = engine_and_context; + return true; + } + + /*! \brief If TVM_TENSORRT_CACHE_DIR is set, will save the engine to that + * directory so it can be loaded later. + */ + void CacheEngineToDisk() { + std::string cache_dir = dmlc::GetEnv("TVM_TENSORRT_CACHE_DIR", std::string("")); + if (cache_dir.empty()) return; + std::string key = GetSubgraphKey(); + std::string path = cache_dir + "/" + key + ".plan"; + DLOG(INFO) << "Caching TensorRT engine to " << path; + // Serialize engine to disk + nvinfer1::IHostMemory* serialized_engine = trt_engine_cache_[symbol_name_].engine->serialize(); + SaveBinaryToFile(path, std::string(static_cast(serialized_engine->data()), + serialized_engine->size())); + serialized_engine->destroy(); + // Serialize metadata + std::ostringstream os; + dmlc::JSONWriter writer(&os); + writer.BeginObject(); + writer.WriteObjectKeyValue("inputs", trt_engine_cache_[symbol_name_].inputs); + writer.WriteObjectKeyValue("outputs", trt_engine_cache_[symbol_name_].outputs); + writer.EndObject(); + std::string meta_path = cache_dir + "/" + key + ".meta"; + SaveBinaryToFile(meta_path, os.str()); + } + + std::string GetSubgraphKey() { + // Using this key will only allow a single model per TVM_TENSORRT_CACHE_DIR directory. We could + // instead use a hash of graph_json and all weights to allow many models in the same directory, + // but the cost of computing the hash is high. + return symbol_name_ + (dmlc::GetEnv("TVM_TENSORRT_USE_FP16", false) ? "_fp16" : "_fp32"); + } + + /*! \brief Get the batch size when in implicit_batch mode. */ + int GetBatchSize() { + if (!use_implicit_batch_) return -1; + for (size_t i = 0; i < input_nodes_.size(); ++i) { + auto nid = input_nodes_[i]; + if (nodes_[nid].GetOpType() == "input") { + // Get batch size from first input. + return nodes_[nid].GetOpShape()[0][0]; + } + } + return -1; + } + + /*! \brief Map of function name to TRT engine if built already. */ + std::unordered_map trt_engine_cache_; + + /*! \brief TensorRT logger. */ + TensorRTLogger logger_; + + /*! \brief Batch size that the engine is optimized for. */ + int batch_size_; + +#else + void Run() override { + LOG(FATAL) << "TensorRT runtime is not enabled. " + << "Please build with USE_TENSORRT_RUNTIME."; + } + + void BuildEngine() { + LOG(WARNING) << "TensorRT runtime is not enabled. " + << "Please build with USE_TENSORRT_RUNTIME."; + } + + bool GetCachedEnginesFromDisk() { return false; } + + void CacheEngineToDisk() {} +#endif + + bool use_implicit_batch_; + + size_t max_workspace_size_; +}; + +runtime::Module TensorRTRuntimeCreate(const String& symbol_name, const String& graph_json, + const Array& const_names) { + auto n = make_object(symbol_name, graph_json, const_names); + return runtime::Module(n); +} + +TVM_REGISTER_GLOBAL("runtime.tensorrt_runtime_create").set_body_typed(TensorRTRuntimeCreate); + +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_tensorrt") + .set_body_typed(JSONRuntimeBase::LoadFromBinary); + +} // namespace contrib +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/contrib/tensorrt/tensorrt_utils.h b/src/runtime/contrib/tensorrt/tensorrt_utils.h new file mode 100644 index 000000000000..ab9b169f26d6 --- /dev/null +++ b/src/runtime/contrib/tensorrt/tensorrt_utils.h @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file runtime/contrib/tensorrt/utils.h + * \brief Helper functions used by TensorRTBuilder or TensorRTOpConverters. + */ + +#ifndef TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_UTILS_H_ +#define TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_UTILS_H_ + +#include +#include + +#include "NvInfer.h" + +// There is a conflict between cpplint and clang-format-10. +// clang-format off +#define TRT_VERSION_GE(major, minor, patch) \ + ((NV_TENSORRT_MAJOR > major) || (NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR > minor) || \ + (NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR == minor && NV_TENSORRT_PATCH >= patch)) +// clang-format on + +namespace tvm { +namespace runtime { +namespace contrib { + +/*! + * \brief Helper function to convert an vector to TRT Dims. + * \param vec Vector. + * \return TRT Dims. + */ +template +inline nvinfer1::Dims VectorToTrtDims(const std::vector& vec) { + nvinfer1::Dims dims; + // Dims(nbDims=0, d[0]=1) is used to represent a scalar in TRT. + dims.d[0] = 1; + dims.nbDims = vec.size(); + for (size_t i = 0; i < vec.size(); ++i) { + dims.d[i] = vec[i]; + } + return dims; +} + +/*! + * \brief Helper function to convert TRT Dims to vector. + * \param vec TRT Dims. + * \return Vector. + */ +inline std::vector TrtDimsToVector(const nvinfer1::Dims& dims) { + return std::vector(dims.d, dims.d + dims.nbDims); +} + +} // namespace contrib +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_UTILS_H_ diff --git a/tests/python/contrib/test_tensorrt.py b/tests/python/contrib/test_tensorrt.py new file mode 100644 index 000000000000..6f615397db58 --- /dev/null +++ b/tests/python/contrib/test_tensorrt.py @@ -0,0 +1,905 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import time +import pytest + +import tvm +import tvm.relay.testing +from tvm import relay +from tvm.relay.op.contrib import tensorrt +from tvm.contrib import graph_runtime + + +def skip_codegen_test(): + """Skip test if TensorRT and CUDA codegen are not present""" + if not tvm.runtime.enabled("cuda") or not tvm.gpu(0).exist: + print("Skip because CUDA is not enabled.") + return True + if not tvm.get_global_func("relay.ext.tensorrt", True): + print("Skip because TensorRT codegen is not available.") + return True + return False + + +def skip_runtime_test(): + if not tvm.runtime.enabled("cuda") or not tvm.gpu(0).exist: + print("Skip because CUDA is not enabled.") + return True + if not tensorrt.is_tensorrt_runtime_enabled(): + print("Skip because TensorRT runtime is not available.") + return True + return False + + +def run_and_verify_func(config): + """Test a Relay func by compiling, running, and comparing TVM and TRT outputs. + + Parameters + ---------- + config : Tuple[relay.Function, Dict[str, NDArray], List[str]] + A tuple containing 1) The function to test, 2) A dictionary of var names to input shapes and + 3) A list of which vars should be considered params. + """ + if skip_codegen_test(): + return + f, input_shapes, is_param = config + params = {x: np.random.uniform(-1, 1, input_shapes[x]).astype(np.float32) for x in is_param} + input_dict = { + k: np.random.uniform(-1, 1, v).astype(np.float32) + for k, v in input_shapes.items() + if k not in is_param + } + + # Run TRT + mod = tvm.IRModule() + mod["main"] = f + mod, config = tensorrt.partition_for_tensorrt(mod, params) + with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}): + graph, lib, graph_params = relay.build(mod, "cuda", params=params) + if skip_runtime_test(): + return + mod = graph_runtime.create(graph, lib, ctx=tvm.gpu(0)) + mod.set_input(**graph_params) + mod.run(**input_dict) + results = [mod.get_output(i) for i in range(mod.get_num_outputs())] + + # Run reference + mod = tvm.IRModule() + mod["main"] = f + with tvm.transform.PassContext(opt_level=3): + graph, lib, graph_params = relay.build(mod, "cuda", params=params) + mod = graph_runtime.create(graph, lib, ctx=tvm.gpu(0)) + mod.set_input(**graph_params) + mod.run(**input_dict) + ref_results = [mod.get_output(i) for i in range(mod.get_num_outputs())] + + assert len(results) == len(ref_results) + for i in range(len(results)): + res = results[i].asnumpy() + ref_res = ref_results[i].asnumpy() + assert res.shape == ref_res.shape + tvm.testing.assert_allclose(res, ref_res, rtol=1e-3, atol=1e-3) + + +def run_and_verify_model(model): + if skip_codegen_test(): + return + + def compile_and_run(i_data, input_shape, dtype, use_trt=True, num_iteration=1): + import mxnet as mx + from mxnet.gluon.model_zoo.vision import get_model + + def check_trt_used(graph): + import json + + graph = json.loads(graph) + num_trt_subgraphs = sum( + [ + 1 + for n in graph["nodes"] + if n.get("attrs", {}).get("func_name", "").startswith("tensorrt_") + ] + ) + assert num_trt_subgraphs >= 1 + + block = get_model(model, pretrained=True) + mod, params = relay.frontend.from_mxnet(block, shape={"data": input_shape}, dtype=dtype) + + if use_trt: + mod, config = tensorrt.partition_for_tensorrt(mod, params) + with tvm.transform.PassContext( + opt_level=3, config={"relay.ext.tensorrt.options": config} + ): + graph, lib, params = relay.build(mod, "cuda", params=params) + check_trt_used(graph) + else: + with tvm.transform.PassContext(opt_level=3): + graph, lib, params = relay.build(mod, "cuda", params=params) + + if skip_runtime_test(): + return + mod = graph_runtime.create(graph, lib, ctx=tvm.gpu(0)) + mod.set_input(**params) + # Warmup + for i in range(10): + mod.run(data=i_data) + # Time + times = [] + for i in range(num_iteration): + start_time = time.time() + mod.run(data=i_data) + res = mod.get_output(0) + times.append(time.time() - start_time) + latency = 1000.0 * np.mean(times) + print(model, latency) + return res + + dtype = "float32" + input_shape = (1, 3, 224, 224) + i_data = np.random.uniform(-1, 1, input_shape).astype(dtype) + res = compile_and_run(i_data, input_shape, dtype, use_trt=True) + if skip_runtime_test(): + return + ref_res = compile_and_run(i_data, input_shape, dtype, use_trt=False) + tvm.testing.assert_allclose(res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-3) + + +def test_tensorrt_simple(): + if skip_codegen_test(): + return + dtype = "float32" + xshape = (1, 3, 2, 2) + yshape = (1, 3, 1, 1) + zshape = (1, 1, 1, 1) + x = relay.var("x", shape=(xshape), dtype=dtype) + y = relay.var("y", shape=(yshape), dtype=dtype) + z = relay.var("z", shape=(zshape), dtype=dtype) + w = z * (x + y) + out = relay.nn.relu(w) + f = relay.Function([x, y, z], out) + + mod = tvm.IRModule() + mod["main"] = f + mod, config = tensorrt.partition_for_tensorrt(mod) + with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}): + graph, lib, params = relay.build(mod, "cuda") + if skip_runtime_test(): + return + mod = graph_runtime.create(graph, lib, ctx=tvm.gpu(0)) + x_data = np.random.uniform(-1, 1, xshape).astype(dtype) + y_data = np.random.uniform(-1, 1, yshape).astype(dtype) + z_data = np.random.uniform(-1, 1, zshape).astype(dtype) + mod.run(x=x_data, y=y_data, z=z_data) + results = [mod.get_output(i).asnumpy() for i in range(mod.get_num_outputs())] + + +def test_tensorrt_not_compatible(): + if skip_codegen_test(): + return + dtype = "float32" + xshape = (1, 32, 14, 14) + x = relay.var("x", shape=(xshape), dtype=dtype) + y = relay.add(x, x) + z = relay.erf(y) + out = relay.nn.relu(z) + f = relay.Function([x], out) + mod = tvm.IRModule() + mod["main"] = f + mod, config = tensorrt.partition_for_tensorrt(mod) + with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}): + graph, lib, params = relay.build(mod, "cuda") + if skip_runtime_test(): + return + mod = graph_runtime.create(graph, lib, ctx=tvm.gpu(0)) + x_data = np.random.uniform(-1, 1, xshape).astype(dtype) + mod.run(x=x_data) + results = [mod.get_output(i).asnumpy() for i in range(mod.get_num_outputs())] + + +def test_tensorrt_serialize(): + if skip_codegen_test(): + return + import mxnet + from mxnet.gluon.model_zoo.vision import get_model + + block = get_model("resnet18_v1", pretrained=True) + mod, params = relay.frontend.from_mxnet( + block, shape={"data": (1, 3, 224, 224)}, dtype="float32" + ) + # Compile + mod, config = tensorrt.partition_for_tensorrt(mod, params) + with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}): + lib = relay.build(mod, "cuda", params=params) + # Serialize + lib.export_library("compiled.so") + # Deserialize + loaded_lib = tvm.runtime.load_module("compiled.so") + # Run + if skip_runtime_test(): + return + gen_module = tvm.contrib.graph_runtime.GraphModule(loaded_lib["default"](tvm.gpu(0))) + i_data = np.random.uniform(0, 1, (1, 3, 224, 224)).astype("float32") + gen_module.run(data=i_data) + + +def test_conv2d(): + def get_graph( + x_shape=(1, 32, 8, 8), + k_shape=(16, 32, 3, 3), + groups=1, + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + ): + x = relay.var("x", shape=(x_shape), dtype="float32") + kernel = relay.var("kernel", shape=(k_shape), dtype="float32") + out = relay.nn.conv2d( + x, + kernel, + channels=k_shape[0], + kernel_size=k_shape[2:4], + groups=groups, + padding=padding, + strides=strides, + dilation=dilation, + ) + f = relay.Function([x, kernel], out) + return f, {"x": x_shape, "kernel": k_shape}, ["kernel"] + + for k_shape, groups in [((16, 32, 3, 3), 1), ((32, 1, 3, 3), 32)]: + for padding in [(0, 0), (1, 1)]: + for strides in [(1, 1), (2, 2)]: + for dilation in [(1, 1), (2, 2)]: + run_and_verify_func( + get_graph( + k_shape=k_shape, + groups=groups, + padding=padding, + strides=strides, + dilation=dilation, + ) + ) + + +def test_conv2d_nhwc(): + def get_graph(x_shape=(1, 8, 8, 32), k_shape=(3, 3, 32, 16)): + x = relay.var("x", shape=(x_shape), dtype="float32") + kernel = relay.var("kernel", shape=(k_shape), dtype="float32") + out = relay.nn.conv2d( + x, + kernel, + channels=16, + kernel_size=(3, 3), + data_layout="NHWC", + kernel_layout="HWIO", + ) + f = relay.Function([x, kernel], out) + return f, {"x": x_shape, "kernel": k_shape}, ["kernel"] + + run_and_verify_func(get_graph()) + + +def test_conv2d_weights_const(): + def get_graph( + x_shape=(1, 32, 8, 8), + k_shape=(16, 32, 3, 3), + groups=1, + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + ): + x = relay.var("x", shape=(x_shape), dtype="float32") + kernel = relay.const(np.ones(k_shape).astype("float32")) + out = relay.nn.conv2d( + x, + kernel, + channels=k_shape[0], + kernel_size=k_shape[2:4], + groups=groups, + padding=padding, + strides=strides, + dilation=dilation, + ) + f = relay.Function([x], out) + return f, {"x": x_shape}, [] + + run_and_verify_func(get_graph()) + + +def test_conv2d_weights_transposed(): + def get_graph(x_shape=(1, 32, 9, 9), k_shape=(3, 3, 32, 16), order=(3, 2, 0, 1)): + x = relay.var("x", shape=(x_shape), dtype="float32") + kernel = relay.var("kernel", shape=(k_shape), dtype="float32") + kernel_t = relay.transpose(kernel, order) + # Conv2d requires constant weights in TensorRT, so the weights should be transposed by + # FoldConstant. + out = relay.nn.conv2d(x, kernel_t, channels=k_shape[order[0]], kernel_size=(3, 3)) + f = relay.Function([x, kernel], out) + return f, {"x": x_shape, "kernel": k_shape}, ["kernel"] + + run_and_verify_func(get_graph()) + + +def test_dense(): + def get_graph(x_shape=(1, 16), k_shape=(32, 16)): + x = relay.var("x", shape=(x_shape), dtype="float32") + kernel = relay.var("kernel", shape=(k_shape), dtype="float32") + # Dense requires constant weights in TensorRT, so the weights are transposed by us. + out = relay.nn.dense(x, kernel, units=k_shape[0]) + f = relay.Function([x, kernel], out) + return f, {"x": x_shape, "kernel": k_shape}, ["kernel"] + + run_and_verify_func(get_graph()) + + +def test_bias_add(): + def get_graph(x_shape=(1, 16), channels=16): + x = relay.var("x", shape=(x_shape), dtype="float32") + bias = relay.var("bias", shape=(channels,), dtype="float32") + out = relay.nn.bias_add(x, bias) + f = relay.Function([x, bias], out) + return f, {"x": x_shape, "bias": (channels,)}, ["bias"] + + run_and_verify_func(get_graph()) + run_and_verify_func(get_graph((1, 6, 3, 4), 6)) + + +def test_pool2d(): + def get_graph( + op, + x_shape=(1, 3, 32, 32), + pool_size=(2, 2), + strides=(2, 2), + padding=(0, 0), + ceil_mode=False, + count_include_pad=None, + ): + x = relay.var("x", shape=(x_shape), dtype="float32") + if count_include_pad is not None: + out = op( + x, + pool_size=pool_size, + strides=strides, + padding=padding, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + ) + else: + out = op( + x, + pool_size=pool_size, + strides=strides, + padding=padding, + ceil_mode=ceil_mode, + ) + f = relay.Function([x], out) + return f, {"x": x_shape}, [] + + for pool_size in [(2, 2), (3, 3)]: + for strides in [(1, 1), (2, 2)]: + for padding in [(0, 0), (1, 1), (0, 0, 1, 1)]: + for ceil_mode in [False, True]: + # Skip "the padding size is larger than or equal to the filter size for exclusive-counting pooling" + if pool_size == (2, 2) and padding == (0, 0, 1, 1): + continue + for count_include_pad in [False, True]: + # Skip "inclusive-counted blended or average pooling is not supported in combination with asymmetric padding" + if count_include_pad and (padding == (0, 0, 1, 1) or strides == (2, 2)): + continue + run_and_verify_func( + get_graph( + relay.nn.avg_pool2d, + pool_size=pool_size, + strides=strides, + padding=padding, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + ) + ) + run_and_verify_func( + get_graph( + relay.nn.max_pool2d, + pool_size=pool_size, + strides=strides, + padding=padding, + ceil_mode=ceil_mode, + ) + ) + + +def test_global_pool2d(): + def get_graph(op, x_shape=(1, 3, 32, 32)): + x = relay.var("x", shape=(x_shape), dtype="float32") + out = op(x) + f = relay.Function([x], out) + return f, {"x": x_shape}, [] + + run_and_verify_func(get_graph(relay.nn.global_max_pool2d)) + run_and_verify_func(get_graph(relay.nn.global_avg_pool2d)) + + +def test_batch_flatten(): + def get_graph(x_shape=(1, 3, 4, 6)): + x = relay.var("x", shape=(x_shape), dtype="float32") + out = relay.nn.batch_flatten(x) + f = relay.Function([x], out) + return f, {"x": x_shape}, [] + + run_and_verify_func(get_graph()) + + +def test_expand_dims(): + def get_graph(x_shape=(1, 3), axis=1, num_newaxis=1): + x = relay.var("x", shape=(x_shape), dtype="float32") + out = relay.expand_dims(x, axis, num_newaxis) + f = relay.Function([x], out) + return f, {"x": x_shape}, [] + + run_and_verify_func(get_graph()) + + +def test_squeeze(): + def get_graph(x_shape, axis): + x = relay.var("x", shape=(x_shape), dtype="float32") + out = relay.squeeze(x, axis=axis) + f = relay.Function([x], out) + return f, {"x": x_shape}, [] + + run_and_verify_func(get_graph((1, 5, 1, 1), (2, 3))) + run_and_verify_func(get_graph((1, 3, 1), (-1,))) + + +def test_concatenate(): + def get_graph(input_shapes, axis): + concat_inputs = [] + shapes_dict = {} + for i in range(len(input_shapes)): + name = "input_{}".format(i) + concat_inputs.append(relay.var(name, shape=(input_shapes[i]), dtype="float32")) + shapes_dict[name] = input_shapes[i] + out = relay.concatenate(concat_inputs, axis) + f = relay.Function(concat_inputs, out) + return f, shapes_dict, [] + + run_and_verify_func(get_graph([(1, 2, 6, 6), (1, 3, 6, 6)], axis=1)) + + +def test_conv2d_transpose(): + def get_graph( + x_shape=(1, 32, 8, 8), + k_shape=(32, 16, 3, 3), + groups=1, + padding=(0, 0), + strides=(1, 1), + ): + x = relay.var("x", shape=(x_shape), dtype="float32") + kernel = relay.var("kernel", shape=(k_shape), dtype="float32") + out = relay.nn.conv2d_transpose( + x, + kernel, + channels=k_shape[1], + kernel_size=k_shape[2:4], + groups=groups, + padding=padding, + strides=strides, + ) + f = relay.Function([x, kernel], out) + return f, {"x": x_shape, "kernel": k_shape}, ["kernel"] + + for padding in [(0, 0), (1, 1)]: + for strides in [(1, 1), (2, 2)]: + run_and_verify_func(get_graph(padding=padding, strides=strides)) + + +def test_reshape(): + def get_graph(x_shape, new_shape): + x = relay.var("x", shape=(x_shape), dtype="float32") + out = relay.reshape(x, new_shape) + f = relay.Function([x], out) + return f, {"x": x_shape}, [] + + run_and_verify_func(get_graph((1, 1, 1, 10), (-1, 10))) + run_and_verify_func(get_graph((1, 10, 2, 3), (1, -1))) + run_and_verify_func(get_graph((1, 1, 2, 3), (1, 6))) + + +def test_transpose(): + def get_graph(x_shape, order): + x = relay.var("x", shape=(x_shape), dtype="float32") + out = relay.transpose(x, order) + f = relay.Function([x], out) + return f, {"x": x_shape}, [] + + run_and_verify_func(get_graph((1, 16, 7, 7), [0, 2, 3, 1])) + run_and_verify_func(get_graph((1, 7, 7, 16), [0, 3, 1, 2])) + + +def test_float_const(): + def get_graph(x_shape=(1, 16)): + x = relay.var("x", shape=(x_shape), dtype="float32") + beta = relay.const(1, dtype="float32") + out = relay.multiply(x, beta) + f = relay.Function([x], out) + return f, {"x": x_shape}, [] + + run_and_verify_func(get_graph()) + + +def test_pad(): + def get_graph(x_shape, pad_width): + x = relay.var("x", shape=(x_shape), dtype="float32") + out = relay.nn.pad(x, pad_width=pad_width) + f = relay.Function([x], out) + return f, {"x": x_shape}, [] + + run_and_verify_func(get_graph((1, 8, 16, 16), [[0, 0], [0, 0], [0, 0], [0, 0]])) + run_and_verify_func(get_graph((1, 8, 16, 16), [[0, 0], [0, 0], [1, 1], [1, 1]])) + run_and_verify_func(get_graph((1, 8, 16, 16), [[0, 0], [0, 0], [0, 1], [2, 0]])) + run_and_verify_func(get_graph((1, 8, 3, 16, 16), [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0]])) + + +def test_softmax(): + def get_graph(x_shape, axis): + x = relay.var("x", shape=(x_shape), dtype="float32") + out = relay.nn.softmax(x, axis=axis) + f = relay.Function([x], out) + return f, {"x": x_shape}, [] + + run_and_verify_func(get_graph((1, 1000), axis=1)) + run_and_verify_func(get_graph((1, 1000), axis=-1)) + run_and_verify_func(get_graph((1, 3, 4), axis=-2)) + run_and_verify_func(get_graph((1, 3, 4), axis=1)) + + +def test_batch_norm(): + def get_graph(x_shape, param_shape, axis=1, epsilon=1e-5): + x = relay.var("x", shape=(x_shape), dtype="float32") + beta = relay.var("beta", shape=(param_shape), dtype="float32") + gamma = relay.var("gamma", shape=(param_shape), dtype="float32") + moving_mean = relay.var("moving_mean", shape=(param_shape), dtype="float32") + moving_var = relay.var("moving_var", shape=(param_shape), dtype="float32") + out, _, _ = relay.nn.batch_norm( + x, + gamma=gamma, + beta=beta, + moving_mean=moving_mean, + moving_var=moving_var, + axis=axis, + center=True, + scale=True, + epsilon=epsilon, + ) + f = relay.Function([x, gamma, beta, moving_mean, moving_var], out) + return ( + f, + { + "x": x_shape, + "beta": param_shape, + "gamma": param_shape, + "moving_mean": param_shape, + "moving_var": param_shape, + }, + ["beta", "gamma", "moving_mean", "moving_var"], + ) + + run_and_verify_func(get_graph((1, 64, 56, 56), (64,))) + run_and_verify_func(get_graph((1, 56, 56, 64), (64,), axis=3, epsilon=1.001e-05)) + + +def test_unary(): + def get_graph(op, x_shape=(1, 8, 3, 3)): + x = relay.var("x", shape=(x_shape), dtype="float32") + out = op(x) + f = relay.Function([x], out) + return f, {"x": x_shape}, [] + + for op in [ + relay.nn.relu, + relay.sigmoid, + relay.tanh, + relay.exp, + relay.log, + relay.sqrt, + relay.abs, + relay.negative, + relay.sin, + relay.cos, + relay.atan, + relay.ceil, + relay.floor, + ]: + run_and_verify_func(get_graph(op)) + + +def test_clip(): + def get_graph(x_shape=(1, 8, 3, 3)): + x = relay.var("x", shape=(x_shape), dtype="float32") + out = relay.clip(x, a_min=-0.2, a_max=0.4) + f = relay.Function([x], out) + return f, {"x": x_shape}, [] + + run_and_verify_func(get_graph()) + + +def test_leaky_relu(): + def get_graph(x_shape=(1, 8, 3, 3)): + x = relay.var("x", shape=(x_shape), dtype="float32") + out = relay.nn.leaky_relu(x, alpha=0.1) + f = relay.Function([x], out) + return f, {"x": x_shape}, [] + + run_and_verify_func(get_graph()) + + +def test_binary(): + def get_graph(op, x_shape, y_shape, y_is_const=False): + x = relay.var("x", shape=(x_shape), dtype="float32") + if y_is_const: + y = relay.const(np.ones(y_shape).astype("float32")) + out = op(x, y) + f = relay.Function([x], out) + return f, {"x": x_shape}, [] + y = relay.var("y", shape=(y_shape), dtype="float32") + out = op(x, y) + f = relay.Function([x, y], out) + return f, {"x": x_shape, "y": y_shape}, [] + + for op in [relay.add, relay.subtract, relay.multiply, relay.divide, relay.power]: + for y_is_const in [True, False]: + run_and_verify_func(get_graph(op, (1, 8, 3, 3), (1, 8, 3, 3), y_is_const)) + run_and_verify_func(get_graph(op, (1, 8, 1, 3), (1, 8, 3, 1), y_is_const)) + run_and_verify_func(get_graph(op, (1, 10), (10,), y_is_const)) + run_and_verify_func(get_graph(op, (1, 1, 1, 10), (10,), y_is_const)) + run_and_verify_func(get_graph(op, (1, 1, 1), (3,), y_is_const)) + + +def test_reduce(): + def get_graph(op, x_shape=(1, 2, 3, 4), axis=(2, 3), keepdims=False): + x = relay.var("x", shape=(x_shape), dtype="float32") + out = op(x, axis=axis, keepdims=keepdims) + f = relay.Function([x], out) + return f, {"x": x_shape}, [] + + for op in [relay.sum, relay.prod, relay.max, relay.min, relay.mean]: + for keepdims in [True, False]: + run_and_verify_func(get_graph(op, axis=(1), keepdims=keepdims)) + run_and_verify_func(get_graph(op, axis=(2, 3), keepdims=keepdims)) + run_and_verify_func(get_graph(op, axis=(1, 2), keepdims=keepdims)) + run_and_verify_func(get_graph(op, axis=(1, 2, 3), keepdims=keepdims)) + + +def test_strided_slice(): + def get_graph(x_shape, begin, end, strides=None): + x = relay.var("x", shape=(x_shape), dtype="float32") + if strides: + out = relay.strided_slice( + x, + relay.expr.const(begin, dtype="int32"), + relay.expr.const(end, dtype="int32"), + relay.expr.const(strides, dtype="int32"), + ) + else: + out = relay.strided_slice( + x, + relay.expr.const(begin, dtype="int32"), + relay.expr.const(end, dtype="int32"), + ) + f = relay.Function([x], out) + return f, {"x": x_shape}, [] + + run_and_verify_func(get_graph((1, 3, 6, 7), [0, 0, 0, 0], [1, 1, 6, 7])) + run_and_verify_func(get_graph((1, 3, 6, 7), [0, 1, 0, 0], [1, 2, 6, 6])) + run_and_verify_func(get_graph((1, 10), [0, 0], [1, 10], [1, 2])) + + +def test_adaptive_pool2d(): + def get_graph(op, x_shape=(1, 3, 32, 32), out_size=(1, 1)): + x = relay.var("x", shape=(x_shape), dtype="float32") + out = op(x, out_size) + f = relay.Function([x], out) + return f, {"x": x_shape}, [] + + run_and_verify_func(get_graph(relay.nn.adaptive_max_pool2d)) + run_and_verify_func(get_graph(relay.nn.adaptive_avg_pool2d)) + + +def test_multiple_outputs(): + def get_graph(): + x = relay.var("x", shape=(1, 3), dtype="float32") + y = relay.var("y", shape=(1, 3), dtype="float32") + z = relay.add(x, y) + w = relay.add(z, y) + out = relay.Tuple((z, w)) + f = relay.Function([x, y], out) + return f, {"x": (1, 3), "y": (1, 3)}, [] + + run_and_verify_func(get_graph()) + + +def test_conv3d(): + def get_graph( + x_shape=(1, 32, 8, 8, 8), + k_shape=(16, 32, 3, 3, 3), + groups=1, + padding=(0, 0, 0), + strides=(1, 1, 1), + dilation=(1, 1, 1), + ): + x = relay.var("x", shape=(x_shape), dtype="float32") + kernel = relay.var("kernel", shape=(k_shape), dtype="float32") + out = relay.nn.conv3d( + x, + kernel, + channels=k_shape[0], + kernel_size=k_shape[2:], + groups=groups, + padding=padding, + strides=strides, + dilation=dilation, + ) + f = relay.Function([x, kernel], out) + return f, {"x": x_shape, "kernel": k_shape}, ["kernel"] + + run_and_verify_func(get_graph()) + run_and_verify_func(get_graph(padding=(0, 0, 0, 1, 1, 1))) + + +def test_pool3d(): + def get_graph( + op, + x_shape=(1, 3, 8, 32, 32), + pool_size=(2, 2, 2), + strides=(2, 2, 2), + padding=(0, 0, 0), + ceil_mode=False, + count_include_pad=None, + ): + x = relay.var("x", shape=(x_shape), dtype="float32") + if count_include_pad is not None: + out = op( + x, + pool_size=pool_size, + strides=strides, + padding=padding, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + ) + else: + out = op( + x, + pool_size=pool_size, + strides=strides, + padding=padding, + ceil_mode=ceil_mode, + ) + f = relay.Function([x], out) + return f, {"x": x_shape}, [] + + run_and_verify_func(get_graph(relay.nn.avg_pool3d)) + run_and_verify_func(get_graph(relay.nn.max_pool3d)) + run_and_verify_func(get_graph(relay.nn.max_pool3d, padding=(0, 0, 0, 1, 1, 1))) + run_and_verify_func(get_graph(relay.nn.max_pool3d, strides=(1, 1, 1))) + + +def test_conv3d_transpose(): + def get_graph( + x_shape=(1, 32, 8, 8, 8), + k_shape=(32, 16, 3, 3, 3), + groups=1, + padding=(0, 0, 0), + strides=(1, 1, 1), + output_padding=(0, 0, 0), + ): + x = relay.var("x", shape=(x_shape), dtype="float32") + kernel = relay.var("kernel", shape=(k_shape), dtype="float32") + out = relay.nn.conv3d_transpose( + x, + kernel, + channels=k_shape[1], + kernel_size=k_shape[2:5], + groups=groups, + padding=padding, + strides=strides, + output_padding=output_padding, + ) + f = relay.Function([x, kernel], out) + return f, {"x": x_shape, "kernel": k_shape}, ["kernel"] + + run_and_verify_func(get_graph()) + run_and_verify_func(get_graph(strides=(2, 2, 2))) + run_and_verify_func(get_graph(strides=(2, 2, 2), output_padding=(1, 1, 1))) + + +def test_alexnet(): + run_and_verify_model("alexnet") + + +def test_resnet18_v1(): + run_and_verify_model("resnet18_v1") + + +def test_resnet18_v2(): + run_and_verify_model("resnet18_v2") + + +def test_squeezenet(): + run_and_verify_model("squeezenet1.0") + + +def test_mobilenet(): + run_and_verify_model("mobilenet0.25") + + +def test_mobilenet_v2(): + run_and_verify_model("mobilenetv2_0.25") + + +def test_vgg11(): + run_and_verify_model("vgg11") + + +def test_densenet121(): + run_and_verify_model("densenet121") + + +if __name__ == "__main__": + test_tensorrt_not_compatible() + test_tensorrt_simple() + test_tensorrt_serialize() + + # Op tests + test_conv2d() + test_conv2d_nhwc() + test_conv2d_weights_const() + test_conv2d_weights_transposed() + test_dense() + test_bias_add() + test_pool2d() + test_global_pool2d() + test_batch_flatten() + test_expand_dims() + test_squeeze() + test_concatenate() + test_conv2d_transpose() + test_reshape() + test_transpose() + test_float_const() + test_pad() + test_softmax() + test_batch_norm() + test_unary() + test_clip() + test_leaky_relu() + test_binary() + test_reduce() + test_strided_slice() + test_adaptive_pool2d() + test_multiple_outputs() + test_conv3d() + test_pool3d() + test_conv3d_transpose() + + # Integration tests + test_alexnet() + test_resnet18_v1() + test_resnet18_v2() + test_squeezenet() + test_mobilenet() + test_mobilenet_v2() + test_vgg11() + test_densenet121() diff --git a/tests/scripts/task_config_build_gpu.sh b/tests/scripts/task_config_build_gpu.sh index dedb561c56c6..0072fb59cf11 100755 --- a/tests/scripts/task_config_build_gpu.sh +++ b/tests/scripts/task_config_build_gpu.sh @@ -45,3 +45,4 @@ echo set\(USE_VTA_FSIM ON\) >> config.cmake echo set\(USE_BLAS openblas\) >> config.cmake echo set\(CMAKE_CXX_COMPILER g++\) >> config.cmake echo set\(CMAKE_CXX_FLAGS -Werror\) >> config.cmake +echo set\(USE_TENSORRT_CODEGEN ON\) >> config.cmake